In [1]:
# !pip3 install tgt -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip3 install librosa -i https://pypi.tuna.tsinghua.edu.cn/simple

In [4]:
import whisper

class WhisperExtractor():
    def __init__(self, fs=16000, device="cuda"):
        self.fs = fs
        model = whisper.load_model("large")
        self.model = model.to(device)
        self.device = device
    
    def get_feature(self, path):
        # (#features, time)
        gtaudio = whisper.load_audio(path)
        audio = whisper.pad_or_trim(gtaudio)
        mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(self.model.device)
        options = whisper.DecodingOptions()
        result = whisper.decode(self.model, mel, options)
        features = result.audio_features.detach().cpu().numpy()
        return features[:, :int(len(gtaudio)/self.fs/30*features.shape[1])]

In [5]:
import warnings
warnings.filterwarnings("ignore")
import torch
import sys

import tgt
import os
import numpy as np
from tqdm import tqdm
import glob

In [6]:
dataset_dir = "/mntcephfs/lee_dataset/tts/LibriTTS_R/"
feat_base_dir = "/mntcephfs/lab_data/shoinoue/Dataset/LibriTTS_R/features/"

extractor = WhisperExtractor()
speakers = [os.path.basename(a) for a in glob.glob(dataset_dir + "*/*")]
speakers.sort()

In [7]:
save = False

files = []
for spk in tqdm(speakers):
    files += glob.glob(dataset_dir + f"*/{spk}/*/*.wav")
files.sort()
for path in tqdm(files):
    embedding = extractor.get_feature(path)
    savepath = feat_base_dir + "/".join(path.split("/")[-4:])[:-4] + "_whisper.npy"
    if save:
        np.save(savepath, embedding)

100%|██████████| 1230/1230 [00:02<00:00, 567.97it/s]
  0%|          | 7/160267 [00:08<55:01:25,  1.24s/it]
Exception ignored in: <function Group.__del__ at 0x155479692710>
Traceback (most recent call last):
  File "/mntcephfs/lab_data/shoinoue/miniconda3/envs/whisper/lib/python3.10/site-packages/regex/_regex_core.py", line 3035, in __del__
    def __del__(self):
KeyboardInterrupt: 

KeyboardInterrupt



# Normalize

In [None]:
from sklearn.preprocessing import StandardScaler
import numpy as np
import joblib

In [26]:
scaler_name = "LibriTTS-R_whisper"

In [30]:
scaler = StandardScaler()
files = []
for spk in tqdm(speakers):
    files += glob.glob(feat_base_dir + f"train*/{spk}/*/*[0-9]_whisper.npy")
files.sort()
for path in tqdm(files):
    mel = np.load(path)
    scaler.partial_fit(mel.T)
            
scaler_filename = f"ckpts/scalers/{scaler_name}.save"
joblib.dump(scaler, scaler_filename) 
# a = joblib.load(scaler_filename)

100%|██████████| 1230/1230 [00:04<00:00, 282.54it/s]
100%|██████████| 333/333 [00:01<00:00, 287.49it/s]
