# Preparations

In [None]:
#!g1.4
%%bash 
#install libraries
pip install torchaudio
pip install wandb
pip install gdown
pip install unidecode
pip install inflect
pip install --upgrade pydantic
pip install seaborn
pip install hparams

#download LjSpeech
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 -o /dev/null
mkdir data
tar -xvf LJSpeech-1.1.tar.bz2 >> /dev/null
mv LJSpeech-1.1 data/LJSpeech-1.1

gdown https://drive.google.com/u/0/uc?id=1-EdH0t0loc6vPiuVtXdhsDtzygWNSNZx
mv train.txt data/

#download Waveglow
gdown https://drive.google.com/u/0/uc?id=1WsibBTsuRg_SF2Z6L6NFRTT-NjEy1oTx
mkdir -p waveglow/pretrained_model/
mv waveglow_256channels_ljs_v2.pt waveglow/pretrained_model/waveglow_256channels.pt

# gdown https://drive.google.com/u/0/uc?id=1cJKJTmYd905a-9GFoo5gKjzhKjUVj83j
# tar -xvf mel.tar.gz
# echo $(ls mels | wc -l)

#download alignments
wget https://github.com/xcmyz/FastSpeech/raw/master/alignments.zip
unzip alignments.zip >> /dev/null

# we will use waveglow code, data and audio preprocessing from this repo
git clone https://github.com/xcmyz/FastSpeech.git
mv FastSpeech/text .
mv FastSpeech/audio .
mv FastSpeech/waveglow/* waveglow/
mv FastSpeech/utils.py .
mv FastSpeech/glow.py .

In [None]:
#!g1.4
!head -n 5 data/train.txt

In [None]:
# #!g1.4
# import shutil

# shutil.unpack_archive('energies.zip', 'energies')
# shutil.unpack_archive('mels.zip', 'mels')
# shutil.unpack_archive('pitches.zip', 'pitches')
# shutil.unpack_archive('emotions.zip', 'emotions')

# Imports

In [None]:
#!g1.4
import pathlib
import random
import itertools
from tqdm import tqdm_notebook

from IPython import display
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import distributions
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import MelSpectrogram
import math
import time
import os
import librosa
import pandas as pd
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from dataclasses import dataclass
from collections import OrderedDict

import seaborn as sns 
sns.set()

import sys
sys.path.append('.')

import soundfile # read audio files
import numpy as np
import librosa # extract features
import glob
import os
import pickle # to save model after training
import time

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import torch
from torch import nn
import torch.optim as optim

import math

from sklearn.neural_network import MLPRegressor
from sklearn import preprocessing, metrics
from sklearn.ensemble import RandomForestRegressor
from sklearn.utils import shuffle

import seaborn as sn
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.font_manager import FontProperties

from IPython.display import clear_output

%matplotlib inline

# Configs

In [None]:
#!g1.4
@dataclass
class MelSpectrogramConfig:
    num_mels = 80

@dataclass
class FastSpeechConfig:
    vocab_size = 300
    max_seq_len = 3000

    encoder_dim = 256
    encoder_n_layer = 4
    encoder_head = 2
    encoder_conv1d_filter_size = 1024

    decoder_dim = 256
    decoder_n_layer = 4
    decoder_head = 2
    decoder_conv1d_filter_size = 1024

    fft_conv1d_kernel = (9, 1)
    fft_conv1d_padding = (4, 0)

    predictor_filter_size = 256
    predictor_kernel_size = 3
    predictor_dropout = 0.5
    
    dropout = 0.1
    
    PAD = 0
    UNK = 1
    BOS = 2
    EOS = 3

    PAD_WORD = '<blank>'
    UNK_WORD = '<unk>'
    BOS_WORD = '<s>'
    EOS_WORD = '</s>'


@dataclass
class TrainConfig:
    checkpoint_path = "./model_new"
    logger_path = "./logger"
    mel_ground_truth = "./mels/mels"
    energy_path =  "./energies/energies" #new
    pitch_path =  "./pitches/pitches" #new
    emotion_path =  "./emotions/emotions" #new
    alignment_path = "./alignments"
    data_path = './data/train.txt'
    
    wandb_project = 'fastspeech_example'
    
    text_cleaners = ['english_cleaners']

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = 'cuda:0'

    batch_size = 16
    epochs = 200
    n_warm_up_step = 4000

    learning_rate = 1e-3
    weight_decay = 1e-6
    grad_clip_thresh = 1.0
    decay_step = [500000, 1000000, 2000000]

    save_step = 3000
    log_step = 5
    clear_Time = 20

    batch_expand_size = 32
    

mel_config = MelSpectrogramConfig()
model_config = FastSpeechConfig()
train_config = TrainConfig()

In [None]:
#!g1.4
from text import text_to_sequence


def pad_1D(inputs, PAD=0):

    def pad_data(x, length, PAD):
        x_padded = np.pad(x, (0, length - x.shape[0]),
                          mode='constant',
                          constant_values=PAD)
        return x_padded

    max_len = max((len(x) for x in inputs))
    padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])

    return padded


def pad_1D_tensor(inputs, PAD=0):

    def pad_data(x, length, PAD):
        x_padded = F.pad(x, (0, length - x.shape[0]))
        return x_padded

    max_len = max((len(x) for x in inputs))
    padded = torch.stack([pad_data(x, max_len, PAD) for x in inputs])

    return padded


def pad_2D(inputs, maxlen=None):

    def pad(x, max_len):
        PAD = 0
        if np.shape(x)[0] > max_len:
            raise ValueError("not max_len")

        s = np.shape(x)[1]
        x_padded = np.pad(x, (0, max_len - np.shape(x)[0]),
                          mode='constant',
                          constant_values=PAD)
        return x_padded[:, :s]

    if maxlen:
        output = np.stack([pad(x, maxlen) for x in inputs])
    else:
        max_len = max(np.shape(x)[0] for x in inputs)
        output = np.stack([pad(x, max_len) for x in inputs])

    return output


def pad_2D_tensor(inputs, maxlen=None):

    def pad(x, max_len):
        if x.size(0) > max_len:
            raise ValueError("not max_len")

        s = x.size(1)
        x_padded = F.pad(x, (0, 0, 0, max_len-x.size(0)))
        return x_padded[:, :s]

    if maxlen:
        output = torch.stack([pad(x, maxlen) for x in inputs])
    else:
        max_len = max(x.size(0) for x in inputs)
        output = torch.stack([pad(x, max_len) for x in inputs])

    return output


def process_text(train_text_path):
    with open(train_text_path, "r", encoding="utf-8") as f:
        txt = []
        for line in f.readlines():
            txt.append(line)

        return txt


def get_data_to_buffer(train_config):
    buffer = list()
    text = process_text(train_config.data_path)

    start = time.perf_counter()
    for i in tqdm(range(len(text))):

        mel_gt_name = os.path.join(
            train_config.mel_ground_truth, "ljspeech-mel-%05d.npy" % (i+1))
        mel_gt_target = np.load(mel_gt_name)
        duration = np.load(os.path.join(
            train_config.alignment_path, str(i)+".npy"))
        
        energy_gt_name = os.path.join(
            train_config.energy_path, "ljspeech-energy-%05d.npy" % (i+1)) # new
        energy_gt_target = np.load(energy_gt_name)
        
        pitch_gt_name = os.path.join(
            train_config.pitch_path, "ljspeech-pitch-%05d.npy" % (i+1)) # new
        pitch_gt_target = np.load(pitch_gt_name)

        emotion_gt_name = os.path.join(
            train_config.emotion_path, "ljspeech-emotion-%05d.npy" % (i+1)) # new
        emotion_gt_target = np.load(emotion_gt_name)
        
        character = text[i][0:len(text[i])-1]
        character = np.array(
            text_to_sequence(character, train_config.text_cleaners))

        character = torch.from_numpy(character)
        duration = torch.from_numpy(duration)
        mel_gt_target = torch.from_numpy(mel_gt_target)

        energy_gt_target = torch.from_numpy(energy_gt_target) # new
        pitch_gt_target = torch.from_numpy(pitch_gt_target) # new
        emotion_gt_target = torch.from_numpy(emotion_gt_target) # new

        buffer.append({"text": character, "duration": duration,
                       "mel_target": mel_gt_target, "energy": energy_gt_target,
                       "pitch": pitch_gt_target, "emotion": emotion_gt_target}) # new
        
    # normalize energy and pitch
        
    en = []
    pit = []
    em = []
    for b in buffer:
        en.append(b['energy'].mean())
        
        pit.append(b['pitch'].mean())
        
        #em.append(b['emotion'].mean())
        
    max_energy= np.mean(en)
    max_pitch = np.mean(pit)
    max_emotion = np.mean(em)
    for i in range(len(buffer)):
        buffer[i]['energy'] = buffer[i]['energy']/ max_energy
        buffer[i]['pitch'] = buffer[i]['pitch'] / max_pitch
        #buffer[i]['emotion'] = buffer[i]['emotion'] / max_emotion
        

    end = time.perf_counter()
    print("cost {:.2f}s to load all data into buffer.".format(end-start))

    return buffer


class BufferDataset(Dataset):
    def __init__(self, buffer):
        self.buffer = buffer
        self.length_dataset = len(self.buffer)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        return self.buffer[idx]


def reprocess_tensor(batch, cut_list):
    texts = [batch[ind]["text"] for ind in cut_list]
    mel_targets = [batch[ind]["mel_target"] for ind in cut_list]
    durations = [batch[ind]["duration"] for ind in cut_list]

    energies = [batch[ind]["energy"] for ind in cut_list] # new
    pitches = [batch[ind]["pitch"] for ind in cut_list] # new
    emotions = [batch[ind]["emotion"] for ind in cut_list] # new

    length_text = np.array([])
    for text in texts:
        length_text = np.append(length_text, text.size(0))

    src_pos = list()
    max_len = int(max(length_text))
    for length_src_row in length_text:
        src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
                              (0, max_len-int(length_src_row)), 'constant'))
    src_pos = torch.from_numpy(np.array(src_pos))

    length_mel = np.array(list())
    for mel in mel_targets:
        length_mel = np.append(length_mel, mel.size(0))

    mel_pos = list()
    max_mel_len = int(max(length_mel))
    for length_mel_row in length_mel:
        mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
                              (0, max_mel_len-int(length_mel_row)), 'constant'))
    mel_pos = torch.from_numpy(np.array(mel_pos))

    texts = pad_1D_tensor(texts)
    durations = pad_1D_tensor(durations)
    mel_targets = pad_2D_tensor(mel_targets)

    energies = pad_1D_tensor(energies) # new
    pitches = pad_1D_tensor(pitches) # new
    emotions = pad_1D_tensor(emotions) # new

    out = {"text": texts,
           "mel_target": mel_targets,
           "duration": durations,
           "energy": energies,
           "pitch": pitches,
           "emotion": emotions, 
           "mel_pos": mel_pos,
           "src_pos": src_pos,
           "mel_max_len": max_mel_len}

    return out


def collate_fn_tensor(batch):
    len_arr = np.array([d["text"].size(0) for d in batch])
    index_arr = np.argsort(-len_arr)
    batchsize = len(batch)
    real_batchsize = batchsize // train_config.batch_expand_size

    cut_list = list()
    for i in range(train_config.batch_expand_size):
        cut_list.append(index_arr[i*real_batchsize:(i+1)*real_batchsize])

    output = list()
    for i in range(train_config.batch_expand_size):
        output.append(reprocess_tensor(batch, cut_list[i]))

    return output



In [None]:
#!g1.4
buffer = get_data_to_buffer(train_config)

dataset = BufferDataset(buffer)

training_loader = DataLoader(
    dataset,
    batch_size=train_config.batch_expand_size * train_config.batch_size,
    shuffle=True,
    collate_fn=collate_fn_tensor,
    drop_last=True,
    num_workers=4
)

# Encoder

## Transformer Block

### Multi Head Attention

In [None]:
#!g1.4
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        
        attn = torch.bmm(q, k.transpose(-1, -2)) / self.temperature

        if mask is not None:
            attn = torch.masked_fill(attn, mask, -math.inf)
        
        attn = self.dropout(self.softmax(attn))
        output = torch.bmm(attn, v)

        return output, attn

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_x, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_x = d_x
        self.d_model = d_model

        self.w_qs = nn.Linear(d_model, n_head * d_x)
        self.w_ks = nn.Linear(d_model, n_head * d_x)
        self.w_vs = nn.Linear(d_model, n_head * d_x)

        self.attention = ScaledDotProductAttention(
            temperature=d_x**0.5) 
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_x, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)
        
        self.reset_parameters()

    def reset_parameters(self):
         # normal distribution initialization better than kaiming(default in pytorch)
        nn.init.normal_(self.w_qs.weight, mean=0,
                        std=np.sqrt(2.0 / (self.d_model + self.d_x)))
        nn.init.normal_(self.w_ks.weight, mean=0,
                        std=np.sqrt(2.0 / (self.d_model + self.d_x)))
        nn.init.normal_(self.w_vs.weight, mean=0,
                        std=np.sqrt(2.0 / (self.d_model + self.d_x))) 
        
    def forward(self, x, mask=None):
        d_x, n_head = self.d_x, self.n_head

        sz_b, len_x, _ = x.size()

        residual = x

        q = self.w_qs(x).view(sz_b, len_x, n_head, d_x)
        k = self.w_ks(x).view(sz_b, len_x, n_head, d_x)
        v = self.w_vs(x).view(sz_b, len_x, n_head, d_x)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_x)  # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_x)  # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_x)  # (n*b) x lv x dv
        
        if mask is not None:
            mask = mask.repeat(n_head, 1, 1)  # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)

        output = output.view(n_head, sz_b, len_x, d_x)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)  # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output, attn

### Positionwise Feed Forward

In [None]:
#!g1.4
class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()

        # Use Conv1D
        # position-wise
        self.w_1 = nn.Conv1d(
            d_in, d_hid, kernel_size=model_config.fft_conv1d_kernel[0], padding=model_config.fft_conv1d_padding[0])
        # position-wise
        self.w_2 = nn.Conv1d(
            d_hid, d_in, kernel_size=model_config.fft_conv1d_kernel[1], padding=model_config.fft_conv1d_padding[1])

        self.layer_norm = nn.LayerNorm(d_in)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        output = self.layer_norm(x)
        output = output.transpose(1, 2)
        
        output = self.w_2(F.relu(self.w_1(output)))
        output = output.transpose(1, 2)
        output = self.dropout(output)
        output = x + output #self.layer_norm(output + residual)

        return output

### FFTBlock

In [None]:
#!g1.4
class FFTBlock(torch.nn.Module):
    """FFT Block"""

    def __init__(self,
                 d_model,
                 d_inner,
                 n_head,
                 d_x,
                 dropout=0.1):
        super(FFTBlock, self).__init__()
        self.slf_attn = MultiHeadAttention(
            n_head, d_model, d_x, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(
            d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, mask=slf_attn_mask)
        
        if non_pad_mask is not None:
            enc_output *= non_pad_mask

        enc_output = self.pos_ffn(enc_output)
        
        if non_pad_mask is not None:
            enc_output *= non_pad_mask
        
        return enc_output, enc_slf_attn

## Length Regulator

### Aligner

In [None]:
#!g1.4
def create_alignment(base_mat, duration_predictor_output):
    N, L = duration_predictor_output.shape
    for i in range(N):
        count = 0
        for j in range(L):
            for k in range(duration_predictor_output[i][j]):
                base_mat[i][count+k][j] = 1
            count = count + duration_predictor_output[i][j]
    return base_mat

In [None]:
#!g1.4
create_alignment(
    torch.zeros(1, 6, 3).numpy(),
    torch.LongTensor([[1,2,3]])
)

### Duration Predictor

In [None]:
#!g1.4
class Transpose(nn.Module):
    def __init__(self, dim_1, dim_2):
        super().__init__()
        self.dim_1 = dim_1
        self.dim_2 = dim_2

    def forward(self, x):
        return x.transpose(self.dim_1, self.dim_2)


In [None]:
#!g1.4
class Predictor(nn.Module):
    """ Predictor """

    def __init__(self, model_config: FastSpeechConfig):
        super(Predictor, self).__init__()

        self.input_size = model_config.encoder_dim
        self.filter_size = model_config.predictor_filter_size
        self.kernel = model_config.predictor_kernel_size
        self.conv_output_size = model_config.predictor_filter_size
        self.dropout = model_config.predictor_dropout

        self.conv_net = nn.Sequential(
            Transpose(-1, -2),
            nn.Conv1d(
                self.input_size, self.filter_size,
                kernel_size=self.kernel, padding=1
            ),
            Transpose(-1, -2),
            nn.LayerNorm(self.filter_size),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            Transpose(-1, -2),
            nn.Conv1d(
                self.filter_size, self.filter_size,
                kernel_size=self.kernel, padding=1
            ),
            Transpose(-1, -2),
            nn.LayerNorm(self.filter_size),
            nn.ReLU(),
            nn.Dropout(self.dropout)
        )

        self.linear_layer = nn.Linear(self.conv_output_size, 1)
        self.relu = nn.ReLU()

    def forward(self, encoder_output):
        encoder_output = self.conv_net(encoder_output)
            
        out = self.linear_layer(encoder_output)
        out = self.relu(out)
        out = out.squeeze()
        if not self.training:
            out = out.unsqueeze(0)
        return out

### LR

In [None]:
#!g1.4
class LengthRegulator(nn.Module):
    """ Length Regulator """

    def __init__(self, model_config):
        super(LengthRegulator, self).__init__()
        self.duration_predictor = Predictor(model_config)

    def LR(self, x, duration_predictor_output, mel_max_length=None):
        expand_max_len = torch.max(
            torch.sum(duration_predictor_output, -1), -1)[0]
        alignment = torch.zeros(duration_predictor_output.size(0),
                                expand_max_len,
                                duration_predictor_output.size(1)).numpy()
        alignment = create_alignment(alignment,
                                     duration_predictor_output.cpu().numpy())
        alignment = torch.from_numpy(alignment).to(x.device)

        output = alignment @ x
        if mel_max_length:
            output = F.pad(
                output, (0, 0, 0, mel_max_length-output.size(1), 0, 0))
        return output

    def forward(self, x, alpha=1.0, target=None, mel_max_length=None):
        log_dur_predictor_output = self.duration_predictor(x)

        if target is not None:
            output = self.LR(x, target, mel_max_length)
            return output, torch.exp(log_dur_predictor_output)
        else:
            dur_predictor_output = torch.exp(log_dur_predictor_output)
            dur_predictor_output = ((dur_predictor_output * alpha + 0.5)).int()

            output = self.LR(x, dur_predictor_output)

            mel_pos = torch.stack(
                [torch.Tensor([i+1 for i in range(output.size(1))])]
            ).long().to(train_config.device)

            return output, mel_pos
        

In [None]:
#!g1.4
# map emotion to dimensional model space
circumplex_model = {
    "neutral":0.2,
    "happy":0.8,
    "sad":-0.8,
    "angry":-0.1,
    "fearful":-0.3, 
    "disgust":-0.3, 
}

In [None]:
#!g1.4
import torch.nn as nn
class VarianceAdaptor(nn.Module):
    """ Variance Adaptor """

    def __init__(self, model_config):
        super(VarianceAdaptor, self).__init__()
        self.pitch_predictor = Predictor(model_config)
        self.energy_predictor = Predictor(model_config)
        self.emotion_predictor = Predictor(model_config)

    def forward(self, x, alpha_pitch=1.0, alpha_energy=1.0, preferred_emotion=None, target_pitch=None, target_energy=None, target_emotion=None, mel_max_length=None):
        ### Your code here
        #dur_predictor_output = self.duration_predictor(x)
        pitch_predictor_output = self.pitch_predictor(x)
        energy_predictor_output = self.energy_predictor(x)
        emotion_predictor_output = self.emotion_predictor(x)
        #print(x.shape, pitch_predictor_output.shape, target_pitch.shape)
        if target_pitch is not None:
            p_vals = target_pitch.quantile(torch.tensor([i/256 for i in range(256)]).to(target_pitch.dtype).to(target_pitch.device)).reshape((1, 256))
            target_pitch_ohe = torch.abs(target_pitch[:,:,None].expand((target_pitch.shape[0], target_pitch.shape[1], 256)) - p_vals[:,None,].expand((target_pitch.shape[0], target_pitch.shape[1], 256)))
            target_pitch_ohe = target_pitch_ohe.view(-1, target_pitch_ohe.shape[-1])
            ones = target_pitch_ohe.argmin(dim=-1)
            target_pitch_ohe = torch.zeros_like(target_pitch_ohe)
            target_pitch_ohe[torch.arange(target_pitch_ohe.shape[0]), ones] = 1
            target_pitch_ohe = target_pitch_ohe.view(x.shape[0], -1, 256)
            
            
            p_vals = target_energy.quantile(torch.tensor([i/256 for i in range(256)]).to(target_energy.dtype).to(target_energy.device)).reshape((1, 256))
            target_energy_ohe = torch.abs(target_energy[:,:,None].expand((target_energy.shape[0], target_energy.shape[1], 256)) - p_vals[:,None,].expand((target_energy.shape[0], target_energy.shape[1], 256)))
            target_energy_ohe = target_energy_ohe.view(-1, target_energy_ohe.shape[-1])
            ones = target_energy_ohe.argmin(dim=-1)
            target_energy_ohe = torch.zeros_like(target_energy_ohe)
            target_energy_ohe[torch.arange(target_energy_ohe.shape[0]), ones] = 1
            target_energy_ohe = target_energy_ohe.view(x.shape[0], -1, 256)
            
            print(x.shape, target_energy_ohe.shape, energy_predictor_output.shape, target_energy.shape)
            
            p_vals = target_emotion.quantile(torch.tensor([i/256 for i in range(256)]).to(target_emotion.dtype).to(target_emotion.device)).reshape((1, 256))
            target_emotion_ohe = torch.abs(target_emotion[:,:,None].expand((target_emotion.shape[0], target_emotion.shape[1], 256)) - p_vals[:,None,].expand((target_emotion.shape[0], target_emotion.shape[1], 256)))
            target_emotion_ohe = target_emotion_ohe.view(-1, target_emotion_ohe.shape[-1])
            ones = target_emotion_ohe.argmin(dim=-1)
            target_emotion_ohe = torch.zeros_like(target_emotion_ohe)
            target_emotion_ohe[torch.arange(target_emotion_ohe.shape[0]), ones] = 1
            target_emotion_ohe = target_emotion_ohe.view(x.shape[0], -1, 256)
            
            output = x + target_pitch_ohe + target_energy_ohe + target_emotion_ohe
            
            if mel_max_length:
                output = F.pad(output, (0, 0, 0, mel_max_length-output.size(1), 0, 0))
            return output, pitch_predictor_output, energy_predictor_output, emotion_predictor_output

        else:
            pitch_predictor_output = pitch_predictor_output * alpha_pitch
            p_vals = pitch_predictor_output.quantile(torch.tensor([i/256 for i in range(256)]).to(pitch_predictor_output.dtype).to(pitch_predictor_output.device)).reshape((1, 256))
            pitch_predictor_output_ohe = torch.abs(pitch_predictor_output[:,:,None].expand((pitch_predictor_output.shape[0], pitch_predictor_output.shape[1], 256)) - p_vals[:,None,].expand((pitch_predictor_output.shape[0], pitch_predictor_output.shape[1], 256)))
            pitch_predictor_output_ohe = pitch_predictor_output_ohe.view(-1, pitch_predictor_output_ohe.shape[-1])
            ones = pitch_predictor_output_ohe.argmin(dim=-1)
            pitch_predictor_output_ohe = torch.zeros_like(pitch_predictor_output_ohe)
            pitch_predictor_output_ohe[torch.arange(pitch_predictor_output_ohe.shape[0]), ones] = 1
            pitch_predictor_output_ohe = pitch_predictor_output_ohe.view(x.shape[0], -1, 256)
            
            energy_predictor_output = energy_predictor_output * alpha_energy
            p_vals = energy_predictor_output.quantile(torch.tensor([i/256 for i in range(256)]).to(energy_predictor_output.dtype).to(energy_predictor_output.device)).reshape((1, 256))
            energy_predictor_output_ohe = torch.abs(energy_predictor_output[:,:,None].expand((energy_predictor_output.shape[0], energy_predictor_output.shape[1], 256)) - p_vals[:,None,].expand((energy_predictor_output.shape[0], energy_predictor_output.shape[1], 256)))
            energy_predictor_output_ohe = energy_predictor_output_ohe.view(-1, energy_predictor_output_ohe.shape[-1])
            ones = energy_predictor_output_ohe.argmin(dim=-1)
            energy_predictor_output_ohe = torch.zeros_like(energy_predictor_output_ohe)
            energy_predictor_output_ohe[torch.arange(energy_predictor_output_ohe.shape[0]), ones] = 1
            energy_predictor_output_ohe = energy_predictor_output_ohe.view(x.shape[0], -1, 256)
            
            if preferred_emotion is not None:
                emotion_predictor_output = 0.1 * emotion_predictor_output + 0.9 * circumplex_model[preferred_emotion]
            p_vals = emotion_predictor_output.quantile(torch.tensor([i/256 for i in range(256)]).to(emotion_predictor_output.dtype).to(emotion_predictor_output.device)).reshape((1, 256))
            emotion_predictor_output_ohe = torch.abs(emotion_predictor_output[:,:,None].expand((emotion_predictor_output.shape[0], emotion_predictor_output.shape[1], 256)) - p_vals[:,None,].expand((emotion_predictor_output.shape[0], emotion_predictor_output.shape[1], 256)))
            emotion_predictor_output_ohe = emotion_predictor_output_ohe.view(-1, emotion_predictor_output_ohe.shape[-1])
            ones = emotion_predictor_output_ohe.argmin(dim=-1)
            emotion_predictor_output_ohe = torch.zeros_like(emotion_predictor_output_ohe)
            emotion_predictor_output_ohe[torch.arange(emotion_predictor_output_ohe.shape[0]), ones] = 1
            emotion_predictor_output_ohe = emotion_predictor_output_ohe.view(x.shape[0], -1, 256)
            
            
            output = x + pitch_predictor_output_ohe + energy_predictor_output_ohe + emotion_predictor_output_ohe
            
            return output

## Final BLock

### Attention masks

In [None]:
#!g1.4
def get_non_pad_mask(seq):
    assert seq.dim() == 2
    return seq.ne(model_config.PAD).type(torch.float).unsqueeze(-1)

def get_attn_key_pad_mask(seq_k, seq_q):
    ''' For masking out the padding part of key sequence. '''
    # Expand to fit the shape of key query attention matrix.
    len_q = seq_q.size(1)
    padding_mask = seq_k.eq(model_config.PAD)
    padding_mask = padding_mask.unsqueeze(
        1).expand(-1, len_q, -1)  # b x lq x lk

    return padding_mask

### Encoder

In [None]:
#!g1.4
class Encoder(nn.Module):
    def __init__(self, model_config):
        super(Encoder, self).__init__()
        
        len_max_seq=model_config.max_seq_len
        n_position = len_max_seq + 1
        n_layers = model_config.encoder_n_layer

        self.src_word_emb = nn.Embedding(
            model_config.vocab_size,
            model_config.encoder_dim,
            padding_idx=model_config.PAD
        )

        self.position_enc = nn.Embedding(
            n_position,
            model_config.encoder_dim,
            padding_idx=model_config.PAD
        )

        self.layer_stack = nn.ModuleList([FFTBlock(
            model_config.encoder_dim,
            model_config.encoder_conv1d_filter_size,
            model_config.encoder_head,
            model_config.encoder_dim // model_config.encoder_head,
            dropout=model_config.dropout
        ) for _ in range(n_layers)])

    def forward(self, src_seq, src_pos, return_attns=False):

        enc_slf_attn_list = []

        # -- Prepare masks
        slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
        non_pad_mask = get_non_pad_mask(src_seq)
        
        # -- Forward
        enc_output = self.src_word_emb(src_seq) + self.position_enc(src_pos)

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(
                enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        return enc_output, non_pad_mask

In [None]:
#!g1.4
class Decoder(nn.Module):
    """ Decoder """

    def __init__(self, model_config):

        super(Decoder, self).__init__()

        len_max_seq=model_config.max_seq_len
        n_position = len_max_seq + 1
        n_layers = model_config.decoder_n_layer

        self.position_enc = nn.Embedding(
            n_position,
            model_config.encoder_dim,
            padding_idx=model_config.PAD,
        )

        self.layer_stack = nn.ModuleList([FFTBlock(
            model_config.encoder_dim,
            model_config.encoder_conv1d_filter_size,
            model_config.encoder_head,
            model_config.encoder_dim // model_config.encoder_head,
            dropout=model_config.dropout
        ) for _ in range(n_layers)])

    def forward(self, enc_seq, enc_pos, return_attns=False):

        dec_slf_attn_list = []

        # -- Prepare masks
        slf_attn_mask = get_attn_key_pad_mask(seq_k=enc_pos, seq_q=enc_pos)
        non_pad_mask = get_non_pad_mask(enc_pos)

        # -- Forward
        dec_output = enc_seq + self.position_enc(enc_pos)

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn = dec_layer(
                dec_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)
            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]

        return dec_output
    

In [None]:
#!g1.4
def get_mask_from_lengths(lengths, max_len=None):
    if max_len == None:
        max_len = torch.max(lengths).item()

    ids = torch.arange(0, max_len, 1, device=lengths.device)
    mask = (ids < lengths.unsqueeze(1)).bool()

    return mask

In [None]:
#!g1.4
class FastSpeech2(nn.Module):
    """ FastSpeech2 """

    def __init__(self, model_config):
        super(FastSpeech2, self).__init__()

        self.encoder = Encoder(model_config)
        self.length_regulator = LengthRegulator(model_config)
        self.variance_adaptor = VarianceAdaptor(model_config)
        self.decoder = Decoder(model_config)

        self.mel_linear = nn.Linear(model_config.decoder_dim, mel_config.num_mels)

    def mask_tensor(self, mel_output, position, mel_max_length):
        lengths = torch.max(position, -1)[0]
        mask = ~get_mask_from_lengths(lengths, max_len=mel_max_length)
        mask = mask.unsqueeze(-1).expand(-1, -1, mel_output.size(-1))
        return mel_output.masked_fill(mask, 0.)

    def forward(self, src_seq, src_pos, mel_pos=None, mel_max_length=None, length_target=None, target_pitch=None, target_energy=None, target_emotion=None, alpha=1.0, alpha_pitch=1.0, alpha_energy=1.0, preferred_emotion=None):
        ### Your code here
        x, non_pad_mask = self.encoder(src_seq, src_pos)

        if self.training:
            output, dur_predictor_output = self.length_regulator(x, alpha, length_target, mel_max_length)
            output, pitch_predictor_output, energy_predictor_output, emotion_predictor_output = self.variance_adaptor(output, alpha_pitch, alpha_energy, preferred_emotion, target_pitch, target_energy, target_emotion, mel_max_length)
            output = self.decoder(output, mel_pos)
            output = self.mask_tensor(output, mel_pos, mel_max_length)
            output = self.mel_linear(output)
            return output, dur_predictor_output, pitch_predictor_output, energy_predictor_output, emotion_predictor_output

        else:
            output, mel_pos = self.length_regulator(x, alpha)
            output = self.variance_adaptor(output, alpha_pitch, alpha_energy, preferred_emotion)
            output = self.decoder(output, mel_pos)
            output = self.mel_linear(output)

            return output

## Loss

In [None]:
#!g1.4
import torch
import torch.nn as nn

class FastSpeechLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()

    def forward(self, mel, duration_predicted, pitch_predicted, energy_predicted, emotion_predicted, mel_target, duration_predictor_target, pitch_predictor_target, energy_predictor_target, emotion_predictor_target):
        mel_loss = self.mse_loss(mel, mel_target)

        duration_predictor_loss = self.l1_loss(duration_predicted,
                                               duration_predictor_target.float())
        
        pitch_predictor_loss = self.mse_loss(pitch_predicted,
                                               pitch_predictor_target)
        
        energy_predictor_loss = self.mse_loss(energy_predicted,
                                               energy_predictor_target)
        
        emotion_predictor_loss = self.mse_loss(emotion_predicted,
                                               emotion_predictor_target)

        return mel_loss, duration_predictor_loss, pitch_predictor_loss, energy_predictor_loss, emotion_predictor_loss


# Train

In [None]:
#!g1.4
from torch.optim.lr_scheduler  import OneCycleLR
from wandb_writer import WanDBWriter

In [None]:
#!g1.4
model = FastSpeech2(model_config)
model.load_state_dict(torch.load('./model_new2/checkpoint_63000.pth.tar', map_location='cuda:0')['model'])
model = model.to(train_config.device)

fastspeech_loss = FastSpeechLoss()
current_step = 0

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=train_config.learning_rate,
    betas=(0.9, 0.98),
    eps=1e-9)

scheduler = OneCycleLR(optimizer, **{
    "steps_per_epoch": len(training_loader) * train_config.batch_expand_size,
    "epochs": train_config.epochs,
    "anneal_strategy": "cos",
    "max_lr": train_config.learning_rate,
    "pct_start": 0.1
})

In [None]:
#!g1.4
logger = WanDBWriter(train_config)

## Train loop

In [None]:
#!g1.4
tqdm_bar = tqdm(total=train_config.epochs * len(training_loader) * train_config.batch_expand_size - current_step)


for epoch in range(train_config.epochs):
    for i, batchs in enumerate(training_loader):
        # real batch start here
        for j, db in enumerate(batchs):
            current_step += 1
            tqdm_bar.update(1)
            
            logger.set_step(current_step)

            # Get Data
            character = db["text"].long().to(train_config.device)
            mel_target = db["mel_target"].float().to(train_config.device)
            duration = db["duration"].int().to(train_config.device)
            
            energy = db["energy"].float().to(train_config.device)
            pitch = db["pitch"].float().to(train_config.device)
            emotion = db["emotion"].float().to(train_config.device)
            
            mel_pos = db["mel_pos"].long().to(train_config.device)
            src_pos = db["src_pos"].long().to(train_config.device)
            max_mel_len = db["mel_max_len"]
            

            # Forward
            mel_output, duration_predictor_output, pitch_predictor_output, energy_predictor_output, emotion_predictor_output = model(character,
                                                          src_pos,
                                                          mel_pos=mel_pos,
                                                          mel_max_length=max_mel_len,
                                                          length_target=duration,
                                                          target_pitch=pitch, target_energy=energy, target_emotion=emotion,
                                                          )
            
            # Calc Loss
            mel_loss, duration_loss, pitch_loss, energy_loss, emotion_loss = fastspeech_loss(mel_output,
                                                    duration_predictor_output, 
                                                    pitch_predictor_output, # 868
                                                    energy_predictor_output,
                                                    emotion_predictor_output,
                                                    mel_target,
                                                    duration,
                                                    pitch, # 868
                                                    energy,
                                                    emotion)
            total_loss = mel_loss + duration_loss + pitch_loss + energy_loss + emotion_loss

            # Logger
            t_l = total_loss.detach().cpu().numpy()
            m_l = mel_loss.detach().cpu().numpy()
            d_l = duration_loss.detach().cpu().numpy()
            p_l = pitch_loss.detach().cpu().numpy()
            en_l = energy_loss.detach().cpu().numpy()
            em_l = emotion_loss.detach().cpu().numpy()
            
            logger.add_scalar("emotion_loss", em_l)
            logger.add_scalar("energy", en_l)
            logger.add_scalar("pitch_loss", p_l)
            logger.add_scalar("duration_loss", d_l)
            logger.add_scalar("mel_loss", m_l)
            logger.add_scalar("total_loss", t_l)

            # Backward
            total_loss.backward()

            # Clipping gradients to avoid gradient explosion
            nn.utils.clip_grad_norm_(
                model.parameters(), train_config.grad_clip_thresh)
            
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()
            
            clear_output()
            
            if current_step % train_config.save_step == 0:
                torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(
                )}, os.path.join(train_config.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step))
                print("save model at step %d ..." % current_step)


In [None]:
#!g1.4
import waveglow
import text
import audio
import utils

In [None]:
#!g1.4
WaveGlow = utils.get_WaveGlow()
WaveGlow = WaveGlow.cuda()
model.load_state_dict(torch.load('./model_new/checkpoint_9000.pth.tar', map_location='cuda:0')['model'])
model = model.eval()

In [None]:
#!g1.4
def synthesis(model, text, alpha=1.0, alpha_pitch=1.0, alpha_energy=1.0, emotion='neutral'):
    text = np.array(phn)
    text = np.stack([text])
    src_pos = np.array([i+1 for i in range(text.shape[1])])
    src_pos = np.stack([src_pos])
    sequence = torch.from_numpy(text).long().to(train_config.device)
    src_pos = torch.from_numpy(src_pos).long().to(train_config.device)
    
    with torch.no_grad():
        mel = model.forward(sequence, src_pos, alpha=alpha, alpha_pitch=alpha_pitch, alpha_energy=alpha_energy, preferred_emotion=emotion)
    return mel[0].cpu().transpose(0, 1), mel.contiguous().transpose(1, 2)

def get_data():
    tests = [ 
        "A defibrillator is a device that gives a high energy electric shock to the heart of someone who is in cardiac arrest",
        "Massachusetts Institute of Technology may be best known for its math, science and engineering education",
        "Wasserstein distance or Kantorovich Rubinstein metric is a distance function defined between probability distributions on a given metric space"
        
    ]
    data_list = list(text.text_to_sequence(test, train_config.text_cleaners) for test in tests)

    return data_list

data_list = get_data()
for alpha in [0.8, 1., 1.2]:
    for alpha_pitch in [0.8, 1., 1.2]:
        for alpha_energy in [0.8, 1., 1.2]:
            for emotion in ['happy', 'sad', 'disgust']:
                for i, phn in tqdm(enumerate(data_list)):
                    mel, mel_cuda = synthesis(model, phn, alpha)
        
                    os.makedirs("results", exist_ok=True)
        
                    waveglow.inference.inference(
                        mel_cuda, WaveGlow,
                        f"results/s=speed{alpha}_pitch{alpha_pitch}_energy{alpha_energy}_{emotion}_{i}_waveglow.wav"
                    )