## Finetuning FastPitch

In [None]:
import os
import json
from pathlib import Path

import wandb

import torch
import pandas as pd
import IPython.display as ipd
from matplotlib.pyplot import imshow
from matplotlib import pyplot as plt

Let's also download the pretrained checkpoint that we want to finetune from. NeMo will save checkpoints to `~/.cache`, so let's move that to our current directory. 

*Note: please, check that `home_path` refers to your home folder. Otherwise, change it manually.*

In [None]:
SPEAKER_ID = "9017"
MODEL_NAME = "tts_en_fastpitch"

Let's download the pretrained model

In [None]:
from nemo.collections.tts.models import FastPitchModel
FastPitchModel.from_pretrained(MODEL_NAME);

We will copy the model to our current working dir for simplicity

In [None]:
for nemo_file in (Path.home()/".cache/torch/NeMo/").glob(f"**/{MODEL_NAME}_align.nemo"):
    print(f"Copying {nemo_file} to ./{nemo_file.name}")
    Path(f"./{nemo_file.name}").write_bytes(nemo_file.read_bytes())

## Train Time!

We can now train our model with the following command:

**NOTE: This will take about 50 minutes on colab's K80 GPUs.**

In [None]:
!(python fastpitch_finetune.py --config-name=fastpitch_align_v1.05.yaml \
  train_dataset=9017_manifest_train_local.json \
  validation_datasets=9017_manifest_valid_local.json \
  exp_manager.exp_dir=./fastpitch_finetune \
  +exp_manager.create_wandb_logger=True \
  +exp_manager.wandb_logger_kwargs='{project:nemo, job_type:training, log_model:True}' \
  +init_from_nemo_model=./tts_en_fastpitch_align.nemo \
  model.pitch_mean=121.9 model.pitch_std=23.1 \
  model.pitch_fmin=30 model.pitch_fmax=512 \
  trainer.max_steps=1000 \
)

Let's take a closer look at the training command:

* `--config-name=fastpitch_align_v1.05.yaml`
  * We first tell the script what config file to use.

* `train_dataset=./6097_manifest_train_dur_5_mins_local.json 
  validation_datasets=./6097_manifest_dev_ns_all_local.json 
  sup_data_path=./fastpitch_sup_data`
  * We tell the script what manifest files to train and eval on, as well as where supplementary data is located (or will be calculated and saved during training if not provided).
  
* `phoneme_dict_path=tts_dataset_files/cmudict-0.7b_nv22.10 
heteronyms_path=tts_dataset_files/heteronyms-052722
whitelist_path=tts_dataset_files/lj_speech.tsv 
`
  * We tell the script where `phoneme_dict_path`, `heteronyms-052722` and `whitelist_path` are located. These are the additional files we downloaded earlier, and are used in preprocessing the data.
  
* `exp_manager.exp_dir=./ljspeech_to_6097_no_mixing_5_mins`
  * Where we want to save our log files, tensorboard file, checkpoints, and more.

* `+init_from_nemo_model=./tts_en_fastpitch_align.nemo`
  * We tell the script what checkpoint to finetune from.

* `+trainer.max_steps=1000 ~trainer.max_epochs trainer.check_val_every_n_epoch=25`
  * For this experiment, we tell the script to train for 1000 training steps/iterations rather than specifying a number of epochs to run. Since the config file specifies `max_epochs` instead, we need to remove that using `~trainer.max_epochs`.

* `model.train_ds.dataloader_params.batch_size=24 model.validation_ds.dataloader_params.batch_size=24`
  * Set batch sizes for the training and validation data loaders.

* `model.n_speakers=1`
  * The number of speakers in the data. There is only 1 for now, but we will revisit this parameter later in the notebook.

* `model.pitch_mean=121.9 model.pitch_std=23.1 model.pitch_fmin=30 model.pitch_fmax=512`
  * For the new speaker, we need to define new pitch hyperparameters for better audio quality.
  * These parameters work for speaker 6097 from the Hi-Fi TTS dataset.
  * For speaker 92, we suggest `model.pitch_mean=214.5 model.pitch_std=30.9 model.pitch_fmin=80 model.pitch_fmax=512`.
  * fmin and fmax are hyperparameters to librosa's pyin function. We recommend tweaking these per speaker.
  * After fmin and fmax are defined, pitch mean and std can be easily extracted.

