# 00. Homework description

For this assignment, your task is to develop the [FastPitch](https://arxiv.org/abs/2006.06873) synthesis model, train it, and generate several audio samples.

The training and data processing code has already been provided for you. The training will be conducted on the LJspeech dataset.

The total score for the homework is **12 points**, distributed as follows:
- 2 points for visualizing the input data
- 8 points for writing the model code
- 2 points for model inference

The homework submission should include:
- Completed notebook
- Attached WER and loss graphs from TensorBoard
- 1 audio file - the result of a regular model inference
- 4 additional audio files - for an extra two points: you are encouraged to experiment with adjusting phoneme durations and pitch slightly and listen to the results.

# 01. Preparation steps

In [None]:
device = "cuda"
gpu_avaiable = "1"    # Run nvidia-smi to find free GPU

In [None]:
path_to_repository = ...  # Path to the repository root, e.g. /home/user/speech_course

In [None]:
# If running in colab

# clone the repository:
# git clone https://github.com/yandexdataschool/speech_course.git
# !pip install -r speech_course/week_07_tts_am/requirements.txt

### Dataset

We will work with [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) -- a single-speaker dataset with 24 hours of speech.

The data we will use contains pre-computed [MFA-alignments](https://montreal-forced-aligner.readthedocs.io/en/latest/user_guide/workflows/alignment.html) alongside with the original wavs and texts. If you are interested in the process of extracting such alignments, please refer to this [tutorial](https://colab.research.google.com/gist/NTT123/12264d15afad861cb897f7a20a01762e/mfa-ljspeech.ipynb).

Download the dataset with precomputed alignments.

In [None]:
import requests
from urllib.parse import urlencode
from io import BytesIO
from zipfile import ZipFile

base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
public_key = 'https://disk.yandex.ru/d/PpgePfWcQTAbug'

final_url = base_url + urlencode(dict(public_key=public_key))
response = requests.get(final_url)
download_url = response.json()['href']
response = requests.get(download_url)

path_to_dataset = 'data/ljspeech'    # Choose any appropriate local path

# If running in Colab:
# path_to_dataset = '/content/ljspeech_aligned'

zipfile = ZipFile(BytesIO(response.content))
zipfile.extractall(path=path_to_dataset)

### Hi-Fi GAN checkpoint


Download a pretrained Hi-Fi GAN checkpoint (to generate audio from the predicted mel-spectrograms).

In [None]:
!wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_ds-ljs22khz/versions/21.08.0_amp/zip -O hifigan_ckpt.zip
!unzip hifigan_ckpt.zip
!rm hifigan_ckpt.zip

In [None]:
# In colab:
# !wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_ds-ljs22khz/versions/21.08.0_amp/zip -O /content/hifigan_ckpt.zip
# !unzip /content/hifigan_ckpt.zip
# !rm /content/hifigan_ckpt.zip

In [None]:
path_to_hfg_ckpt = "hifigan_gen_checkpoint_6500.pt" 

### Imports

In [None]:
import sys
import os
import json
import dataclasses
import torch
import subprocess as sp
import matplotlib.pylab as plt

from g2p_en import G2p
import IPython.display as Ipd

In [None]:
sys.path.append(path_to_repository)

# 02. See a data sample (2 points)

In [None]:
from week_07_tts_am.fastpitch.hparams import HParamsFastpitch
from week_07_tts_am.fastpitch.data import prepare_loaders

The mfa alignment provides phonemes and their durations, which we will need during training:

In [None]:
with open(os.path.join(path_to_dataset, 'mfa_aligned', 'LJ001-0001.json')) as f:
  utterance = json.load(f)

utterance

Phoneme `sil` here denotes pause -- a period of silence between spoken phonemes. The phonemes are from [ARPA](https://en.wikipedia.org/wiki/ARPABET) alphabet.

In [None]:
hparams = HParamsFastpitch()
train_loader, val_loader = prepare_loaders(path_to_dataset, hparams)

In [None]:
train_iter = iter(train_loader)
batch = next(train_iter)

In [None]:
list(dataclasses.asdict(batch).keys())

In [None]:
batch.mels.shape

In [None]:
batch.pitches.shape

In [None]:
batch.durations.shape

## Task

**(1 point)** Draw a combined image showing both the mel-spectrogram and pitch for a sample from the batch. Use durations to ensure proper alignment of their shapes in the image.  
- The visualization from the seminar might be beneficial for drawing pitch. The parameter `WINDOW_SIZE` required for determining the frequencies' resolution (df) is defined [here](https://github.com/yandexdataschool/speech_course/blob/main/week_07_tts_am/fastpitch/data.py#L108). 
   
**(1 point)** Include phoneme labels near the time axis on the image from the previous step. (like in Figure 3 in the [paper](https://arxiv.org/pdf/2006.06873.pdf)).  
-  Use `week_07_tts_am.fastpitch.data.SymbolsSet` to get the actual phonemes from their indices in the batch

A useful code snippet:
```py
ax.text(
   s, ax.get_ylim()[0] * 1.15, 
   symbol, 
   rotation=90
)
```

`get_ylim() * 1.15` allows you to draw the phonemes lower, not directly on the spectrogram.  
`rotation=90` will rotate them sideways, reducing overlap with each other.

In [None]:
<YOUR CODE HERE>

# 03. Implement FastPitch model (8 points)

- Please implement the FastPitch model in the cell provided below. Running this cell will overwrite the model file in the repository. 
- Run training (see next cells)
- When submitting the homework, please include the Word Error Rate (WER) and loss curves obtained from TensorBoard as attachments.

At the end of the training, if all is well, the WER should reach a value of 0.006, and the loss should reach a value of 0.69. The training will run approximately 30 minutes (3000 batches). 

In [None]:
%%writefile <path to repository>/week_07_tts_am/fastpitch/model.py

import torch
from torch import nn as nn

from week_07_tts_am.fastpitch.common.layers import TemporalPredictor
from week_07_tts_am.fastpitch.common.utils import DeviceGetterMixin
from week_07_tts_am.fastpitch.common.utils import regulate_len
from week_07_tts_am.fastpitch.data import FastPitchBatch, SymbolsSet
from week_07_tts_am.fastpitch.hparams import HParamsFastpitch
from week_07_tts_am.fastpitch.common.transformer import FFTransformer


class FastPitch(nn.Module, DeviceGetterMixin):
    def __init__(self, hparams: HParamsFastpitch):
        super().__init__()
        self.hparams = hparams
        n_symbols = len(SymbolsSet().symbols_to_id)

        self.symbol_emb = nn.Embedding(n_symbols, hparams.symbols_embedding_dim)

        self.encoder = FFTransformer(
            n_layer=hparams.in_fft_n_layers,
            n_head=hparams.in_fft_n_heads,
            d_model=hparams.symbols_embedding_dim,
            d_head=hparams.in_fft_d_head,
            d_inner=4 * hparams.symbols_embedding_dim,
            kernel_size=hparams.in_fft_conv1d_kernel_size,
            dropout=hparams.p_in_fft_dropout,
            dropatt=hparams.p_in_fft_dropatt,
            dropemb=hparams.p_in_fft_dropemb
        )

        self.duration_predictor = TemporalPredictor(
            input_size=hparams.symbols_embedding_dim,
            filter_size=hparams.dur_predictor_filter_size,
            kernel_size=hparams.dur_predictor_kernel_size,
            dropout=hparams.p_dur_predictor_dropout,
            n_layers=hparams.dur_predictor_n_layers
        )

        self.pitch_predictor = TemporalPredictor(
            input_size=hparams.symbols_embedding_dim,
            filter_size=hparams.pitch_predictor_filter_size,
            kernel_size=hparams.pitch_predictor_kernel_size,
            dropout=hparams.p_pitch_predictor_dropout,
            n_layers=hparams.pitch_predictor_n_layers
        )

        self.pitch_emb = nn.Conv1d(1, hparams.symbols_embedding_dim, kernel_size=3, padding=1)

        self.decoder = FFTransformer(
            n_layer=hparams.out_fft_n_layers,
            n_head=hparams.out_fft_n_heads,
            d_model=hparams.symbols_embedding_dim,
            d_head=hparams.out_fft_d_head,
            d_inner=4 * hparams.symbols_embedding_dim,
            kernel_size=hparams.out_fft_conv1d_kernel_size,
            dropout=hparams.p_out_fft_dropout,
            dropatt=hparams.p_out_fft_dropatt,
            dropemb=hparams.p_out_fft_dropemb
        )

        self.proj = nn.Linear(hparams.symbols_embedding_dim, hparams.n_mel_channels, bias=True)

    def get_encoder_out(self, batch: FastPitchBatch):
        '''
        Return: 
        enc_out: 
            Output of the first series of FFT blocks (before adding pitch embedding)
            shape: (batch, len(text), symbols_embedding_dim)
        enc_mask:
            Boolean padding mask for the input text sequences
            shape: (batch, len(text), 1)
        '''
        <YOUR CODE HERE>
        return enc_out, enc_mask

    def forward(self, batch: FastPitchBatch, use_gt_durations=True, use_gt_pitch=True, max_duration=75):
        '''
        Flags `use_gt_durations` and `use_gt_pitch` should be both True during training and either True or False during inference.

        Use the function `regulate_len` to duplicate phonemes according to durations before passing them to the decoder.
        
        Return:
        mel_out:
            Predicted mel-spectrograms
            shape: (batch, time, mel_bins)
        mel_lens:
            Number of time frames in each of the predicted spectrograms
            shape: (batch,)
        log_dur_pred:
            The predicted log-durations for each phoneme (the output of the duration predictor).
            shape: (batch, len(text))
        dur_pred:
            The exponent of the predicted log-durations for each phoneme. Clamped to the range (0, max_duration) for numeric stability
            shape: (batch, len(text))
        pitch_pred:
            The predicted pitch for each phoneme
            shape: (batch, len(text))
        '''
        <YOUR CODE HERE>
        return mel_out, mel_lens, dur_pred, log_dur_pred, pitch_pred

    @torch.no_grad()
    def infer(self, batch: FastPitchBatch, max_duration=75):
        enc_out, dur_pred, pitch_pred = self.infer_encoder(batch, max_duration=max_duration)
        mel_out, mel_lens = self.infer_decoder(enc_out, dur_pred)
        return mel_out, mel_lens, dur_pred, pitch_pred

    def infer_encoder(self, batch: FastPitchBatch, max_duration=75):
        <YOUR CODE HERE>
        return enc_out, dur_pred, pitch_pred

    def infer_decoder(self, enc_out, dur_pred):
        <YOUR CODE HERE>
        return mel_out, mel_lens
    

In [None]:
 # Allows reloading code import without kernel restart
%load_ext autoreload
%autoreload 2

In [None]:
from week_07_tts_am.fastpitch.model import FastPitch

In [None]:
fp = FastPitch(hparams)

In [None]:
enc_out, enc_mask = fp.get_encoder_out(batch)

In [None]:
assert enc_out.shape == torch.Size([hparams.batch_size, batch.texts.shape[1], hparams.symbols_embedding_dim])
assert enc_mask.shape == torch.Size([hparams.batch_size, batch.texts.shape[1], 1])

In [None]:
mel_out, mel_lens, dur_pred, log_dur_pred, pitch_pred = fp.forward(batch)

In [None]:
assert mel_out.shape == batch.mels.transpose(2, 1).shape
assert mel_lens.shape == batch.mel_lengths.shape
assert dur_pred.shape == batch.texts.shape
assert dur_pred.shape == log_dur_pred.shape
assert pitch_pred.shape == batch.texts.shape

### Run training

In [24]:
logs_dir = "logs"     # Choose any paths
ckpt_dir = "checkpoints"

In [None]:
os.makedirs(logs_dir, exist_ok=True)
os.makedirs(ckpt_dir, exist_ok=True)

In [None]:
sp.check_call(
    ' '.join([
        f'PYTHONPATH={path_to_repository} CUDA_VISIBLE_DEVICES={gpu_avaiable}',
        f'python3 -m week_07_tts_am.fastpitch.train_fastpitch',
        f'--logs {logs_dir}',
        f'--ckptdir {ckpt_dir}',
        f'--dataset {path_to_dataset}',
        f'--hfg {path_to_hfg_ckpt}'
    ]), shell=True
)

In [None]:
# If running in colab:

# %load_ext tensorboard
# %tensorboard --logdir logs

In [None]:
# If running in colab:

# %%shell

# mkdir logs checkpoints

# PYTHONPATH=speech_course python3 -m week_07_tts_am.fastpitch.train_fastpitch  \
# --logs logs \
# --ckptdir checkpoints \
# --dataset /content/ljspeech_aligned \
# --hfg /content/hifigan_gen_checkpoint_6500.pt

# 05. Inference (2 points)

In [None]:
from week_07_tts_am.fastpitch.common.checkpointer import Checkpointer
from week_07_tts_am.fastpitch.model import FastPitch
from week_07_tts_am.fastpitch.data import FastPitchBatch, SymbolsSet
from week_07_tts_am.hifigan.model import load_model as load_hfg_model

In [None]:
def get_symbol_ids(text):
    g2p = G2p()
    phonemes = g2p(text)

    symbols_set = SymbolsSet()
    
    symbols = []
    for ph in phonemes:
        if ph in symbols_set.symbols_to_id:
            symbols.append(ph)
        elif ph == ' ':
            continue
        else:
            symbols.append("sil")
    
    symbols_ids = torch.LongTensor(symbols_set.encode(symbols))
    text_length = torch.LongTensor([symbols_ids.shape[0]])

    return symbols_ids, text_length

In [None]:
checkpointer = Checkpointer(ckpt_dir)

In [None]:
hfg = load_hfg_model(path_to_hfg_ckpt)
hfg = hfg.to(device).eval()

In [None]:
ckpt_dict = checkpointer.load_last_checkpoint()
hparams = HParamsFastpitch.create(ckpt_dict['hparams'])
fp = FastPitch(hparams)
fp.load_state_dict(ckpt_dict['state_dict'])
fp = fp.to(device)

In [None]:
ckpt_dict

In [None]:
text = "Hi! My name is Annie Wilson! Nice to meet you."

In [None]:
symbols_ids, lengths = get_symbol_ids(text)

batch = FastPitchBatch(
    texts=symbols_ids.unsqueeze(0),
    text_lengths=lengths
).to(device)

In [None]:
with torch.no_grad():
    mels, mel_lens, *_ = fp.infer(batch)
    mels = mels.permute(0, 2, 1)
    audio = hfg(mels)

Ipd.display(Ipd.Audio(audio.squeeze().cpu().detach().numpy(), rate=22050))

## Task
- Execute the code provided above. Then, append the generated audio to the homework results
   - if attaching an archive, use name: `prediction.wav`
- **(1 point)** Try increasing and decreasing the prediction speed by a factor of 2, draw spectrograms for each case
    - if attaching an archive, use names:  `prediction_half_dur.wav`,  `prediction_double_dur.wav`
- **(1 point)** Try shifting prediction pitch 50 Hz up and down, draw spectrograms for each case
    - if attaching an archive, use names:  `prediction_50hz_up.wav`,  `prediction_50hz_down.wav`

Аttach resulting audio files to the homework report. 

In [None]:
def scale_durations(durations: torch.Tensor, scale_factor: float):
    <YOUR CODE HERE>


def shift_pitch(pitch: torch.Tensor, shift: float):
    scale = 62.51305    # Mean and variance of pitch in LJSpeech used for target pitch normalization
    mean = 215.42230
    <YOUR CODE HERE>

In [None]:
_, dur_pred, pitch_pred = fp.infer_encoder(batch)

In [None]:
batch = FastPitchBatch(
    texts=symbols_ids.unsqueeze(0),
    text_lengths=lengths,
    pitches=<YOUR CODE HERE>,
    durations=<YOUR CODE HERE>
).to(device)

In [None]:
with torch.no_grad():
    mels, mel_lens, *_ = fp(batch, use_gt_durations=True, use_gt_pitch=True)
    mels = mels.permute(0, 2, 1)
    audio = hfg(mels)

Ipd.display(Ipd.Audio(audio.squeeze().cpu().detach().numpy(), rate=22050))
plt.imshow(mels.squeeze().cpu().detach().numpy())
plt.show()