## Fast Speech 2 FGP Emphasis Demo


#### Import libraries

In [23]:
import re
import argparse
import uuid
from string import punctuation

import os
import json

import matplotlib
from matplotlib import pyplot as plt

import torch
import yaml
import numpy as np
from torch.utils.data import DataLoader
from g2p_en import G2p
from pypinyin import pinyin, Style

from utils.model import get_vocoder
from utils.tools import to_device 
from dataset import TextDataset
from text import text_to_sequence
from sys import stdin, exit

import time
from scipy.io import wavfile

from utils.model import vocoder_infer
from model import FastSpeech2, ScheduledOptim

import IPython.display as ipd
from IPython.display import Image



#### Define Helper Functions

In [24]:
def pitch_control_word_to_phoneme(pitch_control_word_level, text):
    text = text.rstrip(punctuation)
    lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
    pitch_control_phoneme_level = []
    g2p = G2p()

    words = re.split(r"([,;.\-\?\!\s+])", text)
    proper_words = [word for word in words if re.search('[a-zA-Z]', word) is not None]
    if len(proper_words) != len(pitch_control_word_level):
        print("Word amount and word level pitch control parameter amount does not match!")
        print("pitch control parameters amount: {} (parameters: {})".format(len(pitch_control_word_level), pitch_control_word_level))
        print("word amount: {} (words: {})".format(len(proper_words), proper_words))
        return 1.0
    proper_word_index = 0

    for w in words: 
        if w.lower() in lexicon:
            phone_amount = len(lexicon[w.lower()])
        else:
            phone_amount = len(list(filter(lambda p: p != " ", g2p(w))))
        pitch_control_phoneme_level += [pitch_control_word_level[proper_word_index]] * phone_amount
        if w.lower() == proper_words[proper_word_index].lower():
            proper_word_index += 1
            if proper_word_index >= len(proper_words):
                break
    if len(proper_words) != proper_word_index:
        print("Bug Warnign! len proper words: {}, proper_word_index: {}".format(len(proper_words), proper_word_index))   

    if text[-1] == " ":
        pitch_control_phoneme_level += [1.0]
    return pitch_control_phoneme_level

def energy_control_word_to_phoneme(energy_control_word_level, text):
    text = text.rstrip(punctuation)
    lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
    energy_control_phoneme_level = []
    g2p = G2p()

    words = re.split(r"([,;.\-\?\!\s+])", text)
    proper_words = [word for word in words if re.search('[a-zA-Z]', word) is not None]
    if len(proper_words) != len(energy_control_word_level):
        print("Word amount and word level energy control parameter amount does not match!")
        print("energy control parameters amount: {} (parameters: {})".format(len(energy_control_word_level), energy_control_word_level))
        print("word amount: {} (words: {})".format(len(proper_words), proper_words))
        return 1.0
    proper_word_index = 0

    for w in words: 
        if w.lower() in lexicon:
            phone_amount = len(lexicon[w.lower()])
        else:
            phone_amount = len(list(filter(lambda p: p != " ", g2p(w))))
        energy_control_phoneme_level += [energy_control_word_level[proper_word_index]] * phone_amount
        if w.lower() == proper_words[proper_word_index].lower():
            proper_word_index += 1
            if proper_word_index >= len(proper_words):
                break
    if len(proper_words) != proper_word_index:
        print("Bug Warning! len proper words: {}, proper_word_index: {}".format(len(proper_words), proper_word_index))   

    if text[-1] == " ":
        energy_control_phoneme_level += [1.0]
    return energy_control_phoneme_level

def duration_control_word_to_phoneme(duration_control_word_level, text):
    text = text.rstrip(punctuation)
    lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])
    duration_control_phoneme_level = []
    g2p = G2p()

    words = re.split(r"([,;.\-\?\!\s+])", text)
    proper_words = [word for word in words if re.search('[a-zA-Z]', word) is not None]
    if len(proper_words) != len(duration_control_word_level):
        print("Word amount and word level duration control parameter amount does not match!")
        print("duration control parameters amount: {} (parameters: {})".format(len(duration_control_word_level), duration_control_word_level))
        print("word amount: {} (words: {})".format(len(proper_words), proper_words))
        return 1.0
    proper_word_index = 0

    for w in words: 
        if w.lower() in lexicon:
            phone_amount = len(lexicon[w.lower()])
        else:
            phone_amount = len(list(filter(lambda p: p != " ", g2p(w))))
        duration_control_phoneme_level += [duration_control_word_level[proper_word_index]] * phone_amount
        if w.lower() == proper_words[proper_word_index].lower():
            proper_word_index += 1
            if proper_word_index >= len(proper_words):
                break
    if len(proper_words) != proper_word_index:
        print("Bug Warning! len proper words: {}, proper_word_index: {}".format(len(proper_words), proper_word_index))   

    if text[-1] == " ":
        duration_control_phoneme_level += [1.0]
    return duration_control_phoneme_level

