In [1]:
import json
import re
import subprocess
import torch
from pathlib import Path

from scipy.io.wavfile import write as wav_write
from tqdm.notebook import tqdm

from src.models.hifi_gan.models import Generator, load_model as load_hifi
from src.train_config import TrainParams, load_config

In [2]:
config = load_config("configs/freest_tune.yml")
checkpoint_path = Path(f"checkpoints/{config.checkpoint_name}")

In [3]:
generators = [file for file in (checkpoint_path / "hifi").rglob("*.*") if file.name.startswith("g_")]

In [6]:
device = "cuda:0"  # config.device

In [7]:
PHONEMES_CHI = [
    ['qing1', 'chu1', 'yu2', 'lan2', 'er2', 'sheng4', 'yu2', 'lan2'],
    ['tian1', 'dao4', 'chou2', 'qin2'],
    ['jiu3', 'tian1', 'lan3', 'yue4'],
    ['sai1', 'weng1', 'shi1', 'ma3', '，', 'yan1', 'zhi1', 'fei1', 'fu2'],
    ['yi1', 'ming2', 'jing1', 'ren2'],
    ['yi1', 'si1', 'bu4', 'gou3'],
    ['yi1', 'jian4', 'shuang1', 'diao1'],
    ['shan1', 'yu3', 'yu4', 'lai2', 'feng1', 'man3', 'lou2'],
    ['ma2', 'que4', 'sui1', 'xiao3', '，', 'wu3', 'zang4', 'ju4', 'quan2'],
    ['qiang2', 'long2', 'nan2', 'ya1', 'di4', 'tou2', 'she2'],
    ['qian2', 'pa4', 'lang2', 'hou4', 'pa4', 'hu3'],
    ['da4', 'zhi4', 'ruo4', 'yu2']
 ]

In [8]:
lexicon_path = Path("models/pinyin-lexicon_with_tab.txt")

In [9]:
def read_lexicon(lex_path):
    lexicon = {}
    with open(lex_path) as f:
        for line in f:
            temp = re.split(r"\t", line.strip("\n"))
            word = temp[0]
            phones = temp[1:]
            if word.lower() not in lexicon:
                lexicon[word.lower()] = phones[0]
    return lexicon

In [10]:
lexicon = read_lexicon(lexicon_path)

In [11]:
PHONEMES = []
for phonems in PHONEMES_CHI:
    l = []
    
    for pho in phonems:
        try:
            l += lexicon[pho].split(' ')
        except:
            pass
    PHONEMES.append(l)

In [29]:
#PHONEMES

In [14]:
def to_phones(PHONEMES_TO_IDS, phones):
    """For new ones"""
    phonemes_ids = (
       [PHONEMES_TO_IDS[ph] for ph in phones]
    )
    return phonemes_ids


In [15]:
phonemes_list = []
with open(checkpoint_path / "feature"/ "phonemes.json") as f:
    phonemes_to_ids = json.load(f)
for hp in PHONEMES:
    phoneme_ids = to_phones(phonemes_to_ids, hp)
    phonemes_list.append(phoneme_ids)

In [16]:
feature_model = torch.load(checkpoint_path / "feature" / "feature_model.pth", map_location=device)

In [17]:
feature_model = feature_model.eval()

In [18]:
def get_tacotron_batch(
    phonemes_ids, reference, speaker_id, device, mels_mean, mels_std
):
    text_lengths_tensor = torch.LongTensor([len(phonemes_ids)])
    reference = (reference - mels_mean) / mels_std
    reference = reference.unsqueeze(0)
    reference = reference.permute(0, 2, 1).to(device)
    phonemes_ids_tensor = torch.LongTensor(phonemes_ids).unsqueeze(0).to(device)
    speaker_ids_tensor = torch.LongTensor([speaker_id]).to(device)
    return phonemes_ids_tensor, text_lengths_tensor, speaker_ids_tensor, reference

In [19]:
reference_pathes = Path(f"references/{config.lang}")

In [20]:
generated_path = Path(f"generated_hifi/{config.checkpoint_name}")

In [21]:
with open(checkpoint_path / "feature"/ "speakers.json") as f:
    speaker_to_id = json.load(f)

In [22]:
mels_mean = torch.load(checkpoint_path / "feature" / "mels_mean.pth", map_location=device).float()
mels_std = torch.load(checkpoint_path / "feature" / "mels_std.pth", map_location=device).float()

In [23]:
for reference in tqdm(list(reference_pathes.rglob("*.pkl"))):
    speaker = reference.parent.name
    speaker_id = speaker_to_id[speaker]
    ref_mel = torch.load(reference, map_location=device)
    for i, phonemes in enumerate(phonemes_list):
        batch = get_tacotron_batch(phonemes, ref_mel, speaker_id, device, mels_mean, mels_std)
        with torch.no_grad():
            mels = feature_model.inference(batch)
            mels = mels.permute(0, 2, 1).squeeze(0)
            mels = mels * mels_std.to(device) + mels_mean.to(device)
            x = mels.unsqueeze(0)
            for generator_path in generators:
                state_dict = torch.load(generator_path, map_location="cpu")
                state_dict["generator"] = {k: v.to(device) for k, v in state_dict["generator"].items()}
                generator = Generator(config=config.train_hifi.model_param, num_mels=config.n_mels).to(device)
                generator.load_state_dict(state_dict["generator"])
                generator.remove_weight_norm()
                generator.eval()
                y_g_hat = generator(x)
                audio = y_g_hat.squeeze()
                audio = audio * 32768
                audio = audio.type(torch.int16).detach().cpu().numpy()
                save_path = generated_path / generator_path.stem / speaker / reference.stem
                save_path.mkdir(exist_ok=True, parents=True)
                wav_write(save_path / f"{i + 1}.wav", 22050, audio)
                torch.cuda.empty_cache()


  0%|          | 0/50 [00:00<?, ?it/s]