Transformers in Text-to-Speech

=====================================================

## Introduction
---------------

Transformers have revolutionized the field of natural language processing (NLP) with their ability to handle sequential data efficiently. In recent years, researchers have applied Transformer architecture to text-to-speech (TTS) synthesis, achieving impressive results. This project focuses on the application and implementation of Transformers in TTS systems.

## Background
-------------

Traditional TTS systems relied on concatenative synthesis, statistical parametric synthesis, and waveform modeling. However, these approaches have limitations, such as requiring large amounts of data and being prone to over-smoothing. The introduction of deep learning techniques, particularly Transformers, has improved TTS performance significantly.

## Transformer-based TTS
-------------------------

The Transformer architecture, introduced by Vaswani et al. (2017) [1], is well-suited for sequence-to-sequence tasks, making it an ideal choice for TTS. The self-attention mechanism in Transformers enables the model to capture long-range dependencies in input sequences, which is essential for generating coherent and natural-sounding speech.

One of the pioneering works in Transformer-based TTS is the Transformer TTS (T-TTS) model proposed by Li et al. (2019) [2]. This model utilizes a Transformer encoder to process input text and a decoder to generate mel-spectrograms. The authors demonstrated that T-TTS outperforms traditional TTS systems in terms of naturalness and intelligibility.

## Advancements and Variants
-----------------------------

Several variants of Transformer-based TTS models have been proposed to improve performance and efficiency, we will extensively use part of Fastspeech from Speechbrain and HiFi-GAN for our vocoder:

### FastSpeech

Ren et al. (2020) [3] proposed FastSpeech, a parallelizable, lightweight architecture that reduces computational complexity while maintaining performance.

### HiFi-GAN

Kong et al. (2020) [4] introduced HiFi-GAN, which employs a Generative Adversarial Network (GAN) to improve the quality of generated speech, achieving state-of-the-art results.

### Conformer

Gulati et al. (2020) [5] proposed Conformer, an architecture that integrates convolutional and self-attention mechanisms to capture both local and global dependencies in input sequences.

## Challenges and Future Directions
---------------------------------

Despite the success of Transformer-based TTS models, there are still challenges to be addressed:

### Over-smoothing

Transformers can suffer from over-smoothing, leading to a lack of expressiveness in generated speech.

### Data scarcity

Limited availability of high-quality, diverse speech datasets hinders the training of robust TTS models.

### Multimodal fusion

Integrating visual and linguistic information to generate more realistic and engaging speech synthesis.

## Technical Reference
----------
All the Technical reference, inspiration and guidance had been majorly implemted with the help of speechbrain project[6] and Trasnformer-TTS by Soobinseo[7].

## Conclusion
----------

Transformer-based TTS models have revolutionized the field of speech synthesis, offering improved performance, efficiency, and flexibility. The proposed project aims to implement a Transformer TTS model using SpeechBrain, a popular open-source toolkit for speech and language processing. By leveraging the strengths of Transformers, the project aims to create a high-quality TTS system that can generate natural-sounding speech.

## References
--------------

[1] Vaswani et al. (2017). Attention is All You Need. In Advances in Neural Information Processing Systems (NIPS 2017).

[2] Li et al. (2019). Neural Speech Synthesis with Transformer Network. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics.

[3] Ren et al. (2020). FastSpeech: Fast, Robust and Controllable Text to Speech. In Advances in Neural Information Processing Systems (NIPS 2020).

[4] Kong et al. (2020). HiFi-GAN: Generative Adversarial Networks for Hi-Fi Speech Synthesis. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing.

[5] Gulati et al. (2020). Conformer: Convolution-augmented Transformer for Speech Synthesis. In Proceedings of the 2020 IEEE International Conference on Acoustics, Speech and Signal Processing.

[6] M. Ravanelli et al., ‘SpeechBrain: A General-Purpose Speech Toolkit’, arXiv [eess.AS]. 2021.

[7] Soobinseo. (n.d.). GitHub - soobinseo/Transformer-TTS: A Pytorch Implementation of “Neural Speech Synthesis with Transformer Network.” GitHub.

##Project Begning - Training

The code cell below consists of two commands using pip and gdown:

1. `!pip install --upgrade --no-cache-dir gdown`: This command uses pip to install or upgrade the `gdown` package without caching any downloaded files. The `--upgrade` flag ensures that if `gdown` is already installed, it will be upgraded to the latest version. `--no-cache-dir` flag disables caching of downloaded files, which can save disk space.
   
2. `!gdown 1u28CGvLBQAVj4oHqe7l4oPNvsswAIFO3`: This command uses `gdown` to download a file from Google Drive. The file is identified by the ID `1u28CGvLBQAVj4oHqe7l4oPNvsswAIFO3`. This specific file appears to be related to the LibriSpeech dataset.

3. `%%capture`: This magic command captures the output of the cell and prevents it from being displayed in the notebook. It's often used when you don't want to display the output of a specific cell.

4. `!unzip LJSpeech-1.1.zip -d data`: This shell command unzips the file `LJSpeech-1.1.zip` and extracts its contents into a directory named `data`. The `-d` flag specifies the destination directory for the extracted files.

In [1]:
!pip install speechbrain
!pip install tgt
!pip install unidecode
!pip install hyperpyyaml