def expand(values, durations):
    out = list()
    for value, d in zip(values, durations):
        out += [value] * max(0, int(d))
    return np.array(out)

def plot_mel(data, stats, titles):
    fig, axes = plt.subplots(len(data), 1, squeeze=False)
    if titles is None:
        titles = [None for i in range(len(data))]
    pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
    pitch_min = pitch_min * pitch_std + pitch_mean
    pitch_max = pitch_max * pitch_std + pitch_mean

    def add_axis(fig, old_ax):
        ax = fig.add_axes(old_ax.get_position(), anchor="W")
        ax.set_facecolor("None")
        return ax

def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path):

    import time
    start_time = time.time()
    basenames = targets[0]
    for i in range(len(predictions[0])):
        basename = basenames[i]
        src_len = predictions[8][i].item()
        mel_len = predictions[9][i].item()
        mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
        duration = predictions[5][i, :src_len].detach().cpu().numpy()
        if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
            pitch = predictions[2][i, :src_len].detach().cpu().numpy()
            pitch = expand(pitch, duration)
        else:
            pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
        if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
            energy = predictions[3][i, :src_len].detach().cpu().numpy()
            energy = expand(energy, duration)
        
        else:
            energy = predictions[3][i, :mel_len].detach().cpu().numpy()

        with open(
            os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
        ) as f:
            stats = json.load(f)
            stats = stats["pitch"] + stats["energy"][:2]

        fig = plot_mel(
            [
                (mel_prediction.cpu().numpy(), pitch, energy),
            ],
            stats,
            ["Synthetized Spectrogram"],
        )
        plt.savefig(os.path.join(path, "mel.png"))
        plt.close()

    mel_predictions = predictions[1].transpose(1, 2)
    lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
    wav_predictions = vocoder_infer(
        mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
    )

    wav_time = time.time()

    # print("WaveGAN Inference time: {}".format(wav_time - start_time))

    sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
    return wav_predictions[0], fig
        

def get_model(restore_step, configs, device, train=False):
    (preprocess_config, model_config, train_config) = configs

    model = FastSpeech2(preprocess_config, model_config).to(device)
    if restore_step:
        ckpt_path = os.path.join(
            train_config["path"]["ckpt_path"],
            "{}.pth.tar".format(restore_step),
        )
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt["model"])

    if train:
        scheduled_optim = ScheduledOptim(
            model, train_config, model_config, restore_step
        )
        if restore_step:
            scheduled_optim.load_state_dict(ckpt["optimizer"])
        model.train()
        return model, scheduled_optim

    model.eval()
    model.requires_grad_ = False
    return model

def read_lexicon(lex_path):
    lexicon = {}
    with open(lex_path) as f:
        for line in f:
            temp = re.split(r"\s+", line.strip("\n"))
            word = temp[0]
            phones = temp[1:]
            if word.lower() not in lexicon:
                lexicon[word.lower()] = phones
    return lexicon

def preprocess_english(text, preprocess_config):
    text = text.rstrip(punctuation)
    lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"])

    g2p = G2p()
    phones = []
    words = re.split(r"([,;.\-\?\!\s+])", text)
    for w in words:
        if w.lower() in lexicon:
            phones += lexicon[w.lower()]
        else:
            phones += list(filter(lambda p: p != " ", g2p(w)))
    phones = "{" + "}{".join(phones) + "}"
    phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones)
    phones = phones.replace("}{", " ")

    # print("Raw Text Sequence: {}".format(text))
    # print("Phoneme Sequence: {}".format(phones))
    sequence = np.array(
        text_to_sequence(
            phones, preprocess_config["preprocessing"]["text"]["text_cleaners"]
        )
    )

    return np.array(sequence)

def synthesize(model, step, configs, vocoder, batchs, control_values):
    preprocess_config, model_config, train_config = configs
    pitch_control, energy_control, duration_control = control_values

    for batch in batchs:
        batch = to_device(batch, device)
        with torch.no_grad():
            # Forward
            start_time = time.time()
            output = model(
                *(batch[2:]),
                p_control=pitch_control,
                e_control=energy_control,
                d_control=duration_control
            )
            mel_time = time.time()
            wave, fig = synth_samples(
                batch,
                output,
                vocoder,
                model_config,
                preprocess_config,
                train_config["path"]["result_path"],
            )
            full_time = time.time()
            return wave, fig
            # print("mel_time: {}, wav_time {}, full_time: {}".format(mel_time - start_time, full_time - mel_time, full_time - start_time))