* `model.optim.lr=2e-4 ~model.optim.sched model.optim.name=adam`
  * For fine-tuning, we lower the learning rate.
  * We use a fixed learning rate of 2e-4.
  * We switch from the lamb optimizer to the adam optimizer.

* `trainer.devices=1 trainer.strategy=null`
  * For this notebook, we default to 1 gpu which means that we do not need ddp.
  * If you have the compute resources, feel free to scale this up to the number of free gpus you have available.
  * Please remove the `trainer.strategy=null` section if you intend on multi-gpu training.

## Synthesize Samples from Finetuned Checkpoints

Once we have finetuned our FastPitch model, we can synthesize the audio samples for given text using the following inference steps. We use a HiFi-GAN vocoder trained on LJSpeech.

We define some helper functions as well.

In [None]:
from nemo.collections.tts.models import HifiGanModel
from nemo.collections.tts.models import FastPitchModel

we will load a pretrained model to generate the voice from the spectrogram (we will later fine tune this model)

In [None]:
vocoder = HifiGanModel.from_pretrained("tts_hifigan")
vocoder = vocoder.eval().cuda();

We can grab the fine tuned models from the `wandb` artifact:

In [None]:
wandb.init(project="nemo", job_type="fastptich_validation")
artifact = wandb.use_artifact('capecape/nemo/model-2022-12-06_16-43-23:v0', type='model')
artifact_dir = artifact.download()

In [None]:
def ls(path): return list(Path(path).iterdir())

In [None]:
ls(artifact_dir)

In [None]:
def infer(spec_gen_model, vocoder_model, str_input, speaker=None):
    """
    Synthesizes spectrogram and audio from a text string given a spectrogram synthesis and vocoder model.
    
    Args:
        spec_gen_model: Spectrogram generator model (FastPitch in our case)
        vocoder_model: Vocoder model (HiFiGAN in our case)
        str_input: Text input for the synthesis
        speaker: Speaker ID
    
    Returns:
        spectrogram and waveform of the synthesized audio.
    """
    with torch.no_grad():
        parsed = spec_gen_model.parse(str_input)
        if speaker is not None:
            speaker = torch.tensor([speaker]).long().to(device=spec_gen_model.device)
        spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed, speaker=speaker)
        audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)
        
    if spectrogram is not None:
        if isinstance(spectrogram, torch.Tensor):
            spectrogram = spectrogram.to('cpu').numpy()
        if len(spectrogram.shape) == 3:
            spectrogram = spectrogram[0]
    if isinstance(audio, torch.Tensor):
        audio = audio.to('cpu').numpy()
    return spectrogram, audio

In [None]:
last_ckpt = str(ls(artifact_dir)[0])
print(last_ckpt)

spec_model = FastPitchModel.load_from_checkpoint(last_ckpt)
spec_model.eval().cuda();

### View results in the notebook

In [None]:
valid_df = pd.read_json(f"{SPEAKER_ID}_manifest_valid_local.json", lines=True)
valid_df

In [None]:
def generate_audio(text):
    "Generate MEL and Synth Audio"
    spec, audio = infer(spec_model, vocoder, text, speaker=speaker_id)
    return spec, audio.flatten()

In [None]:
for _, val_record in valid_df.iterrows():
    print("Real validation audio")
    ipd.display(ipd.Audio(val_record['audio_filepath'], rate=22050))
    print(f"SYNTHESIZED FOR -- Speaker: {new_speaker_id} | Dataset size: {duration_mins} mins | Mixing:{mixing} | Text: {val_record['text']}")
    spec, audio = generate_audio(val_record["text"])
    ipd.display(ipd.Audio(audio, rate=22050))
    %matplotlib inline
    imshow(spec, origin="lower", aspect="auto")
    plt.show()

### The same on a wandb.Table

In [None]:
table = wandb.Table(columns=['Text', 'Real validation audio', f'Audio Speaker {new_speaker_id}', 'Spec'])

sample_rate=22050

for _, val_record in valid_df.iterrows():
    speaker_spec, speaker_audio = infer(spec_model, vocoder, val_record['text'], speaker=speaker_id)
    row = [val_record["text_no_preprocessing"],
           wandb.Audio(val_record['audio_filepath'], sample_rate=sample_rate), 
           wandb.Audio(speaker_audio.flatten(), sample_rate=sample_rate),
           wandb.Image(speaker_spec)]
    table.add_data(*row)

wandb.log({"table": table})

In [None]:
wandb.finish()