<h1> <font color='red'><b> Text-To-Speech with Transformers - Inferencing</h1>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cp -r /content/drive/MyDrive/TTS/utils/* /content/

**Install Speechbrain**

In [3]:
!pip install speechbrain

Collecting speechbrain
  Downloading speechbrain-1.0.0-py3-none-any.whl (760 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m760.1/760.1 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting hyperpyyaml (from speechbrain)
  Downloading HyperPyYAML-1.2.2-py3-none-any.whl (16 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.9->speechbrain)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.9->speechbrain)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.9->speechbrain)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.9->speechbrain)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from 

In [4]:
!pip install --upgrade --no-cache-dir gdown



In [5]:
import os
import json
import re
import random
import torchaudio
import csv
import torch
import tqdm
import speechbrain as sb
from speechbrain.inference.text import GraphemeToPhoneme
from hyperpyyaml import load_hyperpyyaml
from pathlib import Path
from speechbrain.utils.data_utils import get_all_files

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
!gdown 19W4olk2gSgm_HB7ewINMs05OnPWhCICp

Downloading...
From (original): https://drive.google.com/uc?id=19W4olk2gSgm_HB7ewINMs05OnPWhCICp
From (redirected): https://drive.google.com/uc?id=19W4olk2gSgm_HB7ewINMs05OnPWhCICp&confirm=t&uuid=715c309c-4406-43bc-a61b-4913b5dcd520
To: /content/model.ckpt
100% 95.8M/95.8M [00:02<00:00, 40.5MB/s]
Downloading...
From (original): https://drive.google.com/uc?id=1E1osClrpIpKxLI_GUjNMH3aFTxuxSNw2
From (redirected): https://drive.google.com/uc?id=1E1osClrpIpKxLI_GUjNMH3aFTxuxSNw2&confirm=t&uuid=a255feb4-2dc5-429d-9420-bf6d32f5e636
To: /content/module_classes.py
100% 8.50k/8.50k [00:00<00:00, 20.0MB/s]


In [7]:
from CNNPrenet import CNNPrenet
from CNNDecoderPrenet import CNNDecoderPrenet
from  ScaledPositionalEncoding import ScaledPositionalEncoding

**Hyperparameter file for Inferencing**

In [8]:
%%file hyperparams_inferencing.yaml

############################################################################
# Model: TTS with attention-based mechanism
# Tokens: g2p + possitional embeddings
# losses: MSE & BCE
# Training: LJSpeech
# ############################################################################

###################################
# Experiment Parameters and setup #
###################################
seed: 1234
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Folder set up
output_folder: !ref ./results/<seed>
save_folder: !ref <output_folder>/save


################################
# Model Parameters and model   #
################################
# Input parameters
lexicon:
    - AA
    - AE
    - AH
    - AO
    - AW
    - AY
    - B
    - CH
    - D
    - DH
    - EH
    - ER
    - EY
    - F
    - G
    - HH
    - IH
    - IY
    - JH
    - K
    - L
    - M
    - N
    - NG
    - OW
    - OY
    - P
    - R
    - S
    - SH
    - T
    - TH
    - UH
    - UW
    - V
    - W
    - Y
    - Z
    - ZH

input_encoder: !new:speechbrain.dataio.encoder.TextEncoder



################################
# Model Parameters and model   #
# Transformer Parameters
################################
d_model: 512
nhead: 8
num_encoder_layers: 3
num_decoder_layers: 3
dim_feedforward: 512
dropout: 0.1


# Decoder parameters
# The number of frames in the target per encoder step
n_frames_per_step: 1
decoder_rnn_dim: 1024
prenet_dim: 256
max_decoder_steps: 1000
gate_threshold: 0.5
p_decoder_dropout: 0.1
decoder_no_early_stopping: False

blank_index: 0 # This special tokes is for padding


# Masks
lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask


################################
# CNN 3-layers Prenet          #
################################
# Encoder Prenet
encoder_prenet: !new:CNNPrenet.CNNPrenet

# Decoder Prenet
decoder_prenet: !new:CNNDecoderPrenet.CNNDecoderPrenet

################################
# Positional Encodings         #
################################

#encoder
pos_emb_enc: !new:ScaledPositionalEncoding.ScaledPositionalEncoding
    input_size: !ref <d_model>
    max_len: 5000

#decoder
pos_emb_dec: !new:ScaledPositionalEncoding.ScaledPositionalEncoding
    input_size: !ref <d_model>
    max_len: 5000


################################
# S2S Transfomer               #
################################

Seq2SeqTransformer: !new:torch.nn.Transformer
    d_model: !ref <d_model>
    nhead: !ref <nhead>
    num_encoder_layers: !ref <num_encoder_layers>
    num_decoder_layers: !ref <num_decoder_layers>
    dim_feedforward: !ref <dim_feedforward>
    dropout: !ref <dropout>
    batch_first: True


################################
# CNN 5-layers PostNet         #
################################

decoder_postnet: !new:speechbrain.lobes.models.Tacotron2.Postnet


# Linear transformation on the top of the decoder.
stop_lin: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <d_model>
    n_neurons: 1


# Linear transformation on the top of the decoder.
mel_lin: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <d_model>
    n_neurons: 80

modules:
    encoder_prenet: !ref <encoder_prenet>
    pos_emb_enc: !ref <pos_emb_enc>
    decoder_prenet: !ref <decoder_prenet>
    pos_emb_dec: !ref <pos_emb_dec>
    Seq2SeqTransformer: !ref <Seq2SeqTransformer>
    mel_lin: !ref <mel_lin>
    stop_lin: !ref <stop_lin>
    decoder_postnet: !ref <decoder_postnet>


model: !new:torch.nn.ModuleList
    - [!ref <encoder_prenet>,!ref <pos_emb_enc>,
       !ref <decoder_prenet>, !ref <pos_emb_dec>, !ref <Seq2SeqTransformer>,
       !ref <mel_lin>, !ref <stop_lin>,  !ref <decoder_postnet>]


pretrained_model_path: /content/model.ckpt

# The pretrainer allows a mapping between pretrained files and instances that
# are declared in the yaml. E.g here, we will download the file model.ckpt
# and it will be loaded into "model" which is pointing to the <model> defined
# before.

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
   collect_in: !ref <save_folder>
   loadables:
      model: !ref <model>
   paths:
      model: !ref <pretrained_model_path>


Writing hyperparams_inferencing.yaml


**TTS Inferencing**

In [9]:
%%file TTSInferencing.py

import re
import logging
import torch
import torchaudio
import random
import speechbrain
from speechbrain.inference.interfaces import Pretrained
from speechbrain.inference.text import GraphemeToPhoneme

logger = logging.getLogger(__name__)

class TTSInferencing(Pretrained):
    """
    A TTS class (text -> mel_spec).
    Arguments
    ---------
    hparams
        Hyperparameters (from HyperPyYAML)
    """

    HPARAMS_NEEDED = ["modules", "input_encoder"]

    MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc",
                      "decoder_prenet", "pos_emb_dec",
                      "Seq2SeqTransformer", "mel_lin",
                      "stop_lin", "decoder_postnet"]


    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        lexicon = self.hparams.lexicon
        lexicon = ["@@"] + lexicon
        self.input_encoder = self.hparams.input_encoder
        self.input_encoder.update_from_iterable(lexicon, sequence_input=False)
        self.input_encoder.add_unk()

        self.modules = self.hparams.modules

        self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")




    def generate_padded_phonemes(self, texts):
        """Computes mel-spectrogram for a list of texts

        Arguments
        ---------
        texts: List[str]
            texts to be converted to spectrogram

        Returns
        -------
        tensors of output spectrograms
        """

        # Preprocessing required at the inference time for the input text
        # "label" below contains input text
        # "phoneme_labels" contain the phoneme sequences corresponding to input text labels

        phoneme_labels = list()

        for label in texts:

          phoneme_label = list()

          label = self.custom_clean(label).upper()

          words = label.split()
          words = [word.strip() for word in words]
          words_phonemes = self.g2p(words)

          for i in range(len(words_phonemes)):
              words_phonemes_seq = words_phonemes[i]
              for phoneme in words_phonemes_seq:
                  if not phoneme.isspace():
                      phoneme_label.append(phoneme)
          phoneme_labels.append(phoneme_label)


        # encode the phonemes with input text encoder
        encoded_phonemes = list()
        for i in range(len(phoneme_labels)):
            phoneme_label = phoneme_labels[i]
            encoded_phoneme =  torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device)
            encoded_phonemes.append(encoded_phoneme)


        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True
        )

        max_input_len = input_lengths[0]

        phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device)
        phoneme_padded.zero_()

        for seq_idx, seq in enumerate(encoded_phonemes):
            phoneme_padded[seq_idx, : len(seq)] = seq


        return phoneme_padded.to(self.device, non_blocking=True).float()


    def encode_batch(self, texts):
        """Computes mel-spectrogram for a list of texts

        Texts must be sorted in decreasing order on their lengths

        Arguments
        ---------
        texts: List[str]
            texts to be encoded into spectrogram

        Returns
        -------
        tensors of output spectrograms
        """

        # generate phonemes and padd the input texts
        encoded_phoneme_padded = self.generate_padded_phonemes(texts)
        phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded)
        # Positional Embeddings
        phoneme_pos_emb =  self.modules['pos_emb_enc'](encoded_phoneme_padded)
        # Summing up embeddings
        enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1)  + phoneme_pos_emb
        enc_phoneme_emb = enc_phoneme_emb.to(self.device)


        with torch.no_grad():

          # generate sequential predictions via transformer decoder
          start_token = torch.full((80, 1), fill_value= 0)
          start_token[1] = 2
          decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1)
          decoder_input = decoder_input.to(self.device, non_blocking=True).float()

          num_itr = 0
          stop_condition = [False] * decoder_input.size(0)
          max_iter = 1000

          # while not all(stop_condition) and num_itr < max_iter:
          while num_itr < max_iter:

            # Decoder Prenet
            mel_prenet_emb =  self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1)

            # Positional Embeddings
            mel_pos_emb =  self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device)
            # Summing up Embeddings
            dec_mel_spec = mel_prenet_emb + mel_pos_emb

            # Getting the target mask to avoid looking ahead
            tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device)

            # Getting the source mask
            src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device)

            # Padding masks for source and targets
            src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device)
            tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device)


            # Running the Seq2Seq Transformer
            decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask,
                                                              src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask)

            # Mel Linears
            mel_linears =  self.modules['mel_lin'](decoder_outputs).permute(0,2,1)
            mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output
            mel_pred = mel_linears + mel_postnet # mel tensor output

            stop_token_pred =  self.modules['stop_lin'](decoder_outputs).squeeze(-1)

            stop_condition_list = self.check_stop_condition(stop_token_pred)


            # update the values of main stop conditions
            stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))]
            stop_condition = stop_condition_update


            # Prepare input for the transformer input for next iteration
            current_output = mel_pred[:, :, -1:]

            decoder_input=torch.cat([decoder_input,current_output],dim=2)
            num_itr = num_itr+1

        mel_outputs =  decoder_input[:, :, 1:]

        return mel_outputs



    def encode_text(self, text):
        """Runs inference for a single text str"""
        return self.encode_batch([text])


    def forward(self, text_list):
        "Encodes the input texts."
        return self.encode_batch(text_list)


    def check_stop_condition(self, stop_token_pred):
        """
        check if stop token / EOS reached or not for mel_specs in the batch
        """

        # Applying sigmoid to perform binary classification
        sigmoid_output = torch.sigmoid(stop_token_pred)
        # Checking if the probability is greater than 0.5
        stop_results = sigmoid_output > 0.8
        stop_output = [all(result) for result in stop_results]

        return stop_output



    def custom_clean(self, text):
        """
        Uses custom criteria to clean text.

        Arguments
        ---------
        text : str
            Input text to be cleaned
        model_name : str
            whether to treat punctuations

        Returns
        -------
        text : str
            Cleaned text
        """

        _abbreviations = [
            (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
            for x in [
                ("mrs", "missus"),
                ("mr", "mister"),
                ("dr", "doctor"),
                ("st", "saint"),
                ("co", "company"),
                ("jr", "junior"),
                ("maj", "major"),
                ("gen", "general"),
                ("drs", "doctors"),
                ("rev", "reverend"),
                ("lt", "lieutenant"),
                ("hon", "honorable"),
                ("sgt", "sergeant"),
                ("capt", "captain"),
                ("esq", "esquire"),
                ("ltd", "limited"),
                ("col", "colonel"),
                ("ft", "fort"),
            ]
        ]

        text = re.sub(" +", " ", text)

        for regex, replacement in _abbreviations:
            text = re.sub(regex, replacement, text)
        return text


Writing TTSInferencing.py


Inference

In [10]:
import os
import torchaudio
from TTSInferencing import TTSInferencing

# combine the TTS model with a vocoder (that generates the final waveform)
# Intialize the Vocoder (HiFIGAN)
from speechbrain.inference.vocoders import HIFIGAN

# to access the audio with IPython.display
from IPython.display import Audio


# from speechbrain.inference.TTSInferencing import TTSInferencing

tts_model = TTSInferencing.from_hparams(source="/content",
                                        hparams_file='/content/hyperparams_inferencing.yaml',
                                        pymodule_file='/content/module_classes.py',
                                        savedir="/content/",
                                        )


hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir='/pretrained_models/hifi-gan-ljspeech')

hyperparams.yaml:   0%|          | 0.00/11.3k [00:00<?, ?B/s]

model.ckpt:   0%|          | 0.00/129M [00:00<?, ?B/s]

ctc_lin.ckpt:   0%|          | 0.00/177k [00:00<?, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

hyperparams.yaml:   0%|          | 0.00/1.16k [00:00<?, ?B/s]



generator.ckpt:   0%|          | 0.00/55.8M [00:00<?, ?B/s]

Testing

In [11]:
text = ["How is your day going"]
mel_outputs = tts_model.encode_batch(text)

# Running Vocoder (spectrogram-to-waveform)
# generate and save waveforms
waveforms = hifi_gan.decode_batch(mel_outputs)

base_path = "./test_hifi_gan_waveforms"
os.makedirs(base_path, exist_ok=True)

# Save the audios generated in the given base_path dir
for i, waveform in enumerate(waveforms):
  file_path = os.path.join(base_path, f"_waveform_{i}.wav")
  torchaudio.save( file_path, waveform.squeeze(1), sample_rate = 22050)



In [None]:
Audio("/content/test_hifi_gan_waveforms/_waveform_0.wav", rate=22050)