Collecting speechbrain
  Downloading speechbrain-1.0.0-py3-none-any.whl (760 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/760.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.1/760.1 kB[0m [31m5.1 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━[0m [32m583.7/760.1 kB[0m [31m8.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m760.1/760.1 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting hyperpyyaml (from speechbrain)
  Downloading HyperPyYAML-1.2.2-py3-none-any.whl (16 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.9->speechbrain)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.9->speechbrain)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none

In [2]:
!pip install --upgrade --no-cache-dir gdown
!gdown 1u28CGvLBQAVj4oHqe7l4oPNvsswAIFO3 ## Librispeech subset of Dataset to speedily check things

Downloading...
From (original): https://drive.google.com/uc?id=1u28CGvLBQAVj4oHqe7l4oPNvsswAIFO3
From (redirected): https://drive.google.com/uc?id=1u28CGvLBQAVj4oHqe7l4oPNvsswAIFO3&confirm=t&uuid=8b936d17-dfc6-4a77-930e-90516e3bedf8
To: /content/LJSpeech-1.1.zip
100% 415M/415M [00:04<00:00, 103MB/s] 


In [3]:
%%capture
!unzip LJSpeech-1.1.zip -d data

(Inspired by Speechbrain Fastspeech)

This Python code is used to prepare the LJ Speech dataset for use in training models for speech synthesis tasks. It involves tasks such as:

1. Splitting the dataset into training, validation, and test sets.
2. Generating JSON files containing information about audio files, their corresponding transcriptions, and optionally, additional data like phoneme alignments and pitch information.
3. Optionally computing phoneme alignments and pitch values for models that require such data, like FastSpeech2 but we will not be using the parts and for their generation.
4. Cleaning and preprocessing text data.
5. Ensuring reproducibility by setting random seeds and checking if the data preparation phase has been completed before.
6. Logging progress and errors during the preparation process.

In [7]:
import os
import csv
import json
import random
import logging
import torch
import torchaudio
import numpy as np
from tqdm import tqdm
from speechbrain.utils.data_utils import download_file
from speechbrain.dataio.dataio import load_pkl, save_pkl
import tgt
from speechbrain.inference.text import GraphemeToPhoneme
import re
from unidecode import unidecode
from speechbrain.utils.text_to_sequence import _g2p_keep_punctuations


logger = logging.getLogger(__name__)
OPT_FILE = "opt_ljspeech_prepare.pkl"
METADATA_CSV = "metadata.csv"
TRAIN_JSON = "train.json"
VALID_JSON = "valid.json"
TEST_JSON = "test.json"
WAVS = "wavs"
DURATIONS = "durations"

def prepare_ljspeech(
    data_folder,
    save_folder,
    splits=["train", "valid"],
    split_ratio=[90, 10],
    model_name=None,
    seed=1234,
    pitch_n_fft=1024,
    pitch_hop_length=256,
    pitch_min_f0=65,
    pitch_max_f0=400,
    skip_prep=False,
    use_custom_cleaner=False,
    device="cpu",
):

    random.seed(seed)

    if skip_prep:
        return

    conf = {
        "data_folder": data_folder,
        "splits": splits,
        "split_ratio": split_ratio,
        "save_folder": save_folder,
        "seed": seed,
    }
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    meta_csv = os.path.join(data_folder, METADATA_CSV)
    wavs_folder = os.path.join(data_folder, WAVS)

    save_opt = os.path.join(save_folder, OPT_FILE)
    save_json_train = os.path.join(save_folder, TRAIN_JSON)
    save_json_valid = os.path.join(save_folder, VALID_JSON)
    save_json_test = os.path.join(save_folder, TEST_JSON)

    phoneme_alignments_folder = None
    duration_folder = None
    pitch_folder = None
    if model_name is not None and "FastSpeech2" in model_name:
        alignment_URL = (
            "https://www.dropbox.com/s/v28x5ldqqa288pu/LJSpeech.zip?dl=1"
        )
        phoneme_alignments_folder = os.path.join(
            data_folder, "TextGrid", "LJSpeech"
        )
        download_file(
            alignment_URL, data_folder + "/alignments.zip", unpack=True
        )

        duration_folder = os.path.join(data_folder, "durations")
        if not os.path.exists(duration_folder):
            os.makedirs(duration_folder)

        pitch_folder = os.path.join(data_folder, "pitch")
        if not os.path.exists(pitch_folder):
            os.makedirs(pitch_folder)

    if skip(splits, save_folder, conf):
        logger.info("Skipping preparation, completed in previous run.")
        return

    assert os.path.exists(meta_csv), "metadata.csv does not exist"
    assert os.path.exists(wavs_folder), "wavs/ folder does not exist"

    msg = "Creating json file for ljspeech Dataset.."
    logger.info(msg)
    data_split, meta_csv = split_sets(data_folder, splits, split_ratio)

    if "train" in splits:
        prepare_json(
            model_name,
            data_split["train"],
            save_json_train,
            wavs_folder,
            meta_csv,
            phoneme_alignments_folder,
            duration_folder,
            pitch_folder,
            pitch_n_fft,
            pitch_hop_length,
            pitch_min_f0,
            pitch_max_f0,
            use_custom_cleaner,
            device,
        )
    if "valid" in splits:
        prepare_json(
            model_name,
            data_split["valid"],
            save_json_valid,
            wavs_folder,
            meta_csv,
            phoneme_alignments_folder,
            duration_folder,
            pitch_folder,
            pitch_n_fft,
            pitch_hop_length,
            pitch_min_f0,
            pitch_max_f0,
            use_custom_cleaner,
            device,
        )
    if "test" in splits:
        prepare_json(
            model_name,
            data_split["test"],
            save_json_test,
            wavs_folder,
            meta_csv,
            phoneme_alignments_folder,
            duration_folder,
            pitch_folder,
            pitch_n_fft,
            pitch_hop_length,
            pitch_min_f0,
            pitch_max_f0,
            use_custom_cleaner,
            device,
        )
    save_pkl(conf, save_opt)


def skip(splits, save_folder, conf):

    skip = True

    split_files = {
        "train": TRAIN_JSON,
        "valid": VALID_JSON,
        "test": TEST_JSON,
    }

    for split in splits:
        if not os.path.isfile(os.path.join(save_folder, split_files[split])):
            skip = False

    save_opt = os.path.join(save_folder, OPT_FILE)
    if skip is True:
        if os.path.isfile(save_opt):
            opts_old = load_pkl(save_opt)
            if opts_old == conf:
                skip = True
            else:
                skip = False
        else:
            skip = False
    return skip


def split_sets(data_folder, splits, split_ratio):

    meta_csv = os.path.join(data_folder, METADATA_CSV)
    with open(meta_csv, 'r', encoding='utf-8') as csvfile:
        csv_reader = csv.reader(
           csvfile, delimiter="|", quoting=csv.QUOTE_NONE
        )

        meta_csv = list(csv_reader)

        index_for_sessions = []
        session_id_start = "LJ001"
        index_this_session = []
        for i in range(len(meta_csv)):
            session_id = meta_csv[i][0].split("-")[0]
            if session_id == session_id_start:
                index_this_session.append(i)
                if i == len(meta_csv) - 1:
                    index_for_sessions.append(index_this_session)
            else:
                index_for_sessions.append(index_this_session)
                session_id_start = session_id
                index_this_session = [i]

        session_len = [len(session) for session in index_for_sessions]

        data_split = {}
        for i, split in enumerate(splits):
            data_split[split] = []
            for j in range(len(index_for_sessions)):
                if split == "train":
                    random.shuffle(index_for_sessions[j])
                    n_snts = int(session_len[j] * split_ratio[i] / sum(split_ratio))
                    data_split[split].extend(index_for_sessions[j][0:n_snts])
                    del index_for_sessions[j][0:n_snts]
                if split == "valid":
                    if "test" in splits:
                        random.shuffle(index_for_sessions[j])
                        n_snts = int(
                            session_len[j] * split_ratio[i] / sum(split_ratio)
                        )
                        data_split[split].extend(index_for_sessions[j][0:n_snts])
                        del index_for_sessions[j][0:n_snts]
                    else:
                        data_split[split].extend(index_for_sessions[j])
                if split == "test":
                    data_split[split].extend(index_for_sessions[j])

    return data_split, meta_csv


def prepare_json(
    model_name,
    seg_lst,
    json_file,
    wavs_folder,
    csv_reader,
    phoneme_alignments_folder,
    durations_folder,
    pitch_folder,
    pitch_n_fft,
    pitch_hop_length,
    pitch_min_f0,
    pitch_max_f0,
    use_custom_cleaner=False,
    device="cpu",
):

    logger.info(f"preparing {json_file}.")
    if model_name in ["Tacotron2", "FastSpeech2"]:
        logger.info(
            "Computing phonemes for LJSpeech labels using SpeechBrain G2P. This may take a while."
        )
        g2p = GraphemeToPhoneme.from_hparams(
            "speechbrain/soundchoice-g2p", run_opts={"device": device}
        )
    if model_name is not None and "FastSpeech2" in model_name:
        logger.info(
            "Computing pitch as required for FastSpeech2. This may take a while."
        )

    json_dict = {}
    for index in tqdm(seg_lst):
        id = list(csv_reader)[index][0]
        wav = os.path.join(wavs_folder, f"{id}.wav")
        label = list(csv_reader)[index][2]
        if use_custom_cleaner:
            label = custom_clean(label, model_name)

        json_dict[id] = {
            "uttid": id,
            "wav": wav,
            "label": label,
            "segment": True if "train" in json_file else False,
        }

        if model_name == "FastSpeech2":
            audio, fs = torchaudio.load(wav)

            textgrid_path = os.path.join(
                phoneme_alignments_folder, f"{id}.TextGrid"
            )
            textgrid = tgt.io.read_textgrid(
                textgrid_path, include_empty_intervals=True
            )

            last_phoneme_flags = get_last_phoneme_info(
                textgrid.get_tier_by_name("words"),
                textgrid.get_tier_by_name("phones"),
            )
            (
                phonemes,
                duration,
                start,
                end,
                trimmed_last_phoneme_flags,
            ) = get_alignment(
                textgrid.get_tier_by_name("phones"),
                fs,
                pitch_hop_length,
                last_phoneme_flags,
            )

            label_phoneme = " ".join(phonemes)
            spn_labels = [0] * len(phonemes)
            for i in range(1, len(phonemes)):
                if phonemes[i] == "spn":
                    spn_labels[i - 1] = 1
            if start >= end:
                print(f"Skipping {id}")
                continue

            duration_file_path = os.path.join(durations_folder, f"{id}.npy")
            np.save(duration_file_path, duration)


            json_dict[id].update({"label_phoneme": label_phoneme})
            json_dict[id].update({"spn_labels": spn_labels})
            json_dict[id].update({"start": start})
            json_dict[id].update({"end": end})
            json_dict[id].update({"durations": duration_file_path})
            json_dict[id].update(
                {"last_phoneme_flags": trimmed_last_phoneme_flags}
            )

    with open(json_file, mode="w") as json_f:
        json.dump(json_dict, json_f, indent=2)

    logger.info(f"{json_file} successfully created!")


def get_alignment(tier, sampling_rate, hop_length, last_phoneme_flags):

    sil_phones = ["sil", "sp", "spn", ""]

    phonemes = []
    durations = []
    start_time = 0
    end_time = 0
    end_idx = 0
    trimmed_last_phoneme_flags = []

    flag_iter = iter(last_phoneme_flags)

    for t in tier._objects:
        s, e, p = t.start_time, t.end_time, t.text
        current_flag = next(flag_iter)
        if phonemes == []:
            if p in sil_phones:
                continue
            else:
                start_time = s

        if p not in sil_phones:
            if p[-1].isdigit():
                phonemes.append(p[:-1])
            else:
                phonemes.append(p)
            trimmed_last_phoneme_flags.append(current_flag[1])
            end_time = e
            end_idx = len(phonemes)
        else:
            phonemes.append("spn")
            trimmed_last_phoneme_flags.append(current_flag[1])

        durations.append(
            int(
                np.round(e * sampling_rate / hop_length)
                - np.round(s * sampling_rate / hop_length)
            )
        )

    phonemes = phonemes[:end_idx]
    durations = durations[:end_idx]

    return phonemes, durations, start_time, end_time, trimmed_last_phoneme_flags


def get_last_phoneme_info(words_seq, phones_seq):

    phoneme_objects = phones_seq._objects
    phoneme_iter = iter(phoneme_objects)

    last_phoneme_flags = list()

    for word_obj in words_seq._objects:
        word_end_time = word_obj.end_time

        current_phoneme = next(phoneme_iter, None)
        while current_phoneme:
            phoneme_end_time = current_phoneme.end_time
            if phoneme_end_time == word_end_time:
                last_phoneme_flags.append((current_phoneme.text, 1))
                break
            else:
                last_phoneme_flags.append((current_phoneme.text, 0))
            current_phoneme = next(phoneme_iter, None)

    return last_phoneme_flags


def custom_clean(text, model_name):

    _abbreviations = [
        (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
        for x in [
            ("mrs", "missus"),
            ("mr", "mister"),
            ("dr", "doctor"),
            ("st", "saint"),
            ("co", "company"),
            ("jr", "junior"),
            ("maj", "major"),
            ("gen", "general"),
            ("drs", "doctors"),
            ("rev", "reverend"),
            ("lt", "lieutenant"),
            ("hon", "honorable"),
            ("sgt", "sergeant"),
            ("capt", "captain"),
            ("esq", "esquire"),
            ("ltd", "limited"),
            ("col", "colonel"),
            ("ft", "fort"),
        ]
    ]
    text = unidecode(text.lower())
    if model_name != "FastSpeech2WithAlignment":
        text = re.sub("[:;]", " - ", text)
        text = re.sub(r'[)(\[\]"]', " ", text)
        text = text.strip().strip().strip("-")

    text = re.sub(" +", " ", text)
    for regex, replacement in _abbreviations:
        text = re.sub(regex, replacement, text)
    return text

Trigger the above code block to prepare Train and Valid split for LJSPEECH dataset at 90 and 10

In [8]:
prepare_ljspeech( save_folder = r'results/save', data_folder= r'data/LJSpeech-1.1', splits=["train", "valid"], model_name = 'FastSpeech2', split_ratio=[90, 10], seed=1234, skip_prep=False)

Downloading https://www.dropbox.com/s/v28x5ldqqa288pu/LJSpeech.zip?dl=1 to data/LJSpeech-1.1/alignments.zip


LJSpeech.zip?dl=1: 18.1MB [00:01, 14.9MB/s]                            


Extracting data/LJSpeech-1.1/alignments.zip to data/LJSpeech-1.1


hyperparams.yaml:   0%|          | 0.00/11.3k [00:00<?, ?B/s]

model.ckpt:   0%|          | 0.00/129M [00:00<?, ?B/s]

ctc_lin.ckpt:   0%|          | 0.00/177k [00:00<?, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

100%|██████████| 167/167 [00:03<00:00, 49.60it/s]
100%|██████████| 19/19 [00:00<00:00, 76.15it/s]


## TransformersTTS.py

The `TransformersTTS.py` script implements a Transformer-based Text-to-Speech (TTS) model using PyTorch. The model consists of an encoder and a decoder, each composed of several components.

1. **Encoder**:
   - The encoder takes input characters (text) and converts them into a sequence of embeddings.
   - These embeddings are then augmented with positional encodings and processed through multiple layers of self-attention mechanisms and feed-forward neural networks.
   - The encoder outputs a contextual representation of the input text, which captures its semantic information.

2. **Decoder**:
   - The decoder takes the mel-spectrogram features (acoustic features) as input along with the positional encodings.
   - Similar to the encoder, the decoder also employs self-attention mechanisms and feed-forward neural networks.
   - The decoder generates mel-spectrogram predictions step by step, conditioned on the input text representation produced by the encoder.
   - Additionally, the decoder predicts stop tokens to indicate the end of the mel-spectrogram sequence generation.

3. **Additional Components**:
   - The code includes modules for linear transformations, convolutions, multi-head attention mechanisms, and positional embeddings.
   - These components are essential for building the layers of the Transformer architecture and ensuring proper information flow between encoder and decoder.

4. **Model Class**:
   - The `Model` class encapsulates the entire Transformer TTS architecture by combining the encoder and decoder modules.
   - During forward pass, it takes input characters and mel-spectrogram features, processes them through the encoder and decoder respectively, and returns the predicted mel-spectrogram output along with attention probabilities and stop predictions.

The implementation of **Transformer TTS** owes much to the work of **Soobinseo** [7], whose GitHub repository served as a valuable reference. While our approach differs significantly in some aspects—for instance, we utilize phonemes instead of characters and opt for HIFIGan over WaveNet due to its superior performance and generalizability—the foundation laid by Soobinseo's work remains pivotal to our project.

In [9]:
%%file TransformersTTS.py

import torch.nn as nn
import torch as t
import torch.nn.functional as F
import math
import numpy as np
import copy
from collections import OrderedDict
from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding


def get_sinusoid_encoding_table(seq_len, hidden_dim, padding_idx=None):
    position = t.arange(0, seq_len, dtype=t.float).unsqueeze(1)
    div_term = t.exp(t.arange(0, hidden_dim, 2).float() * (-t.log(t.tensor(10000.0)) / hidden_dim))
    sinusoid_inp = t.arange(0, hidden_dim, 2).float() / hidden_dim
    sinusoid_inp = t.sin(sinusoid_inp * t.tensor([math.pi]))
    position_encoding = t.zeros(seq_len, hidden_dim)
    position_encoding[:, 0::2] = t.sin(position * div_term)
    position_encoding[:, 1::2] = t.cos(position * div_term)
    if padding_idx is not None:
        position_encoding[padding_idx] = 0.
    return position_encoding


def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class Linear(nn.Module):
    """
    Linear Module
    """
    def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
        """
        :param in_dim: dimension of input
        :param out_dim: dimension of output
        :param bias: boolean. if True, bias is included.
        :param w_init: str. weight inits with xavier initialization.
        """
        super(Linear, self).__init__()
        self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)

        nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=nn.init.calculate_gain(w_init))

    def forward(self, x):
        return self.linear_layer(x)


class Conv(nn.Module):
    """
    Convolution Module
    """
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=0, dilation=1, bias=True, w_init='linear'):
        """
        :param in_channels: dimension of input
        :param out_channels: dimension of output
        :param kernel_size: size of kernel
        :param stride: size of stride
        :param padding: size of padding
        :param dilation: dilation rate
        :param bias: boolean. if True, bias is included.
        :param w_init: str. weight inits with xavier initialization.
        """
        super(Conv, self).__init__()

        self.conv = nn.Conv1d(in_channels, out_channels,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation,
                              bias=bias)

        nn.init.xavier_uniform_(
            self.conv.weight, gain=nn.init.calculate_gain(w_init))

    def forward(self, x):
        x = self.conv(x)
        return x

class PostConvNet(nn.Module):
    """
    Post Convolutional Network (mel --> mel)
    """
    def __init__(self, num_hidden):
        """

        :param num_hidden: dimension of hidden
        """
        self.num_mels = 80
        self.outputs_per_step = 1

        super(PostConvNet, self).__init__()
        self.conv1 = Conv(in_channels=self.num_mels * self.outputs_per_step,
                          out_channels=num_hidden,
                          kernel_size=5,
                          padding=4,
                          w_init='tanh')
        self.conv_list = clones(Conv(in_channels=num_hidden,
                                     out_channels=num_hidden,
                                     kernel_size=5,
                                     padding=4,
                                     w_init='tanh'), 3)
        self.conv2 = Conv(in_channels=num_hidden,
                          out_channels=self.num_mels * self.outputs_per_step,
                          kernel_size=5,
                          padding=4)

        self.batch_norm_list = clones(nn.BatchNorm1d(num_hidden), 3)
        self.pre_batchnorm = nn.BatchNorm1d(num_hidden)

        self.dropout1 = nn.Dropout(p=0.1)
        self.dropout_list = nn.ModuleList([nn.Dropout(p=0.1) for _ in range(3)])

    def forward(self, input_, mask=None):
        input_ = self.dropout1(t.tanh(self.pre_batchnorm(self.conv1(input_)[:, :, :-4])))
        for batch_norm, conv, dropout in zip(self.batch_norm_list, self.conv_list, self.dropout_list):
            input_ = dropout(t.tanh(batch_norm(conv(input_)[:, :, :-4])))
        input_ = self.conv2(input_)[:, :, :-4]
        return input_

class FFN(nn.Module):
    """
    Positionwise Feed-Forward Network
    """

    def __init__(self, num_hidden):
        """
        :param num_hidden: dimension of hidden
        """
        super(FFN, self).__init__()
        self.w_1 = Conv(num_hidden, num_hidden * 4, kernel_size=1, w_init='relu')
        self.w_2 = Conv(num_hidden * 4, num_hidden, kernel_size=1)
        self.dropout = nn.Dropout(p=0.1)
        self.layer_norm = nn.LayerNorm(num_hidden)

    def forward(self, input_):
        x = input_.transpose(1, 2)
        x = self.w_2(t.relu(self.w_1(x)))
        x = x.transpose(1, 2)

        x = x + input_

        x = self.layer_norm(x)

        return x

class Attention(nn.Module):
    """
    Attention Network
    """
    def __init__(self, num_hidden, h=4):
        """
        :param num_hidden: dimension of hidden
        :param h: num of heads
        """
        super(Attention, self).__init__()

        self.num_hidden = num_hidden
        self.num_hidden_per_attn = num_hidden // h
        self.h = h

        self.key = Linear(num_hidden, num_hidden, bias=False)
        self.value = Linear(num_hidden, num_hidden, bias=False)
        self.query = Linear(num_hidden, num_hidden, bias=False)

        self.multihead = MultiheadAttention(self.num_hidden_per_attn)

        self.residual_dropout = nn.Dropout(p=0.1)

        self.final_linear = Linear(num_hidden * 2, num_hidden)

        self.layer_norm_1 = nn.LayerNorm(num_hidden)

    def forward(self, memory, decoder_input, mask=None, query_mask=None):

        batch_size = memory.size(0)
        seq_k = memory.size(1)
        seq_q = decoder_input.size(1)

        if query_mask is not None:
            query_mask = query_mask.unsqueeze(-1).repeat(1, 1, seq_k)
            query_mask = query_mask.repeat(self.h, 1, 1)
        if mask is not None:
            mask = mask.repeat(self.h, 1, 1)

        key = self.key(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn)
        value = self.value(memory).view(batch_size, seq_k, self.h, self.num_hidden_per_attn)
        query = self.query(decoder_input).view(batch_size, seq_q, self.h, self.num_hidden_per_attn)

        key = key.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn)
        value = value.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn)
        query = query.permute(2, 0, 1, 3).contiguous().view(-1, seq_q, self.num_hidden_per_attn)

        result, attns = self.multihead(key, value, query, mask=mask, query_mask=query_mask)

        result = result.view(self.h, batch_size, seq_q, self.num_hidden_per_attn)
        result = result.permute(1, 2, 0, 3).contiguous().view(batch_size, seq_q, -1)

        result = t.cat([decoder_input, result], dim=-1)

        result = self.final_linear(result)

        result = result + decoder_input

        result = self.layer_norm_1(result)

        return result, attns

class MultiheadAttention(nn.Module):
    """
    Multihead attention mechanism (dot attention)
    """
    def __init__(self, num_hidden_k):
        """
        :param num_hidden_k: dimension of hidden
        """
        super(MultiheadAttention, self).__init__()

        self.num_hidden_k = num_hidden_k
        self.attn_dropout = nn.Dropout(p=0.1)

    def forward(self, key, value, query, mask=None, query_mask=None):

        attn = t.bmm(query, key.transpose(1, 2))
        attn = attn / math.sqrt(self.num_hidden_k)

        if mask is not None:
            attn = attn.masked_fill(mask, -2 ** 32 + 1)
            attn = t.softmax(attn, dim=-1)
        else:
            attn = t.softmax(attn, dim=-1)

        if query_mask is not None:
            attn = attn * query_mask

        result = t.bmm(attn, value)

        return result, attn



class Prenet(nn.Module):
    """
    Prenet before passing through the network
    """
    def __init__(self, input_size, hidden_size, output_size, p=0.5):
        """
        :param input_size: dimension of input
        :param hidden_size: dimension of hidden unit
        :param output_size: dimension of output
        """
        super(Prenet, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.layer = nn.Sequential(OrderedDict([
             ('fc1', Linear(self.input_size, self.hidden_size)),
             ('relu1', nn.ReLU()),
             ('dropout1', nn.Dropout(p)),
             ('fc2', Linear(self.hidden_size, self.output_size)),
             ('relu2', nn.ReLU()),
             ('dropout2', nn.Dropout(p)),
        ]))

    def forward(self, input_):

        out = self.layer(input_)

        return out


class EncoderPrenet(nn.Module):
    """
    Pre-network for Encoder consists of convolution networks.
    """
    def __init__(self, embedding_size, num_hidden):
        super(EncoderPrenet, self).__init__()
        self.embedding_size = embedding_size
        self.embed = nn.Embedding(42, embedding_size, padding_idx=0)

        self.conv1 = Conv(in_channels=embedding_size,
                          out_channels=num_hidden,
                          kernel_size=5,
                          padding=int(np.floor(5 / 2)),
                          w_init='relu')
        self.conv2 = Conv(in_channels=num_hidden,
                          out_channels=num_hidden,
                          kernel_size=5,
                          padding=int(np.floor(5 / 2)),
                          w_init='relu')

        self.conv3 = Conv(in_channels=num_hidden,
                          out_channels=num_hidden,
                          kernel_size=5,
                          padding=int(np.floor(5 / 2)),
                          w_init='relu')

        self.batch_norm1 = nn.BatchNorm1d(num_hidden)
        self.batch_norm2 = nn.BatchNorm1d(num_hidden)
        self.batch_norm3 = nn.BatchNorm1d(num_hidden)

        self.dropout1 = nn.Dropout(p=0.2)
        self.dropout2 = nn.Dropout(p=0.2)
        self.dropout3 = nn.Dropout(p=0.2)
        self.projection = Linear(num_hidden, num_hidden)

    def forward(self, input_):
        input_ = self.embed(input_)
        input_ = input_.transpose(1, 2)
        input_ = self.dropout1(t.relu(self.batch_norm1(self.conv1(input_))))
        input_ = self.dropout2(t.relu(self.batch_norm2(self.conv2(input_))))
        input_ = self.dropout3(t.relu(self.batch_norm3(self.conv3(input_))))
        input_ = input_.transpose(1, 2)
        input_ = self.projection(input_)

        return input_


class Encoder(nn.Module):
    """
    Encoder Network
    """
    def __init__(self, embedding_size, num_hidden):
        """
        :param embedding_size: dimension of embedding
        :param num_hidden: dimension of hidden
        """
        super(Encoder, self).__init__()
        self.enc_d_model = 384
        self.alpha = nn.Parameter(t.ones(1))
        self.num_hidden = num_hidden
        self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0),
                                                    freeze=True)
        self.pos_dropout = nn.Dropout(p=0.1)
        self.encoder_prenet = EncoderPrenet(embedding_size, num_hidden)
        self.layers = clones(Attention(num_hidden), 3)
        self.ffns = clones(FFN(num_hidden), 3)


    def forward(self, x, pos):

        if self.training:
            c_mask = pos.ne(0).type(t.float)
            mask = pos.eq(0).unsqueeze(1).repeat(1, x.size(1), 1)

        else:
            c_mask, mask = None, None

        x = self.encoder_prenet(x)


        pos = self.pos_emb(pos)

        x = pos * self.alpha + x

        x = self.pos_dropout(x)

        attns = list()
        for layer, ffn in zip(self.layers, self.ffns):
            x, attn = layer(x, x, mask=mask, query_mask=c_mask)
            x = ffn(x)
            attns.append(attn)

        return x, c_mask, attns


class MelDecoder(nn.Module):
    """
    Decoder Network
    """
    def __init__(self, num_hidden):
        """
        :param num_hidden: dimension of hidden
        """
        self.num_mels = 80
        self.outputs_per_step = 1
        super(MelDecoder, self).__init__()
        self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(1024, num_hidden, padding_idx=0),
                                                    freeze=True)
        self.pos_dropout = nn.Dropout(p=0.1)
        self.alpha = nn.Parameter(t.ones(1))
        self.decoder_prenet = Prenet(self.num_mels, num_hidden * 2, num_hidden, p=0.2)
        self.norm = Linear(num_hidden, num_hidden)

        self.selfattn_layers = clones(Attention(num_hidden), 3)
        self.dotattn_layers = clones(Attention(num_hidden), 3)
        self.ffns = clones(FFN(num_hidden), 3)
        self.mel_linear = Linear(num_hidden, self.num_mels * self.outputs_per_step)
        self.stop_linear = Linear(num_hidden, 1, w_init='sigmoid')   #80

        self.postconvnet = PostConvNet(num_hidden)

    def forward(self, memory, decoder_input, c_mask, pos):
        batch_size = memory.size(0)
        decoder_len = decoder_input.size(1)

        if self.training:
            m_mask = pos.ne(0).type(t.float)
            mask = m_mask.eq(0).unsqueeze(1).repeat(1, decoder_len, 1)
            if next(self.parameters()).is_cuda:
                mask = mask + t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte()
            else:
                mask = mask + t.triu(t.ones(decoder_len, decoder_len), diagonal=1).repeat(batch_size, 1, 1).byte()
            mask = mask.gt(0)
            zero_mask = c_mask.eq(0).unsqueeze(-1).repeat(1, 1, decoder_len)
            zero_mask = zero_mask.transpose(1, 2)
        else:
            if next(self.parameters()).is_cuda:
                mask = t.triu(t.ones(decoder_len, decoder_len).cuda(), diagonal=1).repeat(batch_size, 1, 1).byte()
            else:
                mask = t.triu(t.ones(decoder_len, decoder_len), diagonal=1).repeat(batch_size, 1, 1).byte()
            mask = mask.gt(0)
            m_mask, zero_mask = None, None

        decoder_input = self.decoder_prenet(decoder_input)

        decoder_input = self.norm(decoder_input)

        pos = self.pos_emb(pos)
        decoder_input = pos * self.alpha + decoder_input

        decoder_input = self.pos_dropout(decoder_input)

        attn_dot_list = list()
        attn_dec_list = list()

        for selfattn, dotattn, ffn in zip(self.selfattn_layers, self.dotattn_layers, self.ffns):
            decoder_input, attn_dec = selfattn(decoder_input, decoder_input, mask=mask, query_mask=m_mask)
            decoder_input, attn_dot = dotattn(memory, decoder_input, mask=zero_mask, query_mask=m_mask)
            decoder_input = ffn(decoder_input)
            attn_dot_list.append(attn_dot)
            attn_dec_list.append(attn_dec)

        mel_out = self.mel_linear(decoder_input)

        postnet_input = mel_out.transpose(1, 2)
        out = self.postconvnet(postnet_input)
        out = postnet_input + out
        out = out.transpose(1, 2)

        stop_tokens = self.stop_linear(decoder_input)

        return mel_out, out, attn_dot_list, stop_tokens, attn_dec_list


class Model(nn.Module):
    """
    Transformer Network
    """
    def __init__(self):

        self.embedding_size = 512
        self.hidden_size = 256
        super(Model, self).__init__()
        self.encoder = Encoder(self.embedding_size, self.hidden_size)
        self.decoder = MelDecoder(self.hidden_size)

    def forward(self, characters, mel_input, pos_text, pos_mel):
        mel_input = mel_input.permute(0, 2, 1)
        memory, c_mask, attns_enc = self.encoder.forward(characters, pos=pos_text)

        mel_output, postnet_output, attn_probs, stop_preds, attns_dec = self.decoder.forward(memory, mel_input, c_mask,
                                                                                             pos=pos_mel)

        return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec

Writing TransformersTTS.py


The utils file defines a custom loss function class called Loss for training the FastSpeech2 model used in speech synthesis but we use it for our TTs Transformer and make it compatible with model and phonemes data.

It computes various loss components, including mel-spectrogram loss, postnet mel-spectrogram loss, SSIM loss, and gate loss, and combines them to calculate the total loss during training. These losses are weighted and aggregated to guide the model training process effectively.

I tried using non-gated version loss as well but it was not giving optimal results on stop token but indeed enhanced the quality rapidly.

It also houses a class called TextMelAlignmentCollator. It's designed to help organize batches of data during training or testing of a speech synthesis model.

When you use this collate function, it sorts the batch of data based on the lengths of phoneme sequences and mel spectrograms. Then, it pads the sequences with zeros to make sure they all have the same length.

After organizing the data, it returns everything in a structured way, including the padded sequences, input and output lengths, gate values, and additional information like waveforms and labels.

In [10]:
%%file utils.py

import torch
from speechbrain.nnet.losses import bce_loss
from torch.nn.modules.loss import _Loss
from speechbrain.lobes.models.FastSpeech2 import SSIMLoss
import numpy as np

class Loss(torch.nn.Module):
    def __init__(self, ssim_loss_weight, mel_loss_weight, postnet_mel_loss_weight, gate_loss_weight=1.0, gate_loss_max_epochs=8):
        super().__init__()
        self.l1_loss = torch.nn.L1Loss()
        self.ssim_loss = SSIMLoss()
        self.mse_loss = torch.nn.MSELoss()
        self.ssim_loss_weight = ssim_loss_weight
        self.mel_loss_weight = mel_loss_weight
        self.postnet_mel_loss_weight = postnet_mel_loss_weight
        self.gate_loss_weight = gate_loss_weight
        self.gate_loss_max_epochs = gate_loss_max_epochs

    def forward(self, mel_pred, mel, postnet_pred, gate_target, stop_preds, mel_lengths, mask, current_epoch):
        mel_losses = []
        postnet_mel_losses = []

        for i in range(mel.shape[0]):
            mel_loss = self.mse_loss(mel_pred[i, :mel_lengths[i], :], mel[i, :mel_lengths[i], :])
            postnet_mel_loss = self.mse_loss(postnet_pred[i, :mel_lengths[i], :], mel[i, :mel_lengths[i], :])
            mel_losses.append(mel_loss)
            postnet_mel_losses.append(postnet_mel_loss)

        ssim_loss = self.ssim_loss(mel_pred, mel, mel_lengths)
        mel_loss = sum(mel_losses) / len(mel)
        postnet_mel_loss = sum(postnet_mel_losses) / len(mel)

        gate_target_masked = gate_target.masked_select(mask)
        stop_preds_masked = stop_preds.squeeze(-1).masked_select(mask)

        gate_loss = bce_loss(stop_preds_masked, gate_target_masked, pos_weight=torch.tensor(5.0))

        total_loss = (ssim_loss * self.ssim_loss_weight) + (mel_loss * self.mel_loss_weight) + (postnet_mel_loss * self.postnet_mel_loss_weight) + (gate_loss * self.gate_loss_weight)

        loss_dict = {
            "total_loss": total_loss,
            "ssim_loss": ssim_loss * self.ssim_loss_weight,
            "mel_loss": mel_loss * self.mel_loss_weight,
            "postnet_mel_loss": postnet_mel_loss * self.postnet_mel_loss_weight,
            "gate_loss": gate_loss * self.gate_loss_weight,
        }

        return loss_dict

class TextMelAlignmentCollator:
    def __call__(self, batch_data):
        raw_batch = list(batch_data)
        for i, item in enumerate(batch_data):
            batch_data[i] = item["mel_text_pair"]

        text_lengths, sorted_text_indices = torch.sort(
            torch.LongTensor([len(x[0]) for x in batch_data]), dim=0, descending=True
        )
        mel_lengths, sorted_mel_indices = torch.sort(
            torch.LongTensor([x[1].size(1) for x in batch_data]), dim=0, descending=True
        )
        max_mel_len = mel_lengths[0]
        max_text_len = text_lengths[0]

        padded_text = torch.LongTensor(len(batch_data), max_text_len).zero_()
        for i, idx in enumerate(sorted_text_indices):
            text = batch_data[idx][0]
            padded_text[i, :text.size(0)] = text

        padded_mel_pos = torch.LongTensor(len(batch_data), max_mel_len).zero_()
        for i, idx in enumerate(sorted_mel_indices):
            mel_pos = batch_data[idx][4]
            padded_mel_pos[i, :mel_pos.size] = torch.LongTensor(mel_pos)

        padded_text_pos = torch.LongTensor(len(batch_data), max_text_len).zero_()
        for i, idx in enumerate(sorted_text_indices):
            text_pos = batch_data[idx][3]
            padded_text_pos[i, :text_pos.size] = torch.LongTensor(text_pos)

        num_mels, max_target_len = batch_data[0][1].size(0), max(x[1].size(1) for x in batch_data)
        padded_mels = torch.FloatTensor(len(batch_data), num_mels, max_target_len).zero_()
        padded_mel_inputs = torch.FloatTensor(len(batch_data), num_mels, max_target_len).zero_()
        padded_gates = torch.FloatTensor(len(batch_data), max_target_len).zero_()

        output_lengths = torch.LongTensor(len(batch_data))
        labels, wavs = [], []
        for i, idx in enumerate(sorted_mel_indices):
            mel = batch_data[idx][1]
            mel_input = torch.tensor(batch_data[idx][2])
            padded_mels[i, :, :mel.size(1)] = mel
            padded_mel_inputs[i, :, :mel_input.size(1)] = mel_input
            padded_gates[i, mel.size(1)-1:] = 1
            output_lengths[i] = mel.size(1)
            labels.append(raw_batch[idx]["label"])
            wavs.append(raw_batch[idx]["wav"])

        padded_mels = padded_mels.permute(0, 2, 1)
        return (
            padded_text,
            padded_mels,
            padded_mel_inputs,
            padded_text_pos,
            padded_mel_pos,
            text_lengths,
            output_lengths,
            padded_gates,
            wavs,
            labels,
        )

Writing utils.py


This file, holds all the important settings and values needed to train a speech synthesis model. I've tried to stick to the instructions from the research paper as much as possible. The paper suggested training for 1000 epochs, but I stopped at around 920 epochs, which still gave good results.

In this file, you'll find key parameters like the learning rate, how often the model saves its progress, and the number of examples it processes at once. These settings are crucial because they directly influence how well the model learns.

To train the model, I used a powerful A4000 GPU for about 50 hours on Paperspace. This GPU is really fast and helped speed up the training process, making it possible to reach our training goals effectively.

We started with less batch_size and interations to test initially on lower end GPU but increased the capacity on higher end GPU.

It also includes options for loading the data, keeping track of progress during training, and saving the trained model. Basically, it's a one-stop shop for setting up and running a TTS training session, giving you control over every aspect of the process. Most of the inspiration is taken from Fast speech hyper params and its model. we replace our model in place of the older one.

After experimenting with various types of mel spectrograms, including custom ones, I found that the mel spectrogram provided by FastSpeech consistently outperformed the others. To ensure optimal performance, I carefully adjusted the hyperparameters to accommodate the dimensions of the Transformer model. This involved extensive experimentation to properly align the dimensions and fine-tune the parameters for the best results.

In [11]:
%%file hparams_tts.yaml

seed: 1234
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
train_spn_predictor_epochs: 8
progress_samples: True
progress_sample_path: !ref <output_folder>/samples
progress_samples_min_run: 10
progress_samples_interval: 10
progress_batch_sample_size: 4

data_folder: #!PLACEHOLDER

train_json: !ref <save_folder>/train.json
valid_json: !ref <save_folder>/valid.json
test_json: !ref <save_folder>/test.json

splits: ["train", "valid"]
split_ratio: [90, 10]

skip_prep: False

sample_rate: 22050
hop_length: 256
win_length: null
n_mel_channels: 80
mel_fmin: 0.0
mel_fmax: 8000.0
power: 1
norm: "slaney"
mel_scale: "slaney"
dynamic_range_compression: True
mel_normalized: False
min_max_energy_norm: True
min_f0: 65
max_f0: 2093

#Main HyperParams
n_iter: 60
outputs_per_step: 1
epochs: 10000
lr : 0.001
save_step : 2000
image_step : 500
batch_size : 32

num_mels : 80
n_fft : 2048
sr : 22050
preemphasis : 0.97
frame_shift : 0.0125
frame_length : 0.05
# hop_length : int(sr*frame_shift)
# win_length : int(sr*frame_length)
n_mels : 80
# power : 1.2
min_level_db : -100
ref_level_db : 20
hidden_size : 256
embedding_size : 512
max_db : 100
ref_db : 20

cleaners : "english_cleaners"

data_path : "./dataset/LJSpeech-1.1"
checkpoint_path : "./checkpoint"
sample_path : "./samples"

learning_rate: 0.0001
weight_decay: 0.000001
max_grad_norm: 1.0
# batch_size: 32
num_workers_train: 32
num_workers_valid: 4
betas: [0.9, 0.98]

lexicon:
    - AA
    - AE
    - AH
    - AO
    - AW
    - AY
    - B
    - CH
    - D
    - DH
    - EH
    - ER
    - EY
    - F
    - G
    - HH
    - IH
    - IY
    - JH
    - K
    - L
    - M
    - N
    - NG
    - OW
    - OY
    - P
    - R
    - S
    - SH
    - T
    - TH
    - UH
    - UW
    - V
    - W
    - Y
    - Z
    - ZH
    - spn

n_symbols: 42
padding_idx: 0


enc_num_layers: 4
enc_num_head: 2
enc_d_model: 384
enc_ffn_dim: 1024
enc_k_dim: 384
enc_v_dim: 384
enc_dropout: 0.2


dec_num_layers: 4
dec_num_head: 2
dec_d_model: 384
dec_ffn_dim: 1024
dec_k_dim: 384
dec_v_dim: 384
dec_dropout: 0.2

postnet_embedding_dim: 512
postnet_kernel_size: 5
postnet_n_convolutions: 5
postnet_dropout: 0.5

normalize_before: True
ffn_type: 1dcnn
ffn_cnn_kernel_size_list: [9, 1]


dur_pred_kernel_size: 3
pitch_pred_kernel_size: 3
energy_pred_kernel_size: 3
variance_predictor_dropout: 0.5

model: !new:TransformersTTS.Model



mel_spectogram: !name:speechbrain.lobes.models.FastSpeech2.mel_spectogram
    sample_rate: !ref <sample_rate>
    hop_length: !ref <hop_length>
    win_length: !ref <win_length>
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mel_channels>
    f_min: !ref <mel_fmin>
    f_max: !ref <mel_fmax>
    power: !ref <power>
    normalized: !ref <mel_normalized>
    min_max_energy_norm: !ref <min_max_energy_norm>
    norm: !ref <norm>
    mel_scale: !ref <mel_scale>
    compression: !ref <dynamic_range_compression>

criterion: !new:utils.Loss
    ssim_loss_weight: 1.0
    mel_loss_weight: 1.0
    postnet_mel_loss_weight: 1.0


vocoder: "hifi-gan"
pretrained_vocoder: True
vocoder_source: speechbrain/tts-hifigan-ljspeech
vocoder_download_path: tmpdir_vocoder

modules:
    model: !ref <model>

train_dataloader_opts:
    batch_size: !ref <batch_size>
    drop_last: False
    num_workers: !ref <num_workers_train>
    shuffle: True
    collate_fn: !new:utils.TextMelAlignmentCollator

valid_dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers_valid>
    shuffle: False
    collate_fn: !new:utils.TextMelAlignmentCollator

opt_class: !name:torch.optim.Adam
    lr: !ref <learning_rate>
    weight_decay: !ref <weight_decay>
    betas: !ref <betas>

noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
    lr_initial: !ref <learning_rate>
    n_warmup_steps: 4000

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <epochs>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        lr_annealing: !ref <noam_annealing>
        counter: !ref <epoch_counter>

input_encoder: !new:speechbrain.dataio.encoder.TextEncoder

progress_sample_logger: !new:speechbrain.utils.train_logger.ProgressSampleLogger
    output_path: !ref <progress_sample_path>
    batch_sample_size: !ref <progress_batch_sample_size>
    formats:
        raw_batch: raw

Writing hparams_tts.yaml


This script, `tts_train.py`, is responsible for training a TransformerTTS model for text-to-speech (TTS) synthesis. Let's break down its functionality:

* The file handles GraphemeToPhoneme functionality. It defines the forward pass, loss computation, and batch processing functions. Additionally, it handles model inference and audio generation during training.

* The `dataio_prepare` function prepares the datasets for training. It loads lexicon, encodes text and audio, and preprocesses mel spectrograms for input.

* The `main` function is the entry point of the script. It parses command-line arguments, loads hyperparameters from a YAML file, and sets up the experiment directory. Then, it prepares the datasets, initializes the FastSpeech2Brain instance, and starts the training loop.

* Inside the `main` function, the `fit` method of the model instance is called to train the model. It iterates over the training and validation datasets for the specified number of epochs, logging statistics and saving checkpoints periodically.

* As you can predict we reuse a lot of fastspeech module for our purpose but modify it sufficiently to Accommodate our Transformer model.


In [12]:
%%file tts_train.py

import os
import sys
import torch
import logging
import torchaudio
import numpy as np
import speechbrain as sb
from speechbrain.inference.vocoders import HIFIGAN
from pathlib import Path
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.data_utils import scalarize
from speechbrain.inference.text import GraphemeToPhoneme
from tqdm import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger(__name__)


class FastSpeech2Brain(sb.Brain):
    def on_fit_start(self):
        """Gets called at the beginning of ``fit()``, on multiple processes
        if ``distributed_count > 0`` and backend is ddp and initializes statistics
        """
        self.hparams.progress_sample_logger.reset()
        self.last_epoch = 0
        self.last_batch = None
        self.last_loss_stats = {}
        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
        self.spn_token_encoded = (
            self.input_encoder.encode_sequence_torch(["spn"]).int().item()
        )
        return super().on_fit_start()

    def compute_forward(self, batch, stage):
        """Computes the forward pass
        Arguments
        ---------
        batch: str
            a single batch
        stage: speechbrain.Stage
            the training stage
        Returns
        -------
        the model output
        """
        inputs = self.batch_to_device(batch)

        phonemes, input_lengths, pos_text, spectogram, spectogram_input, pos_mel, pos_text, mel_lengths, gate_padded, wavs, labels = inputs

        ( mel_pred, postnet_pred, attn_probs, stop_preds,
         attns_enc, attns_dec )  = self.hparams.model(phonemes, spectogram_input, pos_text, pos_mel)

        return ( mel_pred, postnet_pred, attn_probs, stop_preds,
         attns_enc, attns_dec )

    def on_fit_batch_end(self, batch, outputs, loss, should_step):
        """At the end of the optimizer step, apply noam annealing."""
        if should_step:
            self.hparams.noam_annealing(self.optimizer)

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss given the predicted and targeted outputs.
        Arguments
        ---------
        predictions : torch.Tensor
            The model generated spectrograms and other metrics from `compute_forward`.
        batch : PaddedBatch
            This batch object contains all the relevant tensors for computation.
        stage : sb.Stage
            One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
        Returns
        -------
        loss : torch.Tensor
            A one-element tensor used for backpropagating the gradient.
        """
        x, metadata = self.batch_to_device(batch, return_metadata=True)
        phonemes, input_lengths, pos_text, spectogram, spectogram_input, pos_mel, pos_text, mel_lengths, gate_padded, wavs, labels = x

        gate_target = gate_padded[:, :mel_lengths.max().item()]

        mel_pred, postnet_pred, attn_probs, stop_preds, attns_enc, attns_dec = predictions

        self.last_batch = [phonemes, spectogram, spectogram_input, pos_mel, pos_text, input_lengths, mel_lengths, gate_padded, wavs, labels]

        self._remember_sample([phonemes, spectogram, spectogram_input, pos_mel, pos_text, input_lengths, mel_lengths, gate_padded, wavs, labels], predictions)

        srcmask_inverted = ~self.get_mask_from_lengths(mel_lengths)
        loss = self.hparams.criterion(
            mel_pred, spectogram, postnet_pred, gate_target, stop_preds, mel_lengths, srcmask_inverted, self.hparams.epoch_counter.current
        )
        self.last_loss_stats[stage] = scalarize(loss)
        return loss["total_loss"]

    def _remember_sample(self, batch, predictions):

        (
            phonemes, spectogram, spectogram_input, pos_mel, pos_text, input_lengths,  mel_lengths, gate_padded, wavs, labels
        ) = batch
        (
            mel_pred, postnet_pred, attn_probs, stop_preds,
         attns_enc, attns_dec
        ) = predictions

        self.hparams.progress_sample_logger.remember(
            targe=self.process_mel(spectogram, mel_lengths),
            outpu=self.process_mel(mel_pred, mel_lengths),
            raw_batch=self.hparams.progress_sample_logger.get_batch_sample(
                {
                    "tokens": phonemes,
                    "input_lengths": input_lengths,
                    "mel_target": spectogram,
                    "pos_text": pos_text,
                    "pos_mel": pos_mel,
                    "spectrogram_input": spectogram_input,
                    "mel_pred": mel_pred,
                    "postnet_out": postnet_pred,
                    "stop_preds": stop_preds,
                    "gate_padded":gate_padded,
                    "labels": labels,
                    "wavs": wavs,
                }
            ),
        )

    def process_mel(self, mel, len, index=0):

        assert mel.dim() == 3
        return torch.sqrt(torch.exp(mel[index][: len[index]]))

    def get_mask_from_lengths(self,lengths):
        max_len = torch.max(lengths).item()
        ids = lengths.new_tensor(torch.arange(0, max_len)).to(self.device)
        mask = (lengths.unsqueeze(1) <= ids).to(torch.bool).to(self.device)
        return mask


    def on_stage_end(self, stage, stage_loss, epoch):

        if stage == sb.Stage.VALID:

            self.last_epoch = epoch
            lr = self.hparams.noam_annealing.current_lr


            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch": epoch, "lr": lr},
                train_stats=self.last_loss_stats[sb.Stage.TRAIN],
                valid_stats=self.last_loss_stats[sb.Stage.VALID],
            )
            output_progress_sample = (
                self.hparams.progress_samples
                and epoch % self.hparams.progress_samples_interval == 0
                and epoch >= self.hparams.progress_samples_min_run
            )

            if output_progress_sample:
                logger.info("Saving predicted samples")

                self.hparams.progress_sample_logger.save(epoch)
                self.run_inference()

            self.checkpointer.save_and_keep_only(
                meta=self.last_loss_stats[stage],
                min_keys=["total_loss"],
            )

        if stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                {"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=self.last_loss_stats[sb.Stage.TEST],
            )

    def run_inference(self):
        """Produces a sample in inference mode with predicted durations."""
        if self.last_batch is None:
            return

        tokens, mel, mel_input, pos_mel, pos_text, text_lengths, output_lengths, gate_padded, wavs, labels = self.last_batch

        assert (
            self.hparams.vocoder == "hifi-gan"
            and self.hparams.pretrained_vocoder is True
        ), "Specified vocoder not supported yet"
        logger.info(
            f"Generating audio with pretrained {self.hparams.vocoder_source} vocoder"
        )
        hifi_gan = HIFIGAN.from_hparams(
            source=self.hparams.vocoder_source,
            savedir=self.hparams.vocoder_download_path,
        )

        for j in range(2):
            token = tokens[j].unsqueeze(0)
            mel_input = torch.zeros([1,1, 80]).to(tokens.device)
            pos_text = torch.arange(1, token.size(1)+1).unsqueeze(0)
            pos_text = pos_text.to(tokens.device)
            wav_name = wavs[j]

            length = 0
            pbar = tqdm(range(1023))
            with torch.no_grad():
                for i in pbar:
                    mel_input = mel_input.transpose(1,2)
                    pos_mel = torch.arange(1,mel_input.size(2)+1).unsqueeze(0).to(tokens.device)
                    mel_pred, postnet_pred, attn, stop_token, _, attn_dec = self.hparams.model.forward(token, mel_input, pos_text, pos_mel)
                    mel_input = torch.cat([mel_input, mel_pred[:,-1:,:].transpose(1,2)], dim=2)
                    mel_input = mel_input.transpose(1,2)
                    length = i
                    if torch.sigmoid(stop_token[:, -1, :])>0.5:
                        break
            length = torch.LongTensor([i])
            waveforms = hifi_gan.decode_batch(
                mel_input.transpose(2, 1), length, self.hparams.hop_length
            )
            for idx, wav in enumerate(waveforms):
                sample_type = 'with_spn'
                path = os.path.join(
                    self.hparams.progress_sample_path,
                    str(self.last_epoch),
                    f"pred_{sample_type}_{Path(wav_name).stem}.wav",
                )
                torchaudio.save(path, wav, self.hparams.sample_rate)

    def run_vocoder(self, inference_mel, mel_lens, sample_type=""):
        """Uses a pretrained vocoder to generate audio from predicted mel
        spectogram. By default, uses speechbrain hifigan.

        Arguments
        ---------
        inference_mel: torch.Tensor
            predicted mel from fastspeech2 inference
        mel_lens: torch.Tensor
            predicted mel lengths from fastspeech2 inference
            used to mask the noise from padding
        sample_type: str
            used for logging the type of the inference sample being generated

        Returns
        -------
        None
        """
        if self.last_batch is None:
            return
        phonemes, spectogram, spectogram_input, pos_mel, pos_text, input_lengths, mel_lengths, gate_padded, wavs, labels = self.last_batch

        inference_mel = inference_mel[: self.hparams.progress_batch_sample_size]
        mel_lens = mel_lens[0 : self.hparams.progress_batch_sample_size]
        assert (
            self.hparams.vocoder == "hifi-gan"
            and self.hparams.pretrained_vocoder is True
        ), "Specified vocoder not supported yet"
        logger.info(
            f"Generating audio with pretrained {self.hparams.vocoder_source} vocoder"
        )
        hifi_gan = HIFIGAN.from_hparams(
            source=self.hparams.vocoder_source,
            savedir=self.hparams.vocoder_download_path,
        )
        waveforms = hifi_gan.decode_batch(
            inference_mel.transpose(2, 1), mel_lens, self.hparams.hop_length
        )
        for idx, wav in enumerate(waveforms):
            path = os.path.join(
                self.hparams.progress_sample_path,
                str(self.last_epoch),
                f"pred_{sample_type}_{Path(wavs[idx]).stem}.wav",
            )
            torchaudio.save(path, wav, self.hparams.sample_rate)

    def batch_to_device(self, batch, return_metadata=False):
        """Transfers the batch to the target device
        Arguments
        ---------
        batch: tuple
            the batch to use
        return_metadata: bool
            indicates whether the metadata should be returned
        Returns
        -------
        batch: tuple
            the batch on the correct device
        """

        (
            text,
            mel,
            mel_input,
            pos_text,
            pos_mel,
            text_length,
            output_length,
            gate_padded,
            wavs,
            labels,

        ) = batch

        phonemes = text.to(self.device, non_blocking=True).long()
        spectogram = mel.to(self.device, non_blocking=True).float()
        spectogram_input = mel_input.to(self.device, non_blocking=True).float()
        input_lengths = text_length.to(self.device, non_blocking=True).long()
        pos_mel = pos_mel.to(self.device, non_blocking=True).long()
        pos_text = pos_text.to(self.device, non_blocking=True).long()
        mel_lengths = output_length.to(self.device, non_blocking=True).long()
        gate_padded = gate_padded.to(self.device, non_blocking=True).float()

        x = (phonemes, input_lengths, pos_text, spectogram, spectogram_input, pos_mel, pos_text, mel_lengths, gate_padded, wavs, labels )

        metadata = (labels, wavs)
        if return_metadata:
            return x, metadata
        return x




def dataio_prepare(hparams):

    lexicon = hparams["lexicon"]
    input_encoder = hparams.get("input_encoder")

    lexicon = ["@@"] + lexicon
    input_encoder.update_from_iterable(lexicon, sequence_input=False)
    input_encoder.add_unk()

    @sb.utils.data_pipeline.takes( "wav", "label_phoneme", "durations", "start", "end", "spn_labels",
                                      "last_phoneme_flags" )
    @sb.utils.data_pipeline.provides("mel_text_pair")
    def audio_pipeline( wav, label_phoneme, dur, start, end, spn_labels, last_phoneme_flags, ):

        durs = np.load(dur)
        durs_seq = torch.from_numpy(durs).int()

        label_phoneme = label_phoneme.strip()
        label_phoneme = label_phoneme.split()
        phoneme_seq = input_encoder.encode_sequence_torch(label_phoneme).int()

        assert len(phoneme_seq) == len(
            durs
        ), f"{len(phoneme_seq)}, {len(durs), len(label_phoneme)}, ({label_phoneme})"


        audio, fs = torchaudio.load(wav)

        audio = audio.squeeze()
        audio = audio[int(fs * start) : int(fs * end)]

        mel, energy = hparams["mel_spectogram"](audio=audio)

        text_length = len(phoneme_seq)
        pos_text = np.arange(1, text_length + 1)

        pos_mel = np.arange(1, mel.shape[1] + 1)
        num_mels = 80
        mel_input = np.concatenate([np.zeros([num_mels, 1], np.float32), mel[:,:-1]], axis=1)



        return (
                    phoneme_seq,
                    mel,
                    mel_input,
                    pos_text,
                    pos_mel,
                    text_length,
        )

    datasets = {}

    for dataset in hparams["splits"]:
        datasets[dataset] = sb.dataio.dataset.DynamicItemDataset.from_json(
            json_path=hparams[f"{dataset}_json"],
            replacements={"data_root": hparams["data_folder"]},
            dynamic_items=[audio_pipeline],
            output_keys=["mel_text_pair", "wav", "label" ],
        )
    return datasets, input_encoder


def main():
    hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
    with open(hparams_file) as fin:
        hparams = load_hyperpyyaml(fin, overrides)
    sb.utils.distributed.ddp_init_group(run_opts)

    sb.create_experiment_directory(
        experiment_directory=hparams["output_folder"],
        hyperparams_to_save=hparams_file,
        overrides=overrides,
    )

    datasets, input_encoder = dataio_prepare(hparams)

    fastspeech2_brain = FastSpeech2Brain(
        modules=hparams["modules"],
        opt_class=hparams["opt_class"],
        hparams=hparams,
        run_opts=run_opts,
        checkpointer=hparams["checkpointer"],
    )

    fastspeech2_brain.input_encoder = input_encoder
    fastspeech2_brain.fit(
        fastspeech2_brain.hparams.epoch_counter,
        datasets["train"],
        datasets["valid"],
        train_loader_kwargs=hparams["train_dataloader_opts"],
        valid_loader_kwargs=hparams["valid_dataloader_opts"],
    )


if __name__ == "__main__":
    main()

Writing tts_train.py


As used many times in our lab assignments we simply use the speechbrain architecture to run our training moduile to produce model checkpoints.

In [13]:
!python tts_train.py --device='cuda:0' --data_folder=LJSpeech-1.1 hparams_tts.yaml

speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results
speechbrain.core - Info: max_grad_norm arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
speechbrain.core - 9.0M trainable parameters in FastSpeech2Brain
speechbrain.utils.fetching - Fetch hyperparams.yaml: Using existing file/symlink in pretrained_models/GraphemeToPhoneme-9b27d6eb840bf95c5aedf15ae8ed1172/hyperparams.yaml.
speechbrain.utils.fetching - Fetch custom.py: Delegating to Huggingface hub, source speechbrain/soundchoice-g2p.
speechbrain.utils.fetching - Fetch model.ckpt: Using existing file/symlink in pretrained_models/GraphemeToPhoneme-9b27d6eb840bf95c5aedf15ae8ed1172/model.ckpt.
speechbrain.utils.fetching - Fetch ctc_lin.ckpt: Using existing file/symlink in pretrained_models/GraphemeToPhoneme-9b27d6eb840bf95c5aedf15ae8ed1172/ctc_lin.ckpt.
speechbrain.utils.parameter_transfer - Loading pretrained files for: model, ctc_lin
speechbrain.utils.che

Some sample audio links at different stages of training.

* At 250 epoch: https://drive.google.com/file/d/1ysT5gfXe6rVn3hUz0WVYrrLYnMuKvwXp/view?usp=sharing

* At 500 epochs: https://drive.google.com/file/d/1L3POSMxfQjvXPJeglE4EeeBLu9fuQbBf/view?usp=sharing

* At 750 epochs: https://drive.google.com/file/d/1Op7dlAEnPvv_wTAygNvbBx_yczzajyIF/view?usp=sharing

* At 900 epochs: https://drive.google.com/file/d/1gnyHWbm6dVuPLjoTHdV_lmGiGFC7tma7/view?usp=sharing

While the results may not be flawless due to minor mistakes in my codebase and approach, they still provide a solid foundation for further work. This experience has enhanced my understanding of working with transformer models, allowing me to refine my methods and improve future outcomes.

# Inference Module.

Although time constraints preventing the implementation of a proper Hugging Face API, I've managed to create a functional inference API. This API allows you to obtain results by downloading the model checkpoint and generating the output. Below, I'll demonstrate how you can utilize this API to generate the desired results.

Note: Please run the upper blocks as well to procure all the relevant code for model etc.

The code below downloads the model checkpoint to run our inference.

In [14]:
!pip install --upgrade --no-cache-dir gdown
!gdown 1geJ8ILZJNCEQJidKU-bChqmvcbESDdiO
# https://drive.google.com/file/d/1geJ8ILZJNCEQJidKU-bChqmvcbESDdiO/view?usp=sharing

Downloading...
From (original): https://drive.google.com/uc?id=1geJ8ILZJNCEQJidKU-bChqmvcbESDdiO
From (redirected): https://drive.google.com/uc?id=1geJ8ILZJNCEQJidKU-bChqmvcbESDdiO&confirm=t&uuid=07bee924-1dd7-4bad-a5d1-fb78817d0121
To: /content/model.ckpt
100% 38.0M/38.0M [00:00<00:00, 75.4MB/s]


This piece of code essentially runs a text-to-speech model, using a transformer-based approach. First, it sets up the necessary tools like the HIFIGAN vocoder and a text converter. Then, it prepares the input text by turning it into phonetic representations. After loading the model and its settings, it encodes the text into tokens and starts generating speech. The model keeps predicting mel-spectrogram frames until it decides to stop, determining the length of the generated audio. Finally, it uses the HIFIGAN vocoder to convert these mel-spectrograms into actual audio files, saving them in a folder named 'infoutput'.

In [15]:
import torch
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from speechbrain.inference.text import GraphemeToPhoneme
from speechbrain.inference.vocoders import HIFIGAN
from TransformersTTS import Model
from tqdm import tqdm
import os
from pathlib import Path

def inference_run(input_text="This is the first test for my trasnformer model."):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    vocoder = HIFIGAN.from_hparams(source='speechbrain/tts-hifigan-ljspeech', savedir='tmpdir_vocoder')

    hyperparameters_file = 'hparams_tts.yaml'
    with open(hyperparameters_file) as f:
        hyperparameters = load_hyperpyyaml(f, '')

    grapheme_to_phoneme_converter = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p", savedir="pretrained_models/soundchoice-g2p")
    phonemes = grapheme_to_phoneme_converter.g2p(input_text)
    phoneme_sequence = " ".join(phonemes)

    tts_model = Model().to(device)
    checkpoint_path = 'model.ckpt'
    tts_model.load_state_dict(torch.load(checkpoint_path))

    phoneme_sequence = phoneme_sequence.strip().split()
    lexicon = ["@@"] + hyperparameters["lexicon"]
    input_encoder = hyperparameters.get("input_encoder")
    input_encoder.update_from_iterable(lexicon, sequence_input=False)
    tokens = input_encoder.encode_sequence_torch(phoneme_sequence).int()

    token_tensor = tokens.unsqueeze(0).to(device)
    mel_input = torch.zeros([1, 1, 80]).to(token_tensor.device)
    pos_text = torch.arange(1, token_tensor.size(1)+1).unsqueeze(0).to(token_tensor.device)

    tts_model.eval()
    max_length = 1023
    pbar = tqdm(range(max_length))
    with torch.no_grad():
        for i in pbar:
            mel_input = mel_input.transpose(1, 2)
            pos_mel = torch.arange(1, mel_input.size(2)+1).unsqueeze(0).to(token_tensor.device)
            mel_pred, postnet_pred, attn, stop_token, _, attn_dec = tts_model.forward(token_tensor, mel_input, pos_text, pos_mel)
            mel_input = torch.cat([mel_input, mel_pred[:, -1:, :].transpose(1, 2)], dim=2)
            mel_input = mel_input.transpose(1, 2)
            length = i
            if torch.sigmoid(stop_token[:, -1, :]) > 0.5:
                break

    length_tensor = torch.LongTensor([length]).to(device)

    waveforms = vocoder.decode_batch(mel_input.transpose(2, 1), length_tensor, 256)
    for idx, wav in enumerate(waveforms):
        output_path = os.path.join('results', "pred_sample.wav")
        torchaudio.save(output_path, wav, 22050)

if __name__ == "__main__":
    inference_run()

# File is in the results folder with the name pred_sample

hyperparams.yaml:   0%|          | 0.00/1.16k [00:00<?, ?B/s]



generator.ckpt:   0%|          | 0.00/55.8M [00:00<?, ?B/s]

 26%|██▌       | 264/1023 [00:05<00:15, 49.25it/s]


**Here is the sample for default text**: https://drive.google.com/file/d/1H9COSeXNifkLLZf7gl4e1PB37l1MHHr2/view?usp=sharing