# StyleTTS Demo (LibriTTS)


### Utils

In [1]:
%cd ..

/home/seichi/StyleTTS


In [3]:
# load packages
import os
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa

from models import *
from utils import *

os.environ['CUDA_VISIBLE_DEVICES'] = "1"

%matplotlib inline

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
print(torch.cuda.get_device_name())

cuda
NVIDIA GeForce GTX 1070


In [15]:

import pandas as pd
import pyopenjtalk

DEFAULT_DICT_PATH = os.path.join('Configs/word_index_dict.txt')
class TextCleaner:
    def __init__(self, word_index_dict_path=DEFAULT_DICT_PATH):
        self.word_index_dictionary = self.load_dictionary(word_index_dict_path)

    def __call__(self, text):
        indexes = []
        tokens = pyopenjtalk.g2p(text).split(' ')
        length = len(tokens)
        for char in tokens:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(text)
        # pad = [0]*(self.max_length - length - 1)
        return indexes
    
    def load_dictionary(self, path):
        csv = pd.read_csv(path, header=None).values
        word_index_dict = {word: index for word, index in csv}
        return word_index_dict

textclenaer = TextCleaner()

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(ref_dicts):
    reference_embeddings = {}
    for key, path in ref_dicts.items():
        wave, sr = librosa.load(path, sr=24000)
        audio, index = librosa.effects.trim(wave, top_db=30)
        if sr != 24000:
            audio = librosa.resample(audio, sr, 24000)
        mel_tensor = preprocess(audio).to(device)
        try:
            with torch.no_grad():
                ref = model.style_encoder(mel_tensor.unsqueeze(1))
            reference_embeddings[key] = (ref.squeeze(1), audio)
        except:
            continue
    
    return reference_embeddings

### Load models

In [10]:
# load hifi-gan

import sys
sys.path.insert(0, "./Demo/hifi-gan")

import glob
import os
import argparse
import json
import torch
from scipy.io.wavfile import write
from attrdict import AttrDict
from vocoder import Generator
import librosa
import numpy as np
import torchaudio

h = None

def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict

def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]

cp_g = scan_checkpoint("Vocoder/", 'g_')

config_file = os.path.join(os.path.split(cp_g)[0], 'config.json')
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

device = torch.device(device)
generator = Generator(h).to(device)

state_dict_g = load_checkpoint(cp_g, device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()



Loading 'Vocoder/g_00750000'
Complete.
Removing weight norm...


In [53]:
# load StyleTTS
model_path = "./Models/JSUT/epoch_2nd_00100.pth"
model_config_path = "./Models/JSUT/config.yml"

config = yaml.safe_load(open(model_config_path))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

model = build_model(Munch(config['model_params']), text_aligner, pitch_extractor)

params = torch.load(model_path, map_location='cpu')
params = params['net']
for key in model:
    if key in params:
        if not "discriminator" in key:
            print('%s loaded' % key)
            model[key].load_state_dict(params[key])
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

predictor loaded
decoder loaded
pitch_extractor loaded
text_encoder loaded
style_encoder loaded
text_aligner loaded


### Synthesize speech (seen speakers, LibriTTS train-clean-100)

In [43]:
# get first 3 training sample as references

train_path = config.get('train_data', None)
val_path = config.get('val_data', None)
train_list, val_list = get_data_path_list(train_path, val_path)

ref_dicts = {}
for j in range(3):
    filename = train_list[j].split('|')[0]
    name = filename.split('/')[-1].replace('.wav', '')
    ref_dicts[name] = filename
    
reference_embeddings = compute_style(ref_dicts)

In [64]:
# synthesize a text
text = '''水をマレーシアから買わなくてはならないのです'''

# tokenize
tokens = textclenaer(text)
tokens.insert(0, 0)
tokens.append(0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

In [66]:
converted_samples = {}

with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
        
    for key, (ref, _) in reference_embeddings.items():
        
        s = ref.squeeze(1)
        style = s
        
        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze().cpu().numpy()

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        converted_samples[key] = y_out.cpu().numpy()
        
    import IPython.display as ipd
for key, wave in converted_samples.items():
    print('Synthesized: %s' % key)
    display(ipd.Audio(wave, rate=24000))
    try:
        print('Reference: %s' % key)
        display(ipd.Audio(reference_embeddings[key][-1], rate=24000))
    except:
        continue

Synthesized: BASIC5000_0001


Reference: BASIC5000_0001


Synthesized: BASIC5000_0002


Reference: BASIC5000_0002


Synthesized: BASIC5000_0003


Reference: BASIC5000_0003


### Zero-shot TTS (unseen speakers, LibriTTS test-clean)

In [19]:
test_clean_path = '/share/naplab/users/yl4579/data/LibriTTS/test-clean/'

ref_dicts = {}
# pick first 3 speakers from test-clean
spks = [ f.path for f in os.scandir(test_clean_path) if f.is_dir() ]
spks = spks[:3]
for spk in spks:
    spk_path = spk
    spk = spk.split('/')[-1]
    spk_path = spk_path + "/" + (np.random.choice(os.listdir(spk_path), size=1)[0])
    for f in os.listdir(spk_path):
        if f.endswith('.wav'):
            ref_dicts[spk] = spk_path + "/" + f
reference_embeddings = compute_style(ref_dicts)

FileNotFoundError: [Errno 2] No such file or directory: '/share/naplab/users/yl4579/data/LibriTTS/test-clean/'

In [None]:
# synthesize a text
text = ''' StyleTTS is a style based generative model that can synthesize diverse speech with natural prosody from a reference speech utterance. '''

In [20]:
# tokenize
tokens = textclenaer(text)
tokens.insert(0, 0)
tokens.append(0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

In [21]:
converted_samples = {}

with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
        
    for key, (ref, _) in reference_embeddings.items():
        
        s = ref.squeeze(1)
        style = s
        
        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze().cpu().numpy()

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        converted_samples[key] = y_out.cpu().numpy()

In [22]:
import IPython.display as ipd
for key, wave in converted_samples.items():
    print('Synthesized: %s' % key)
    display(ipd.Audio(wave, rate=24000))
    try:
        print('Reference: %s' % key)
        display(ipd.Audio(reference_embeddings[key][-1], rate=24000))
    except:
        continue

Synthesized: BASIC5000_0001


Reference: BASIC5000_0001


Synthesized: BASIC5000_0002


Reference: BASIC5000_0002


Synthesized: BASIC5000_0003


Reference: BASIC5000_0003