def parse_xml_input(xml_text):
    cleaned_text = []
    pitch_control_word_level = []
    energy_control_word_level = []
    duration_control_word_level = []

    current_pitch = []
    current_energy = []
    current_duration = []

    xml_text = xml_text.replace("<prosody rate", "<prosody-rate").replace("<prosody volume", "<prosody-volume").replace("<prosody range", "<prosody-range")   
    xml_text = xml_text.replace("</prosody rate", "</prosody-rate").replace("</prosody volume", "</prosody-volume").replace("</prosody range", "</prosody-range")   
 
    # TODO: check for syntax errors!
    text_parts = re.split(r'[ >]', xml_text)
    # filter empty strings
    text_parts = list(filter(None, text_parts))
    
    print("text parts: " + str(text_parts))

    for text in text_parts:
        if text[0] == '<':
            # tag
            if text[1] != '/':
                print("opening tag")
                # opening tag
                to_filter = '"% '
                if text[1:15] == 'prosody-range=':
                    print("pitch")
                    # pitch
                    argument = text[15:].translate({ord(i): None for i in to_filter})
                    # procentual value to multiplier
                    argument = float(argument) / 100.0
                    current_pitch.append(argument)
                elif text[1:16] == 'prosody-volume=':
                    print("energy")
                    # energy
                    argument = text[16:].translate({ord(i): None for i in to_filter})
                    argument = float(argument) / 100.0
                    current_energy.append(argument)
                elif text[1:14] == 'prosody-rate=':
                    print("duration")
                    # duration
                    argument = text[14:].translate({ord(i): None for i in to_filter})
                    argument = float(argument) / 100.0
                    # Higher prosody rate tag means faster speech, but internally duration parameter is handled the opposite way => use inverse
                    argument = 1.0 / argument
                    current_duration.append(argument)
                elif text[1:9] == 'emphasis':
                    # emphasis
                    print("emph")
                    current_duration.append(1.8)
                    current_pitch.append("HIGH")
                else:
                    print("Warning! Unknown tag: {}".format(text)) 
            else:
                print("closing tag")
                if text[2:16] == 'prosody-range':
                    print("pitch")
                    # pitch
                    current_pitch.pop()
                elif text[2:17] == 'prosody-volume':
                    print("energy")
                    # energy
                    current_energy.pop()
                elif text[2:15] == 'prosody-rate':
                    print("duration")
                    # duration
                    current_duration.pop()
                elif text[2:10] == 'emphasis':
                    # emphasis
                    print("close emph")
                    current_duration.pop()
                    current_pitch.pop()
                else:
                    print("Warning! Unknown tag: {}".format(text)) 
                # closing tag

        else:
            # text to synthesize
            cleaned_text.append(text)
            text_pitch = 1.0 if not current_pitch else current_pitch[-1]
            text_energy = 1.0 if not current_energy else current_energy[-1]
            text_duration = 1.0 if not current_duration else current_duration[-1]
            pitch_control_word_level.append(text_pitch)
            energy_control_word_level.append(text_energy)
            duration_control_word_level.append(text_duration)
        
    cleaned_text = ' '.join(cleaned_text)
    print("Cleaned text: " + cleaned_text)
    print("pitch control: " + str(pitch_control_word_level))
    print("energy control: " + str(energy_control_word_level))
    print("duration control: " + str(duration_control_word_level))
    
    return cleaned_text, pitch_control_word_level, energy_control_word_level, duration_control_word_level

            

In [25]:
device = "cuda"
restore_step = 900000

preprocess_config = "config/LJSpeech/preprocess.yaml"
model_config = "config/LJSpeech/model.yaml"
train_config = "config/LJSpeech/train.yaml"

preprocess_config = yaml.load(
    open(preprocess_config, "r"), Loader=yaml.FullLoader
)

model_config = yaml.load(open(model_config, "r"), Loader=yaml.FullLoader)
train_config = yaml.load(open(train_config, "r"), Loader=yaml.FullLoader)
configs = (preprocess_config, model_config, train_config)

#### Load Models

In [26]:
# Get model
model = get_model(restore_step, configs, device, train=False)

# Load vocoder
vocoder = get_vocoder(model_config, device)

Removing weight norm...


#### Synthesize Speech

In [31]:
speakers = np.array([0])
xml = False

input_txt = "Yes."

if xml:
    input_txt, pitch_control, energy_control, duration_control = parse_xml_input(input_txt)
else:
    pitch_control = 1.0
    energy_control = 1.0
    duration_control = 1.0

if type(pitch_control) != type(1.0):
    pitch_control_phoneme_level = pitch_control_word_to_phoneme(pitch_control, input_txt)
    pitch_control = pitch_control_phoneme_level

# get energy control
if type(energy_control) != type(1.0):
    energy_control_phoneme_level = energy_control_word_to_phoneme(energy_control, input_txt)
    energy_control = energy_control_phoneme_level

# get duration control
if type(duration_control) != type(1.0):
    duration_control_phoneme_level = duration_control_word_to_phoneme(duration_control, input_txt)
    duration_control = duration_control_phoneme_level

control_values = pitch_control, energy_control, duration_control

ids = [str(uuid.uuid4())]
raw_texts = [input_txt[:100]]
texts = np.array([preprocess_english(input_txt, preprocess_config)])
text_lens = np.array([len(texts[0])])
batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens))]
wave, fig = synthesize(model, restore_step, configs, vocoder, batchs, control_values)

ipd.Image(filename="output/result/LJSpeech/mel.png")
ipd.Audio(wave, rate=22050)