In [1]:
import pandas as pd

In [2]:
speakers_df = pd.read_csv(
    "./datasets_cache/LIBRITTS/LibriTTS/speakers.tsv",
    sep="\t",
    names=["READER", "GENDER", "SUBSET", "NAME"],
)

In [3]:
from dataclasses import dataclass, field
from typing import List, Literal, Tuple, Union

PreprocessLangType = Literal["english_only", "multilingual"]


@dataclass
class STFTConfig:
    filter_length: int
    hop_length: int
    win_length: int
    n_mel_channels: int
    mel_fmin: int
    mel_fmax: int


# Base class used with the Univnet vocoder
@dataclass
class PreprocessingConfig:
    language: PreprocessLangType
    stft: STFTConfig
    sampling_rate: int = 22050
    min_seconds: float = 0.5
    max_seconds: float = 6.0
    use_audio_normalization: bool = True
    workers: int = 8


@dataclass
class PreprocessingConfigUnivNet(PreprocessingConfig):
    stft: STFTConfig = field(
        default_factory=lambda: STFTConfig(
            filter_length=1024,
            hop_length=256,
            win_length=1024,
            n_mel_channels=100,  # univnet
            mel_fmin=20,
            mel_fmax=11025,
        ),
    )

In [4]:
@dataclass
class LangItem:
    r"""A class for storing language information."""

    phonemizer: str
    phonemizer_espeak: str
    nemo: str
    processing_lang_type: PreprocessLangType
    
langs_map: dict[str, LangItem] = {
    "en": LangItem(
        phonemizer="en_us",
        phonemizer_espeak="en-us",
        nemo="en",
        processing_lang_type="english_only",
    ),
}

def get_lang_map(lang: str) -> LangItem:
    r"""Returns a LangItem object for the given language.

    Args:
        lang (str): The language to get the LangItem for.

    Raises:
        ValueError: If the language is not supported.

    Returns:
        LangItem: The LangItem object for the given language.
    """
    if lang not in langs_map:
        raise ValueError(f"Language {lang} is not supported!")
    return langs_map[lang]

In [5]:
import re

# from nemo_text_processing.text_normalization.normalize import Normalizer
from unidecode import unidecode
import torchaudio


class NormalizeText:
    r"""NVIDIA NeMo is a conversational AI toolkit built for researchers working on automatic speech recognition (ASR), text-to-speech synthesis (TTS), large language models (LLMs), and natural language processing (NLP). The primary objective of NeMo is to help researchers from industry and academia to reuse prior work (code and pretrained models) and make it easier to create new conversational AI models.

    This class normalize the characters in the input text and normalize the input text with the `nemo_text_processing`.

    Args:
        lang (str): The language code to use for normalization. Defaults to "en".

    Attributes:
        lang (str): The language code to use for normalization. Defaults to "en".
        model (Normalizer): The `nemo_text_processing` Normalizer model.

    Methods:
        byte_encode(word: str) -> list: Encode a word as a list of bytes.
        normalize_chars(text: str) -> str: Normalize the characters in the input text.
        __call__(text: str) -> str: Normalize the input text with the `nemo_text_processing`.

    Examples:
        >>> from training.preprocess.normilize_text import NormalizeText
        >>> normilize_text = NormalizeText()
        >>> normilize_text("It’s a beautiful day…")
        "It's a beautiful day."
    """

    def __init__(self, lang: str = "en"):
        r"""Initialize a new instance of the NormalizeText class.

        Args:
            lang (str): The language code to use for normalization. Defaults to "en".

        """
        self.lang = lang
        self.processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()
        # self.model = Normalizer(input_case="cased", lang=lang)

    def byte_encode(self, word: str) -> list:
        r"""Encode a word as a list of bytes.

        Args:
            word (str): The word to encode.

        Returns:
            list: A list of bytes representing the encoded word.

        """
        text = word.strip()
        return list(text.encode("utf-8"))

    def normalize_chars(self, text: str) -> str:
        r"""Normalize the characters in the input text.

        Args:
            text (str): The input text to normalize.

        Returns:
            str: The normalized text.

        Examples:
            >>> normalize_chars("It’s a beautiful day…")
            "It's a beautiful day."

        """
        # Define the character mapping
        char_mapping = {
            ord("’"): ord("'"),
            ord("”"): ord("'"),
            ord("…"): ord("."),
            ord("„"): ord("'"),
            ord("“"): ord("'"),
            ord('"'): ord("'"),
            ord("–"): ord("-"),
            ord("—"): ord("-"),
            ord("«"): ord("'"),
            ord("»"): ord("'"),
        }

        # Add unicode normalization as additional garanty and normalize the characters using translate() method
        normalized_string = unidecode(text).translate(char_mapping)

        # Remove redundant multiple characters
        # TODO: Maybe there is some effect on duplication?
        return re.sub(r"(\.|\!|\?|\-)\1+", r"\1", normalized_string)

    def __call__(self, text: str) -> str:
        r"""Normalize the input text with the `nemo_text_processing`.

        Args:
            text (str): The input text to normalize.

        Returns:
            str: The normalized text.

        """
        text = self.normalize_chars(text)
        # return self.model.normalize(text)

        # Split the text into lines
        # lines = text.split("\n")
        processed, lengths = self.processor(text)
        normalized_lines = [self.processor.tokens[i] for i in processed[0, : lengths[0]]]
        # normalized_lines = self.model.normalize_list(lines)

        # TODO: check this!
        # Join the normalized lines, replace \n with . and return the result
        result = ". ".join(normalized_lines)
        return result


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/user/codec/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/user/codec/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/user/codec/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/Users/u

In [6]:
# NOTE: for the backward comp.
# Prepare the phonemes list and dictionary for the embedding
phoneme_basic_symbols = [
    # IPA symbols
    "a",
    "b",
    "d",
    "e",
    "f",
    "g",
    "h",
    "i",
    "j",
    "k",
    "l",
    "m",
    "n",
    "o",
    "p",
    "r",
    "s",
    "t",
    "u",
    "v",
    "w",
    "x",
    "y",
    "z",
    "æ",
    "ç",
    "ð",
    "ø",
    "ŋ",
    "œ",
    "ɐ",
    "ɑ",
    "ɔ",
    "ə",
    "ɛ",
    "ɝ",
    "ɹ",
    "ɡ",
    "ɪ",
    "ʁ",
    "ʃ",
    "ʊ",
    "ʌ",
    "ʏ",
    "ʒ",
    "ʔ",
    "ˈ",
    "ˌ",
    "ː",
    "̃",
    "̍",
    "̥",
    "̩",
    "̯",
    "͡",
    "θ",
    # Punctuation
    "!",
    "?",
    ",",
    ".",
    "-",
    ":",
    ";",
    '"',
    "'",
    "(",
    ")",
    " ",
]

# TODO: add support for other languages
# _letters_accented = "µßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒ"
# _letters_cyrilic = "абвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ"
# _pad = "$"

# This is the list of symbols from StyledTTS2
_punctuation = ';:,.!?¡¿—…"«»“”'
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"

# Combine all symbols
symbols = list(_punctuation) + list(_letters) + list(_letters_ipa)

# Add only unique symbols
phones = phoneme_basic_symbols + [
    symbol for symbol in symbols if symbol not in phoneme_basic_symbols
]

# TODO: Need to understand how to replace this
# len(phones) == 184, leave it as is at this point
symbols = [str(el) for el in range(256)]
symbol2id = {s: i for i, s in enumerate(symbols)}
id2symbol = {i: s for i, s in enumerate(symbols)}

In [7]:
from logging import ERROR, Logger
import os

from phonemizer.backend import EspeakBackend

# IPA Phonemizer: https://github.com/bootphon/phonemizer
from phonemizer.backend.espeak.wrapper import EspeakWrapper

# Create a Logger instance
logger = Logger("my_logger")
# Set the level to ERROR
logger.setLevel(ERROR)

from dp.preprocessing.text import SequenceTokenizer

# from models.config import get_lang_map
# from models.config.symbols import phones

# INFO: Fix for windows, used for local env
if os.name == "nt":
    ESPEAK_LIBRARY = os.getenv(
        "ESPEAK_LIBRARY",
        "C:\\Program Files\\eSpeak NG\\libespeak-ng.dll",
    )
    EspeakWrapper.set_library(ESPEAK_LIBRARY)


class TokenizerIpaEspeak:
    def __init__(self, lang: str = "en"):
        lang_map = get_lang_map(lang)
        self.lang = lang_map.phonemizer_espeak
        self.lang_seq = lang_map.phonemizer

        # NOTE: for backward compatibility with previous IPA tokenizer see the TokenizerIPA class
        self.tokenizer = SequenceTokenizer(
            phones,
            languages=["de", "en_us"],
            lowercase=True,
            char_repeats=1,
            append_start_end=True,
        )

        self.phonemizer = EspeakBackend(
            language=self.lang,
            preserve_punctuation=True,
            with_stress=True,
            words_mismatch="ignore",
            logger=logger,
        ).phonemize

    def __call__(self, text: str):
        r"""Converts the input text to phonemes and tokenizes them.

        Args:
            text (str): The input text to be tokenized.

        Returns:
            Tuple[Union[str, List[str]], List[int]]: IPA phonemes and tokens.

        """
        phones_ipa = "".join(self.phonemizer([text]))

        tokens = self.tokenizer(phones_ipa, language=self.lang_seq)

        return phones_ipa, tokens

In [8]:
@dataclass
class VocoderBasicConfig:
    segment_size: int = 16384
    learning_rate: float = 0.0001
    adam_b1: float = 0.5
    adam_b2: float = 0.9
    lr_decay: float = 0.995
    synth_interval: int = 250
    checkpoint_interval: int = 250
    stft_lamb: float = 2.5

In [9]:
from typing import Optional, Tuple

import librosa
import torch
from torch.nn import Module


class TacotronSTFT(Module):
    def __init__(
        self,
        filter_length: int,
        hop_length: int,
        win_length: int,
        n_mel_channels: int,
        sampling_rate: int,
        center: bool,
        mel_fmax: Optional[int],
        mel_fmin: float = 0.0,
    ):
        r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.

        Args:
            filter_length (int): Length of the filter window.
            hop_length (int): Number of samples between successive frames.
            win_length (int): Size of the STFT window.
            n_mel_channels (int): Number of mel bins.
            sampling_rate (int): Sampling rate of the input waveforms.
            mel_fmin (int or None): Minimum frequency for the mel filter bank.
            mel_fmax (int or None): Maximum frequency for the mel filter bank.
            center (bool): Whether to pad the input signal on both sides.
        """
        super().__init__()

        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.n_fft = filter_length
        self.hop_size = hop_length
        self.win_size = win_length
        self.fmin = mel_fmin
        self.fmax = mel_fmax
        self.center = center

        # Define the mel filterbank
        mel = librosa.filters.mel(
            sr=sampling_rate,
            n_fft=filter_length,
            n_mels=n_mel_channels,
            fmin=mel_fmin,
            fmax=mel_fmax,
        )

        mel_basis = torch.tensor(mel, dtype=float).float()

        # Define the Hann window
        hann_window = torch.hann_window(win_length)

        self.register_buffer("mel_basis", mel_basis)
        self.register_buffer("hann_window", hann_window)

    def _spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        assert torch.min(y.data) >= -1
        assert torch.max(y.data) <= 1

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (
                int((self.n_fft - self.hop_size) / 2),
                int((self.n_fft - self.hop_size) / 2),
            ),
            mode="reflect",
        )
        y = y.squeeze(1)
        spec = torch.stft(
            y,
            self.n_fft,
            hop_length=self.hop_size,
            win_length=self.win_size,
            window=self.hann_window,  # type: ignore
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        return torch.view_as_real(spec)

    def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        spec = self._spectrogram(y)
        return torch.norm(spec, p=2, dim=-1)

    def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Computes mel-spectrograms from a batch of waves.

        Args:
            y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]

        Returns:
            torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
            torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
        """
        spec = self._spectrogram(y)

        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

        mel = torch.matmul(self.mel_basis, spec)  # type: ignore
        mel = self.spectral_normalize_torch(mel)

        return spec, mel

    def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
        r"""Applies dynamic range compression to magnitudes.

        Args:
            magnitudes (torch.Tensor): Input magnitudes.

        Returns:
            torch.Tensor: Output magnitudes.
        """
        return self.dynamic_range_compression_torch(magnitudes)

    def dynamic_range_compression_torch(
        self,
        x: torch.Tensor,
        C: int = 1,
        clip_val: float = 1e-5,
    ) -> torch.Tensor:
        r"""Applies dynamic range compression to x.

        Args:
            x (torch.Tensor): Input tensor.
            C (float): Compression factor.
            clip_val (float): Clipping value.

        Returns:
            torch.Tensor: Output tensor.
        """
        return torch.log(torch.clamp(x, min=clip_val) * C)

    # NOTE: audio np.ndarray changed to torch.FloatTensor!
    def get_mel_from_wav(self, audio: torch.Tensor) -> torch.Tensor:
        audio_tensor = audio.unsqueeze(0)
        with torch.no_grad():
            _, melspec = self.forward(audio_tensor)
        return melspec.squeeze(0)

In [10]:
from librosa.filters import mel as librosa_mel_fn
import torch


class AudioProcessor:
    r"""A class used to process audio signals and convert them into different representations.

    Attributes:
        hann_window (dict): A dictionary to store the Hann window for different configurations.
        mel_basis (dict): A dictionary to store the Mel basis for different configurations.

    Methods:
        name_mel_basis(spec, n_fft, fmax): Generate a name for the Mel basis based on the FFT size, maximum frequency, data type, and device.
        amp_to_db(magnitudes, C=1, clip_val=1e-5): Convert amplitude to decibels (dB).
        db_to_amp(magnitudes, C=1): Convert decibels (dB) to amplitude.
        wav_to_spec(y, n_fft, hop_length, win_length, center=False): Convert a waveform to a spectrogram and compute the magnitude.
        wav_to_energy(y, n_fft, hop_length, win_length, center=False): Convert a waveform to a spectrogram and compute the energy.
        spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): Convert a spectrogram to a Mel spectrogram.
        wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): Convert a waveform to a Mel spectrogram.
    """

    def __init__(self):
        self.hann_window = {}
        self.mel_basis = {}

    @staticmethod
    def name_mel_basis(spec: torch.Tensor, n_fft: int, fmax: int) -> str:
        """Generate a name for the Mel basis based on the FFT size, maximum frequency, data type, and device.

        Args:
            spec (torch.Tensor): The spectrogram tensor.
            n_fft (int): The FFT size.
            fmax (int): The maximum frequency.

        Returns:
            str: The generated name for the Mel basis.
        """
        n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
        return n_fft_len

    @staticmethod
    def amp_to_db(magnitudes: torch.Tensor, C: int = 1, clip_val: float = 1e-5) -> torch.Tensor:
        r"""Convert amplitude to decibels (dB).

        Args:
            magnitudes (Tensor): The amplitude magnitudes to convert.
            C (int, optional): A constant value used in the conversion. Defaults to 1.
            clip_val (float, optional): A value to clamp the magnitudes to avoid taking the log of zero. Defaults to 1e-5.

        Returns:
            Tensor: The converted magnitudes in dB.
        """
        return torch.log(torch.clamp(magnitudes, min=clip_val) * C)

    @staticmethod
    def db_to_amp(magnitudes: torch.Tensor, C: int = 1) -> torch.Tensor:
        r"""Convert decibels (dB) to amplitude.

        Args:
            magnitudes (Tensor): The dB magnitudes to convert.
            C (int, optional): A constant value used in the conversion. Defaults to 1.

        Returns:
            Tensor: The converted magnitudes in amplitude.
        """
        return torch.exp(magnitudes) / C

    def wav_to_spec(
        self,
        y: torch.Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        center: bool = False,
    ) -> torch.Tensor:
        r"""Convert a waveform to a spectrogram and compute the magnitude.

        Args:
            y (Tensor): The input waveform.
            n_fft (int): The FFT size.
            hop_length (int): The hop (stride) size.
            win_length (int): The window size.
            center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False.

        Returns:
            Tensor: The magnitude of the computed spectrogram.
        """
        y = y.squeeze(1)

        dtype_device = str(y.dtype) + "_" + str(y.device)
        wnsize_dtype_device = str(win_length) + "_" + dtype_device
        if wnsize_dtype_device not in self.hann_window:
            self.hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
            mode="reflect",
        )
        y = y.squeeze(1)

        spec = torch.stft(
            y,
            n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=self.hann_window[wnsize_dtype_device],
            center=center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )

        spec = torch.view_as_real(spec)

        # Compute the magnitude
        spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

        return spec

    def wav_to_energy(
        self,
        y: torch.Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        center: bool = False,
    ) -> torch.Tensor:
        r"""Convert a waveform to a spectrogram and compute the energy.

        Args:
            y (Tensor): The input waveform.
            n_fft (int): The FFT size.
            hop_length (int): The hop (stride) size.
            win_length (int): The window size.
            center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False.

        Returns:
            Tensor: The energy of the computed spectrogram.
        """
        spec = self.wav_to_spec(y, n_fft, hop_length, win_length, center=center)
        spec = torch.norm(spec, dim=1, keepdim=True).squeeze(0)

        # Normalize the energy
        return (spec - spec.mean()) / spec.std()

    def spec_to_mel(
            self,
            spec: torch.Tensor,
            n_fft: int,
            num_mels: int,
            sample_rate: int,
            fmin: int,
            fmax: int,
    ) -> torch.Tensor:
        r"""Convert a spectrogram to a Mel spectrogram.

        Args:
            spec (torch.Tensor): The input spectrogram of shape [B, C, T].
            n_fft (int): The FFT size.
            num_mels (int): The number of Mel bands.
            sample_rate (int): The sample rate of the audio.
            fmin (int): The minimum frequency.
            fmax (int): The maximum frequency.

        Returns:
            torch.Tensor: The computed Mel spectrogram of shape [B, C, T].
        """
        mel_basis_key = self.name_mel_basis(spec, n_fft, fmax)

        if mel_basis_key not in self.mel_basis:
            mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
            self.mel_basis[mel_basis_key] = torch.tensor(mel).to(dtype=spec.dtype, device=spec.device)

        mel = torch.matmul(self.mel_basis[mel_basis_key], spec)
        mel = self.amp_to_db(mel)

        return mel

    def wav_to_mel(
        self,
        y: torch.Tensor,
        n_fft: int,
        num_mels: int,
        sample_rate: int,
        hop_length: int,
        win_length: int,
        fmin: int,
        fmax: int,
        center: bool = False,
    ) -> torch.Tensor:
        r"""Convert a waveform to a Mel spectrogram.

        Args:
            y (torch.Tensor): The input waveform.
            n_fft (int): The FFT size.
            num_mels (int): The number of Mel bands.
            sample_rate (int): The sample rate of the audio.
            hop_length (int): The hop (stride) size.
            win_length (int): The window size.
            fmin (int): The minimum frequency.
            fmax (int): The maximum frequency.
            center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False.

        Returns:
            torch.Tensor: The computed Mel spectrogram.
        """
        # Convert the waveform to a spectrogram
        spec = self.wav_to_spec(y, n_fft, hop_length, win_length, center=center)

        # Convert the spectrogram to a Mel spectrogram
        mel = self.spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)

        return mel

In [11]:
import sys
from typing import Tuple, Union

import librosa
import numpy as np
import torch
import torchaudio


def stereo_to_mono(audio: torch.Tensor) -> torch.Tensor:
    r"""Converts a stereo audio tensor to mono by taking the mean across channels.

    Args:
        audio (torch.Tensor): Input audio tensor of shape (channels, samples).

    Returns:
        torch.Tensor: Mono audio tensor of shape (1, samples).
    """
    return torch.mean(audio, 0, True)


def resample(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
    r"""Resamples an audio waveform from the original sampling rate to the target sampling rate.

    Args:
        wav (np.ndarray): The audio waveform to be resampled.
        orig_sr (int): The original sampling rate of the audio waveform.
        target_sr (int): The target sampling rate to resample the audio waveform to.

    Returns:
        np.ndarray: The resampled audio waveform.
    """
    return librosa.resample(wav, orig_sr=orig_sr, target_sr=target_sr)


def safe_load(path: str, sr: Union[int, None]) -> Tuple[np.ndarray, int]:
    r"""Load an audio file from disk and return its content as a numpy array.

    Args:
        path (str): The path to the audio file.
        sr (int or None): The target sampling rate. If None, the original sampling rate is used.

    Returns:
        Tuple[np.ndarray, int]: A tuple containing the audio content as a numpy array and the actual sampling rate.
    """
    try:
        audio, sr_actual = torchaudio.load(path) # type: ignore
        if audio.shape[0] > 0:
            audio = stereo_to_mono(audio)
        audio = audio.squeeze(0)
        if sr_actual != sr and sr is not None:
            audio = resample(np.array(audio), orig_sr=sr_actual, target_sr=sr)
            # audio = resample(audio.numpy(), orig_sr=sr_actual, target_sr=sr)
            sr_actual = sr
        else:
            audio = np.array(audio)
            # .numpy()
            # audio = audio.numpy()
    except Exception as e:
        raise type(e)(
            f"The following error happened loading the file {path} ... \n" + str(e),
        ).with_traceback(sys.exc_info()[2])

    return audio, sr_actual


def preprocess_audio(
    audio: torch.Tensor, sr_actual: int, sr: Union[int, None],
) -> Tuple[torch.Tensor, int]:
    r"""Preprocesses audio by converting stereo to mono, resampling if necessary, and returning the audio tensor and sample rate.

    Args:
        audio (torch.Tensor): The audio tensor to preprocess.
        sr_actual (int): The actual sample rate of the audio.
        sr (Union[int, None]): The target sample rate to resample the audio to, if necessary.

    Returns:
        Tuple[torch.Tensor, int]: The preprocessed audio tensor and sample rate.
    """
    try:
        if audio.shape[0] > 0:
            audio = stereo_to_mono(audio)
        audio = audio.squeeze(0)
        if sr_actual != sr and sr is not None:
            if isinstance(audio, torch.Tensor):
                detach = audio.data.detach().tolist()
                
                audio = np.array(detach, dtype=float)

            # audio
            audio_np = resample(audio, orig_sr=sr_actual, target_sr=sr)
            # Convert back to torch tensor
            audio = torch.tensor(audio_np)
            sr_actual = sr
    except Exception as e:
        raise type(e)(
            f"The following error happened while processing the audio ... \n {e!s}",
        ).with_traceback(sys.exc_info()[2])

    return audio, sr_actual


def normalize_loudness(wav: torch.Tensor) -> torch.Tensor:
    r"""Normalize the loudness of an audio waveform.

    Args:
        wav (torch.Tensor): The input waveform.

    Returns:
        torch.Tensor: The normalized waveform.

    Examples:
        >>> wav = np.array([1.0, 2.0, 3.0])
        >>> normalize_loudness(wav)
        tensor([0.33333333, 0.66666667, 1.  ])
    """
    return wav / torch.max(torch.abs(wav))

In [12]:
from typing import List, Tuple

import numpy as np
import torch
import torch.nn.functional as F

# TODO: LOOK AT THIS ESTIMATION ALGO
#######################################################################################
# Original implementation from https://github.com/NVIDIA/mellotron/blob/master/yin.py #
#######################################################################################


def differenceFunction(x: np.ndarray, N: int, tau_max: int) -> np.ndarray:
    r"""Compute the difference function of an audio signal.

    This function computes the difference function of an audio signal `x` using the algorithm described in equation (6) of [1]. The difference function is a measure of the similarity between the signal and a time-shifted version of itself, and is commonly used in pitch detection algorithms.

    This implementation uses the NumPy FFT functions to compute the difference function efficiently.

    Parameters
        x (np.ndarray): The audio signal to compute the difference function for.
        N (int): The length of the audio signal.
        tau_max (int): The maximum integration window size to use.

    Returns
        np.ndarray: The difference function of the audio signal.

    References
        [1] A. de Cheveigné and H. Kawahara, "YIN, a fundamental frequency estimator for speech and music," The Journal of the Acoustical Society of America, vol. 111, no. 4, pp. 1917-1930, 2002.
    """
    x = np.array(x, np.float64)
    w = x.size
    tau_max = min(tau_max, w)
    x_cumsum = np.concatenate((np.array([0.0]), (x * x).cumsum()))
    size = w + tau_max
    p2 = (size // 32).bit_length()
    nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
    size_pad = min(x * 2**p2 for x in nice_numbers if x * 2**p2 >= size)
    fc = np.fft.rfft(x, size_pad)
    conv = np.fft.irfft(fc * fc.conjugate())[:tau_max]
    return x_cumsum[w : w - tau_max : -1] + x_cumsum[w] - x_cumsum[:tau_max] - 2 * conv


def cumulativeMeanNormalizedDifferenceFunction(df: np.ndarray, N: int) -> np.ndarray:
    r"""Compute the cumulative mean normalized difference function (CMND) of a difference function.

    The CMND is defined as the element-wise product of the difference function with a range of values from 1 to N-1,
    divided by the cumulative sum of the difference function up to that point, plus a small epsilon value to avoid
    division by zero. The first element of the CMND is set to 1.

    Args:
        df (np.ndarray): The difference function.
        N (int): The length of the data.

    Returns:
        np.ndarray: The cumulative mean normalized difference function.

    References:
        [1] K. K. Paliwal and R. P. Sharma, "A robust algorithm for pitch detection in noisy speech signals,"
            Speech Communication, vol. 12, no. 3, pp. 249-263, 1993.
    """
    cmndf = (
        df[1:] * range(1, N) / (np.cumsum(df[1:]).astype(float) + 1e-8)
    )  # scipy method
    return np.insert(cmndf, 0, 1)


def getPitch(cmdf: np.ndarray, tau_min: int, tau_max: int, harmo_th: float=0.1) -> int:
    r"""Compute the fundamental period of a frame based on the Cumulative Mean Normalized Difference function (CMND).

    The CMND is a measure of the periodicity of a signal, and is computed as the cumulative mean normalized difference
    function of the difference function of the signal. The fundamental period is the first value of the index `tau`
    between `tau_min` and `tau_max` where the CMND is below the `harmo_th` threshold. If there are no such values, the
    function returns 0 to indicate that the signal is unvoiced.

    Args:
        cmdf (np.ndarray): The Cumulative Mean Normalized Difference function of the signal.
        tau_min (int): The minimum period for speech.
        tau_max (int): The maximum period for speech.
        harmo_th (float, optional): The harmonicity threshold to determine if it is necessary to compute pitch
            frequency. Defaults to 0.1.

    Returns:
        int: The fundamental period of the signal, or 0 if the signal is unvoiced.

    References:
        [1] K. K. Paliwal and R. P. Sharma, "A robust algorithm for pitch detection in noisy speech signals,"
            Speech Communication, vol. 12, no. 3, pp. 249-263, 1993.
    """
    tau = tau_min
    while tau < tau_max:
        if cmdf[tau] < harmo_th:
            while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]:
                tau += 1
            return tau
        tau += 1

    return 0  # if unvoiced


def compute_yin(
    sig_torch: torch.Tensor,
    sr: int,
    w_len: int = 512,
    w_step: int = 256,
    f0_min: int = 100,
    f0_max: int = 500,
    harmo_thresh: float = 0.1,
) -> Tuple[np.ndarray, List[float], List[float], List[float]]:
    r"""Compute the Yin Algorithm for pitch detection on an audio signal.

    The Yin Algorithm is a widely used method for pitch detection in speech and music signals. It works by computing the
    Cumulative Mean Normalized Difference function (CMND) of the difference function of the signal, and finding the first
    minimum of the CMND below a given threshold. The fundamental period of the signal is then estimated as the inverse of
    the lag corresponding to this minimum.

    Args:
        sig_torch (torch.Tensor): The audio signal as a 1D numpy array of floats.
        sr (int): The sampling rate of the signal.
        w_len (int, optional): The size of the analysis window in samples. Defaults to 512.
        w_step (int, optional): The size of the lag between two consecutive windows in samples. Defaults to 256.
        f0_min (int, optional): The minimum fundamental frequency that can be detected in Hz. Defaults to 100.
        f0_max (int, optional): The maximum fundamental frequency that can be detected in Hz. Defaults to 500.
        harmo_thresh (float, optional): The threshold of detection. The algorithm returns the first minimum of the CMND
            function below this threshold. Defaults to 0.1.

    Returns:
        Tuple[np.ndarray, List[float], List[float], List[float]]: A tuple containing the following elements:
            * pitches (np.ndarray): A 1D numpy array of fundamental frequencies estimated for each analysis window.
            * harmonic_rates (List[float]): A list of harmonic rate values for each fundamental frequency value, which
              can be interpreted as a confidence value.
            * argmins (List[float]): A list of the minimums of the Cumulative Mean Normalized Difference Function for
              each analysis window.
            * times (List[float]): A list of the time of each estimation, in seconds.

    References:
        [1] A. K. Jain, Fundamentals of Digital Image Processing, Prentice Hall, 1989.
        [2] A. de Cheveigné and H. Kawahara, "YIN, a fundamental frequency estimator for speech and music," The Journal
            of the Acoustical Society of America, vol. 111, no. 4, pp. 1917-1930, 2002.
    """
    sig_torch = sig_torch.view(1, 1, -1)
    sig_torch = F.pad(
        sig_torch.unsqueeze(1),
        (int((w_len - w_step) / 2), int((w_len - w_step) / 2), 0, 0),
        mode="reflect",
    )

    sig_torch_n: np.ndarray = np.array(sig_torch.view(-1).tolist())

    # sig_torch_n: np.ndarray = sig_torch.view(-1).numpy()

    tau_min = int(sr / f0_max)
    tau_max = int(sr / f0_min)

    timeScale = range(
        0, len(sig_torch_n) - w_len, w_step,
    )  # time values for each analysis window
    times = [t / float(sr) for t in timeScale]
    frames = [sig_torch_n[t : t + w_len] for t in timeScale]

    pitches = [0.0] * len(timeScale)
    harmonic_rates = [0.0] * len(timeScale)
    argmins = [0.0] * len(timeScale)

    for i, frame in enumerate(frames):
        # Compute YIN
        df = differenceFunction(frame, w_len, tau_max)
        cmdf = cumulativeMeanNormalizedDifferenceFunction(df, tau_max)
        p = getPitch(cmdf, tau_min, tau_max, harmo_thresh)

        # Get results
        if np.argmin(cmdf) > tau_min:
            argmins[i] = float(sr / np.argmin(cmdf))
        if p != 0:  # A pitch was found
            pitches[i] = float(sr / p)
            harmonic_rates[i] = cmdf[p]
        else:  # No pitch, but we compute a value of the harmonic rate
            harmonic_rates[i] = min(cmdf)

    return np.array(pitches), harmonic_rates, argmins, times


def norm_interp_f0(f0: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    r"""Normalize and interpolate the fundamental frequency (f0) values.

    Args:
        f0 (np.ndarray): The input f0 values.

    Returns:
        Tuple[np.ndarray, np.ndarray]: A tuple containing the normalized f0 values and a boolean array indicating which values were interpolated.

    Examples:
        >>> f0 = np.array([0, 100, 0, 200, 0])
        >>> norm_interp_f0(f0)
        (
            np.array([100, 100, 150, 200, 200]),
            np.array([True, False, True, False, True]),
        )
    """
    uv: np.ndarray = f0 == 0
    if sum(uv) == len(f0):
        f0[uv] = 0
    elif sum(uv) > 0:
        f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
    return f0, uv


def compute_pitch(
    sig_torch: torch.Tensor,
    sr: int,
    w_len: int = 1024,
    w_step: int = 256,
    f0_min: int = 50,
    f0_max: int = 1000,
    harmo_thresh: float = 0.25,
):
    r"""Compute the pitch of an audio signal using the Yin algorithm.

    The Yin algorithm is a widely used method for pitch detection in speech and music signals. This function uses the
    Yin algorithm to compute the pitch of the input audio signal, and then normalizes and interpolates the pitch values.
    Returns the normalized and interpolated pitch values.

    Args:
        sig_torch (torch.Tensor): The audio signal as a 1D numpy array of floats.
        sr (int): The sampling rate of the signal.
        w_len (int, optional): The size of the analysis window in samples.
        w_step (int, optional): The size of the lag between two consecutive windows in samples.
        f0_min (int, optional): The minimum fundamental frequency that can be detected in Hz.
        f0_max (int, optional): The maximum fundamental frequency that can be detected in Hz.
        harmo_thresh (float, optional): The threshold of detection. The algorithm returns the first minimum of the CMND function below this threshold.

    Returns:
        np.ndarray: The normalized and interpolated pitch values of the audio signal.
    """
    pitch, _, _, _ = compute_yin(
        sig_torch,
        sr=sr,
        w_len=w_len,
        w_step=w_step,
        f0_min=f0_min,
        f0_max=f0_max,
        harmo_thresh=harmo_thresh,
    )

    pitch, _ = norm_interp_f0(pitch)

    return pitch

In [13]:
from dataclasses import dataclass
import math
import random
from typing import Any, List, Tuple, Union

import numpy as np
from scipy.stats import betabinom
import torch
import torch.nn.functional as F

# from models.config import PreprocessingConfig, VocoderBasicConfig, get_lang_map

# from .audio import normalize_loudness, preprocess_audio
# from .audio_processor import AudioProcessor
# from .compute_yin import compute_yin, norm_interp_f0
# from .normalize_text import NormalizeText
# from .tacotron_stft import TacotronSTFT
# from .tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA


@dataclass
class PreprocessForAcousticResult:
    wav: torch.Tensor
    mel: torch.Tensor
    pitch: torch.Tensor
    phones_ipa: Union[str, List[str]]
    phones: torch.Tensor
    attn_prior: torch.Tensor
    energy: torch.Tensor
    raw_text: str
    normalized_text: str
    speaker_id: int
    chapter_id: str | int
    utterance_id: str
    pitch_is_normalized: bool


class PreprocessLibriTTS:
    r"""Preprocessing PreprocessLibriTTS audio and text data for use with a TacotronSTFT model.

    Args:
        preprocess_config (PreprocessingConfig): The preprocessing configuration.
        lang (str): The language of the input text.

    Attributes:
        min_seconds (float): The minimum duration of audio clips in seconds.
        max_seconds (float): The maximum duration of audio clips in seconds.
        hop_length (int): The hop length of the STFT.
        sampling_rate (int): The sampling rate of the audio.
        use_audio_normalization (bool): Whether to normalize the loudness of the audio.
        tacotronSTFT (TacotronSTFT): The TacotronSTFT object used for computing mel spectrograms.
        min_samples (int): The minimum number of audio samples in a clip.
        max_samples (int): The maximum number of audio samples in a clip.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        lang: str = "en",
    ):
        super().__init__()

        lang_map = get_lang_map(lang)

        self.phonemizer_lang = lang_map.phonemizer
        normilize_text_lang = lang_map.nemo

        self.normilize_text = NormalizeText(normilize_text_lang)
        self.tokenizer = TokenizerIpaEspeak(lang)
        self.vocoder_train_config = VocoderBasicConfig()

        self.preprocess_config = preprocess_config

        self.sampling_rate = self.preprocess_config.sampling_rate
        self.use_audio_normalization = self.preprocess_config.use_audio_normalization

        self.hop_length = self.preprocess_config.stft.hop_length
        self.filter_length = self.preprocess_config.stft.filter_length
        self.mel_fmin = self.preprocess_config.stft.mel_fmin
        self.win_length = self.preprocess_config.stft.win_length

        self.tacotronSTFT = TacotronSTFT(
            filter_length=self.filter_length,
            hop_length=self.hop_length,
            win_length=self.preprocess_config.stft.win_length,
            n_mel_channels=self.preprocess_config.stft.n_mel_channels,
            sampling_rate=self.sampling_rate,
            mel_fmin=self.mel_fmin,
            mel_fmax=self.preprocess_config.stft.mel_fmax,
            center=False,
        )

        min_seconds, max_seconds = (
            self.preprocess_config.min_seconds,
            self.preprocess_config.max_seconds,
        )

        self.min_samples = int(self.sampling_rate * min_seconds)
        self.max_samples = int(self.sampling_rate * max_seconds)

        self.audio_processor = AudioProcessor()

    def beta_binomial_prior_distribution(
        self,
        phoneme_count: int,
        mel_count: int,
        scaling_factor: float = 1.0,
    ) -> torch.Tensor:
        r"""Computes the beta-binomial prior distribution for the attention mechanism.

        Args:
            phoneme_count (int): Number of phonemes in the input text.
            mel_count (int): Number of mel frames in the input mel-spectrogram.
            scaling_factor (float, optional): Scaling factor for the beta distribution. Defaults to 1.0.

        Returns:
            torch.Tensor: A 2D tensor containing the prior distribution.
        """
        P, M = phoneme_count, mel_count
        x = np.arange(0, P)
        mel_text_probs = []
        for i in range(1, M + 1):
            a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
            rv: Any = betabinom(P, a, b)
            mel_i_prob = rv.pmf(x)
            mel_text_probs.append(mel_i_prob)
        return torch.tensor(np.array(mel_text_probs))

    def acoustic(
        self,
        row: Tuple[torch.Tensor, int, str, str, int, str | int, str],
    ) -> Union[None, PreprocessForAcousticResult]:
        r"""Preprocesses audio and text data for use with a TacotronSTFT model.

        Args:
            row (Tuple[torch.FloatTensor, int, str, str, int, str | int, str]): The input row. The row is a tuple containing the following elements: (audio, sr_actual, raw_text, normalized_text, speaker_id, chapter_id, utterance_id).

        Returns:
            dict: A dictionary containing the preprocessed audio and text data.

        Examples:
            >>> preprocess_audio = PreprocessAudio("english_only")
            >>> audio = torch.randn(1, 44100)
            >>> sr_actual = 44100
            >>> raw_text = "Hello, world!"
            >>> output = preprocess_audio(audio, sr_actual, raw_text)
            >>> output.keys()
            dict_keys(['wav', 'mel', 'pitch', 'phones', 'raw_text', 'normalized_text', 'speaker_id', 'chapter_id', 'utterance_id', 'pitch_is_normalized'])
        """
        (
            audio,
            sr_actual,
            raw_text,
            normalized_text,
            speaker_id,
            chapter_id,
            utterance_id,
        ) = row

        wav, sampling_rate = preprocess_audio(audio, sr_actual, self.sampling_rate)

        # TODO: check this, maybe you need to move it to some other place
        # TODO: maybe we can increate the max_samples ?
        # if wav.shape[0] < self.min_samples or wav.shape[0] > self.max_samples:
        #     return None

        if self.use_audio_normalization:
            wav = normalize_loudness(wav)

        normalized_text = self.normilize_text(normalized_text)

        # NOTE: fixed version of tokenizer with punctuation
        phones_ipa, phones = self.tokenizer(normalized_text)

        # Convert to tensor
        phones = torch.Tensor(phones)

        mel_spectrogram = self.tacotronSTFT.get_mel_from_wav(wav)

        # Skipping small sample due to the mel-spectrogram containing less than self.mel_fmin frames
        # if mel_spectrogram.shape[1] < self.mel_fmin:
        #     return None

        # Text is longer than mel, will be skipped due to monotonic alignment search
        if phones.shape[0] >= mel_spectrogram.shape[1]:
            return None

        pitch, _, _, _ = compute_yin(
            wav,
            sr=sampling_rate,
            w_len=self.filter_length,
            w_step=self.hop_length,
            f0_min=50,
            f0_max=1000,
            harmo_thresh=0.25,
        )

        pitch, _ = norm_interp_f0(pitch)

        if np.sum(pitch != 0) <= 1:
            return None

        pitch = torch.tensor(pitch)

        # TODO this shouldnt be necessary, currently pitch sometimes has 1 less frame than spectrogram,
        # We should find out why
        mel_spectrogram = mel_spectrogram[:, : pitch.shape[0]]

        attn_prior = self.beta_binomial_prior_distribution(
            phones.shape[0],
            mel_spectrogram.shape[1],
        ).T

        assert pitch.shape[0] == mel_spectrogram.shape[1], (
            pitch.shape,
            mel_spectrogram.shape[1],
        )

        energy = self.audio_processor.wav_to_energy(
            wav.unsqueeze(0),
            self.filter_length,
            self.hop_length,
            self.win_length,
        )

        return PreprocessForAcousticResult(
            wav=wav,
            mel=mel_spectrogram,
            pitch=pitch,
            attn_prior=attn_prior,
            energy=energy,
            phones_ipa=phones_ipa,
            phones=phones,
            raw_text=raw_text,
            normalized_text=normalized_text,
            speaker_id=speaker_id,
            chapter_id=chapter_id,
            utterance_id=utterance_id,
            # TODO: check the pitch normalization process
            pitch_is_normalized=False,
        )

    def univnet(self, row: Tuple[torch.Tensor, int, str, str, int, str | int, str]):
        r"""Preprocesses audio data for use with a UnivNet model.

        This method takes a row of data, extracts the audio and preprocesses it.
        It then selects a random segment from the preprocessed audio and its corresponding mel spectrogram.

        Args:
            row (Tuple[torch.FloatTensor, int, str, str, int, str | int, str]): The input row. The row is a tuple containing the following elements: (audio, sr_actual, raw_text, normalized_text, speaker_id, chapter_id, utterance_id).

        Returns:
            Tuple[torch.Tensor, torch.Tensor, int]: A tuple containing the selected segment of the mel spectrogram, the corresponding audio segment, and the speaker ID.

        Examples:
            >>> preprocess = PreprocessLibriTTS()
            >>> audio = torch.randn(1, 44100)
            >>> sr_actual = 44100
            >>> speaker_id = 0
            >>> mel, audio_segment, speaker_id = preprocess.preprocess_univnet((audio, sr_actual, "", "", speaker_id, 0, ""))
        """
        (
            audio,
            sr_actual,
            _,
            _,
            speaker_id,
            _,
            _,
        ) = row

        segment_size = self.vocoder_train_config.segment_size
        frames_per_seg = math.ceil(segment_size / self.hop_length)

        wav, _ = preprocess_audio(audio, sr_actual, self.sampling_rate)

        if self.use_audio_normalization:
            wav = normalize_loudness(wav)

        mel_spectrogram = self.tacotronSTFT.get_mel_from_wav(wav)

        if wav.shape[0] < segment_size:
            wav = F.pad(
                wav,
                (0, segment_size - wav.shape[0]),
                "constant",
            )

        if mel_spectrogram.shape[1] < frames_per_seg:
            mel_spectrogram = F.pad(
                mel_spectrogram,
                (0, frames_per_seg - mel_spectrogram.shape[1]),
                "constant",
            )

        from_frame = random.randint(0, mel_spectrogram.shape[1] - frames_per_seg)

        # Skip last frame, otherwise errors are thrown, find out why
        if from_frame > 0:
            from_frame -= 1

        till_frame = from_frame + frames_per_seg

        mel_spectrogram = mel_spectrogram[:, from_frame:till_frame]
        wav = wav[from_frame * self.hop_length : till_frame * self.hop_length]

        return mel_spectrogram, wav, speaker_id

In [14]:
from typing import List, Union

import torch
from torch import Tensor, nn


def pad_1D(inputs: List[Tensor], pad_value: float = 0.0) -> Tensor:
    r"""Pad a list of 1D tensor list to the same length.

    Args:
        inputs (List[torch.Tensor]): List of 1D numpy arrays to pad.
        pad_value (float): Value to use for padding. Default is 0.0.

    Returns:
        torch.Tensor: Padded 2D numpy array of shape (len(inputs), max_len), where max_len is the length of the longest input array.
    """
    max_len = max(x.size(0) for x in inputs)
    padded_inputs = [nn.functional.pad(x, (0, max_len - x.size(0)), value=pad_value) for x in inputs]
    return torch.stack(padded_inputs)


def pad_2D(
    inputs: List[Tensor], maxlen: Union[int, None] = None, pad_value: float = 0.0,
) -> Tensor:
    r"""Pad a list of 2D tensor arrays to the same length.

    Args:
        inputs (List[torch.Tensor]): List of 2D numpy arrays to pad.
        maxlen (Union[int, None]): Maximum length to pad the arrays to. If None, pad to the length of the longest array. Default is None.
        pad_value (float): Value to use for padding. Default is 0.0.

    Returns:
        torch.Tensor: Padded 3D numpy array of shape (len(inputs), max_len, input_dim), where max_len is the maximum length of the input arrays, and input_dim is the dimension of the input arrays.
    """
    max_len = max(x.size(1) for x in inputs) if maxlen is None else maxlen

    padded_inputs = [nn.functional.pad(x, (0, max_len - x.size(1), 0, 0), value=pad_value) for x in inputs]
    return torch.stack(padded_inputs)


def pad_3D(inputs: Union[Tensor, List[Tensor]], B: int, T: int, L: int) -> Tensor:
    r"""Pad a 3D torch tensor to a specified shape.

    Args:
        inputs (torch.Tensor): 3D numpy array to pad.
        B (int): Batch size to pad the array to.
        T (int): Time steps to pad the array to.
        L (int): Length to pad the array to.

    Returns:
        torch.Tensor: Padded 3D numpy array of shape (B, T, L), where B is the batch size, T is the time steps, and L is the length.
    """
    if isinstance(inputs, list):
        inputs_padded = torch.zeros(B, T, L, dtype=inputs[0].dtype)
        for i, input_ in enumerate(inputs):
            inputs_padded[i, :input_.size(0), :input_.size(1)] = input_

    elif isinstance(inputs, torch.Tensor):
        inputs_padded = torch.zeros(B, T, L, dtype=inputs.dtype)
        inputs_padded[:inputs.size(0), :inputs.size(1), :inputs.size(2)] = inputs

    return inputs_padded

In [15]:
from typing import Any, Dict, List

import numpy as np
import torch
from torch.utils.data import Dataset
from torchaudio import datasets

# from models.config import PreprocessingConfigUnivNet, get_lang_map

# from training.preprocess import PreprocessLibriTTS
# from training.tools import pad_1D, pad_2D


class LibriTTSDatasetVocoder(Dataset):
    r"""Loading preprocessed univnet model data."""

    def __init__(
        self,
        root: str,
        batch_size: int,
        download: bool = True,
        lang: str = "en",
    ):
        r"""A PyTorch dataset for loading preprocessed univnet data.

        Args:
            root (str): Path to the directory where the dataset is found or downloaded.
            batch_size (int): Batch size for the dataset.
            download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
        """
        self.dataset = datasets.LIBRITTS(root=root, download=download)
        self.batch_size = batch_size

        lang_map = get_lang_map(lang)
        self.preprocess_libtts = PreprocessLibriTTS(
            PreprocessingConfigUnivNet(lang_map.processing_lang_type),
        )

    def __len__(self) -> int:
        r"""Returns the number of samples in the dataset.

        Returns
            int: Number of samples in the dataset.
        """
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r"""Returns a sample from the dataset at the given index.

        Args:
            idx (int): Index of the sample to return.

        Returns:
            Dict[str, Any]: A dictionary containing the sample data.
        """
        # Retrive the dataset row
        data = self.dataset[idx]

        data = self.preprocess_libtts.univnet(data)

        if data is None:
            # print("Skipping due to preprocessing error")
            rand_idx = np.random.randint(0, self.__len__())
            return self.__getitem__(rand_idx)

        mel, audio, speaker_id = data

        return {
            "mel": mel,
            "audio": audio,
            "speaker_id": speaker_id,
        }

    def collate_fn(self, data: List) -> List:
        r"""Collates a batch of data samples.

        Args:
            data (List): A list of data samples.

        Returns:
            List: A list of reprocessed data batches.
        """
        data_size = len(data)

        idxs = list(range(data_size))

        # Initialize empty lists to store extracted values
        empty_lists: List[List] = [[] for _ in range(4)]
        (
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ) = empty_lists

        # Extract fields from data dictionary and populate the lists
        for idx in idxs:
            data_entry = data[idx]

            mels.append(data_entry["mel"])
            mel_lens.append(data_entry["mel"].shape[1])
            audios.append(data_entry["audio"])
            speaker_ids.append(data_entry["speaker_id"])

        mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
        mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
        audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
        speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)

        return [
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ]

In [16]:
# batch_size = 2
# dataset = LibriTTSDatasetVocoder(
#             root="datasets_cache/LIBRITTS",
#             batch_size=batch_size,
#             download=False,
#         )

In [17]:
# assert len(dataset) == 33236

In [18]:
# sample = dataset[0]

# assert sample["mel"].shape == torch.Size([100, 64])
# assert sample["audio"].shape == torch.Size([16384])
# assert sample["speaker_id"] == 1034

In [19]:
# data = [
#     dataset[0],
#     dataset[2],
# ]
# result = dataset.collate_fn(data)
# assert len(result) == 4
# for batch in result:
#     assert len(batch) == batch_size

In [20]:
from torch.utils.data import DataLoader

# dataloader = DataLoader(
#     dataset,
#     batch_size=batch_size,
#     shuffle=False,
#     collate_fn=dataset.collate_fn,
# )

# dataloader_iter = iter(dataloader)

# for _, items in enumerate([next(dataloader_iter), next(dataloader_iter)]):
#     assert len(items) == 4
#     for it in items:
#         assert len(it) == batch_size

In [21]:
@dataclass
class ConformerConfig:
    n_layers: int
    n_heads: int
    n_hidden: int
    p_dropout: float
    kernel_size_conv_mod: int
    kernel_size_depthwise: int
    with_ff: bool
    
@dataclass
class ReferenceEncoderConfig:
    bottleneck_size_p: int
    bottleneck_size_u: int
    ref_enc_filters: List[int]
    ref_enc_size: int
    ref_enc_strides: List[int]
    ref_enc_pad: List[int]
    ref_enc_gru_size: int
    ref_attention_dropout: float
    token_num: int
    predictor_kernel_size: int


@dataclass
class VarianceAdaptorConfig:
    n_hidden: int
    kernel_size: int
    emb_kernel_size: int
    p_dropout: float
    n_bins: int


@dataclass
class AcousticLossConfig:
    ssim_loss_alpha: float
    mel_loss_alpha: float
    aligner_loss_alpha: float
    pitch_loss_alpha: float
    energy_loss_alpha: float
    u_prosody_loss_alpha: float
    p_prosody_loss_alpha: float
    dur_loss_alpha: float
    binary_align_loss_alpha: float
    binary_loss_warmup_epochs: int

@dataclass
class AcousticENModelConfig:
    speaker_embed_dim: int = 1024
    lang_embed_dim: int = 1
    encoder: ConformerConfig = field(
        default_factory=lambda: ConformerConfig(
            n_layers=6,
            n_heads=8,
            n_hidden=512,
            p_dropout=0.1,
            kernel_size_conv_mod=7,
            kernel_size_depthwise=7,
            with_ff=True,
        ),
    )
    decoder: ConformerConfig = field(
        default_factory=lambda: ConformerConfig(
            n_layers=6,
            n_heads=8,
            n_hidden=512,
            p_dropout=0.1,
            kernel_size_conv_mod=11,
            kernel_size_depthwise=11,
            with_ff=True,
        ),
    )
    reference_encoder: ReferenceEncoderConfig = field(
        default_factory=lambda: ReferenceEncoderConfig(
            bottleneck_size_p=4,
            bottleneck_size_u=256,
            ref_enc_filters=[32, 32, 64, 64, 128, 128],
            ref_enc_size=3,
            ref_enc_strides=[1, 2, 1, 2, 1],
            ref_enc_pad=[1, 1],
            ref_enc_gru_size=32,
            ref_attention_dropout=0.2,
            token_num=32,
            predictor_kernel_size=5,
        ),
    )
    variance_adaptor: VarianceAdaptorConfig = field(
        default_factory=lambda: VarianceAdaptorConfig(
            n_hidden=512,
            kernel_size=5,
            emb_kernel_size=3,
            p_dropout=0.5,
            n_bins=256,
        ),
    )
    loss: AcousticLossConfig = field(
        default_factory=lambda: AcousticLossConfig(
            ssim_loss_alpha=1.0,
            mel_loss_alpha=1.0,
            aligner_loss_alpha=1.0,
            pitch_loss_alpha=1.0,
            energy_loss_alpha=1.0,
            u_prosody_loss_alpha=0.25,
            p_prosody_loss_alpha=0.25,
            dur_loss_alpha=1.0,
            binary_align_loss_alpha=0.1,
            binary_loss_warmup_epochs=10,
        ),
    )

AcousticModelConfigType = Union[AcousticENModelConfig]

In [22]:
SUPPORTED_LANGUAGES = [
    "en",
    "uk",
]

lang2id = {s: i for i, s in enumerate(SUPPORTED_LANGUAGES)}

In [23]:
@dataclass
class AcousticTrainingOptimizerConfig:
    learning_rate: float
    weight_decay: float
    lr_decay: float
    betas: Tuple[float, float] = (0.9, 0.98)
    eps: float = 0.000000001
    grad_clip_thresh: float = 1.0
    warm_up_step: float = 4000
    anneal_steps: List[int] = field(default_factory=list)
    anneal_rate: float = 0.3

@dataclass
class AcousticPretrainingConfig:
    batch_size = 5
    grad_acc_step = 5
    train_steps = 500000
    log_step = 20
    synth_step = 250
    val_step = 4000
    save_step = 1000
    freeze_bert_until = 4000
    mcd_gen_max_samples = 400
    only_train_speaker_until = 0
    optimizer_config: AcousticTrainingOptimizerConfig = field(
        default_factory=lambda: AcousticTrainingOptimizerConfig(
            learning_rate=0.0002,
            weight_decay=0.01,
            lr_decay=1.0,
        ),
    )

In [24]:
class DepthWiseConv1d(Module):
    r"""Implements Depthwise 1D convolution. This module will apply a spatial convolution over inputs
    independently over each input channel in the style of depthwise convolutions.

    In a depthwise convolution, each input channel is convolved with its own set of filters, as opposed
    to standard convolutions where each input channel is convolved with all filters.
    At `groups=in_channels`, each input channel is convolved with its own set of filters.
    Filters in the
    DepthwiseConv1d are not shared among channels. This method can drastically reduce the number of
    parameters/learnable weights in the model, as each input channel gets its own filter.

    This technique is best suited to scenarios where the correlation between different channels is
    believed to be low. It is commonly employed in MobileNet models due to the reduced number of
    parameters, which is critical in mobile devices where computational resources are limited.

    Args:
        in_channels (int): Number of channels in the input
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int): Size of the convolving kernel
        padding (int): Zero-padding added to both sides of the input

    Shape:
        - Input: (N, C_in, L_in)
        - Output: (N, C_out, L_out), where

          `L_out = [L_in + 2*padding - (dilation*(kernel_size-1) + 1)]/stride + 1`

    Attributes:
        weight (Tensor): the learnable weights of shape (`out_channels`, `in_channels`/`group`, `kernel_size`)
        bias (Tensor, optional): the learnable bias of the module of shape (`out_channels`)

    Examples:
    ```python
    m = DepthWiseConv1d(16, 33, 3, padding=1)
    input = torch.randn(20, 16, 50)
    output = m(input)
    ```
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        padding: int,
    ):
        super().__init__()

        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            padding=padding,
            groups=in_channels,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call.

        Args:
            x: input tensor of shape (batch_size, in_channels, signal_length)

        Returns:
            output tensor of shape (batch_size, out_channels, signal_length)
        """
        return self.conv(x)


class PointwiseConv1d(Module):
    r"""Applies a 1D pointwise (aka 1x1) convolution over an input signal composed of several input
    planes, officially known as channels in this context.

    The operation implemented is also known as a "channel mixing" operation, as each output channel can be
    seen as a linear combination of input channels.

    In the simplest case, the output value of the layer with input size
    (N, C_in, L) and output (N, C_out, L_out) can be
    precisely described as:

    $$out(N_i, C_{out_j}) = bias(C_{out_j}) +
        weight(C_{out_j}, k) * input(N_i, k)$$

    where 'N' is a batch size, 'C' denotes a number of channels,
    'L' is a length of signal sequence.
    The symbol '*' in the above indicates a 1D cross-correlation operation.

    The 1D cross correlation operation "*": [Wikipedia Cross-correlation](https://en.wikipedia.org/wiki/Cross-correlation)

    This module supports `TensorFloat32<tf32_on_ampere>`.

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        stride (int): Stride of the convolution. Default: 1
        padding (int): Zero-padding added to both sides of the input. Default: 0
        bias (bool): If set to False, the layer will not learn an additive bias. Default: True
        kernel_size (int): Size of the convolving kernel. Default: 1

    Shape:
        - Input: (N, C_in, L_in)
        - Output: (N, C_out, L_out), where

          L_out = [L_in + 2*padding - (dilation*(kernel_size-1) + 1)]/stride + 1

    Attributes:
        weight (Tensor): the learnable weights of shape (out_channels, in_channels, kernel_size)
        bias (Tensor, optional): the learnable bias of the module of shape (out_channels)

    Example:
    ```python
    m = PointwiseConv1d(16, 33, 1, padding=0, bias=True)
    input = torch.randn(20, 16, 50)
    output = m(input)
    ```


    Description of parameters:
        stride (default 1): Controls the stride for the operation, which is the number of steps the convolutional
        kernel moves for each operation. A stride of 1 means that the kernel moves one step at a time and a stride
        of 2 means skipping every other step. Higher stride values can down sample the output and lead to smaller
        output shapes.

        padding (default 0): Controls the amount of padding applied to the input. By adding padding, the spatial
        size of the output can be controlled. If it is set to 0, no padding is applied. If it is set to 1, zero
        padding of one pixel width is added to the input data.

        bias (default True): Controls whether the layer uses a bias vector. By default, it is True, meaning that
        the layer has a learnable bias parameter.

        kernel_size (default 1): The size of the convolving kernel. In the case of 1D convolution, kernel_size is
        a single integer that specifies the number of elements the filter that convolves the input should have.
        In your PointwiseConv1d case, the default kernel size is 1, indicating a 1x1 convolution is applied
        which is commonly known as a pointwise convolution.

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        padding: int = 0,
        bias: bool = True,
        kernel_size: int = 1,
    ):
        super().__init__()

        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            bias=bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Defines the computation performed at every call.

        Args:
            x (torch.Tensor): input tensor of shape (batch_size, in_channels, signal_length)

        Returns:
            output (torch.Tensor): tensor of shape (batch_size, out_channels, signal_length)
        """
        return self.conv(x)

In [25]:
class BSConv1d(Module):
    r"""`BSConv1d` implements the `BSConv` concept which is based on the paper [BSConv:
    Binarized Separated Convolutional Neural Networks](https://arxiv.org/pdf/2003.13549.pdf).

    `BSConv` is an amalgamation of depthwise separable convolution and pointwise convolution.
    Depthwise separable convolution utilizes far fewer parameters by separating the spatial
    (depthwise) and channel-wise (pointwise) operations. Meanwhile, pointwise convolution
    helps in transforming the channel characteristics without considering the channel's context.

    Args:
        channels_in (int): Number of input channels
        channels_out (int): Number of output channels produced by the convolution
        kernel_size (int): Size of the kernel used in depthwise convolution
        padding (int): Zeropadding added around the input tensor along the height and width directions

    Attributes:
        pointwise (PointwiseConv1d): Pointwise convolution module
        depthwise (DepthWiseConv1d): Depthwise separable convolution module
    """

    def __init__(
        self,
        channels_in: int,
        channels_out: int,
        kernel_size: int,
        padding: int,
    ):
        super().__init__()

        # Instantiate Pointwise Convolution Module:
        # First operation in BSConv: the number of input channels is transformed to the number
        # of output channels without taking into account the channel context.
        self.pointwise = PointwiseConv1d(channels_in, channels_out)

        # Instantiate Depthwise Convolution Module:
        # Second operation in BSConv: A spatial convolution is performed independently over each output
        # channel from the pointwise convolution.
        self.depthwise = DepthWiseConv1d(
            channels_out,
            channels_out,
            kernel_size=kernel_size,
            padding=padding,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Propagate input tensor through pointwise convolution.
        x1 = self.pointwise(x)

        # Propagate the result of the previous pointwise convolution through the depthwise convolution.
        # Return final output of the sequence of pointwise and depthwise convolutions
        return self.depthwise(x1)

In [26]:
class Conv1dGLU(Module):
    r"""`Conv1dGLU` implements a variant of Convolutional Layer with a Gated Linear Unit (GLU).
    It's based on the Deep Voice 3 project.

    Args:
        d_model (int): model dimension parameter.
        kernel_size (int): kernel size for the convolution layer.
        padding (int): padding size for the convolution layer.
        embedding_dim (int): dimension of the embedding.

    Attributes:
         bsconv1d (BSConv1d) : an instance of the Binarized Separated Convolution (1d)
         embedding_proj (torch.nn.Modules.Linear): linear transformation for embeddings.
         sqrt (torch.Tensor): buffer that stores the square root of 0.5
         softsign (torch.nn.SoftSign): SoftSign Activation function
    """

    def __init__(
        self,
        d_model: int,
        kernel_size: int,
        padding: int,
        embedding_dim: int,
    ):
        super().__init__()

        self.bsconv1d = BSConv1d(
            d_model,
            2 * d_model,
            kernel_size=kernel_size,
            padding=padding,
        )

        self.embedding_proj = nn.Linear(
            embedding_dim,
            d_model,
        )

        self.register_buffer("sqrt", torch.sqrt(torch.tensor([0.5])).squeeze(0))

        self.softsign = torch.nn.Softsign()

    def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor:
        """Forward propagation method for the Conv1dGLU layer.

        Args:
            x (torch.Tensor): input tensor
            embeddings (torch.Tensor): input embeddings

        Returns:
            x (torch.Tensor): output tensor after application of Conv1dGLU
        """
        x = x.permute((0, 2, 1))
        residual = x
        x = self.bsconv1d(x)
        splitdim = 1
        a, b = x.split(x.size(splitdim) // 2, dim=splitdim)
        embeddings = self.embedding_proj(embeddings)
        softsign = self.softsign(embeddings)
        a = a + softsign.permute((0, 2, 1))
        x = a * torch.sigmoid(b)
        x = x + residual
        x = x * self.sqrt
        return x.permute((0, 2, 1))

In [27]:
LEAKY_RELU_SLOPE = 0.3

In [28]:
class FeedForward(Module):
    r"""Creates a feed-forward neural network.
    The network includes a layer normalization, an activation function (LeakyReLU), and dropout layers.

    Args:
        d_model (int): The number of expected features in the input.
        kernel_size (int): The size of the convolving kernel for the first conv1d layer.
        dropout (float): The dropout probability.
        expansion_factor (int, optional): The expansion factor for the hidden layer size in the feed-forward network, default is 4.
        leaky_relu_slope (float, optional): Controls the angle of the negative slope of LeakyReLU activation, default is `LEAKY_RELU_SLOPE`.
    """

    def __init__(
        self,
        d_model: int,
        kernel_size: int,
        dropout: float,
        expansion_factor: int = 4,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(d_model)
        self.conv_1 = nn.Conv1d(
            d_model,
            d_model * expansion_factor,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
        )
        self.act = nn.LeakyReLU(leaky_relu_slope)
        self.conv_2 = nn.Conv1d(d_model * expansion_factor, d_model, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the feed-forward neural network.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, num_features).

        Returns:
            Tensor: Output tensor of shape (batch_size, seq_len, num_features).
        """
        # Apply layer normalization
        x = self.ln(x)

        # Forward pass through the first convolution layer, activation layer and dropout layer
        x = x.permute((0, 2, 1))
        x = self.conv_1(x)
        x = x.permute((0, 2, 1))
        x = self.act(x)
        x = self.dropout(x)

        # Forward pass through the second convolution layer and dropout layer
        x = x.permute((0, 2, 1))
        x = self.conv_2(x)
        x = x.permute((0, 2, 1))
        x = self.dropout(x)

        # Scale the output by 0.5 (this helps with training stability)
        return 0.5 * x

In [29]:
class GLUActivation(Module):
    r"""Implements the Gated Linear Unit (GLU) activation function.

    The GLU activation splits the input in half across the channel dimension.
    One half is passed through a nonlinear activation function (like sigmoid or leaky ReLU),
    and the output from this activation function is used as a gate to control the
    amplitude of the other half of the input. An element-wise multiplication is then performed
    between the gating signal and the other half of the input.

    The GLU activation allows the model to dynamically choose which inputs to pass through and
    what information to suppress, which can help improving the model performance on certain tasks.

    Args:
        slope: Controls the slope for the leaky ReLU activation function. Default: 0.3 or see the const `LEAKY_RELU_SLOPE`

    Shape:
        - Input: (N, 2*C, L) where C is the number of input channels.
        - Output: (N, C, L)

    Examples:
    ```python
    m = GLUActivation(0.3)
    input = torch.randn(16, 2*20, 44)
    output = m(input)
    ```

    """

    def __init__(self, slope: float = LEAKY_RELU_SLOPE):
        super().__init__()
        self.lrelu = nn.LeakyReLU(slope)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call.

        Args:
            x: The input tensor of shape (batch_size, 2*channels, signal_length)

        Returns:
            x: The output tensor of shape (batch_size, channels, signal_length)
        """
        # Split the input into two equal parts (chunks) along dimension 1
        out, gate = x.chunk(2, dim=1)

        # Perform element-wise multiplication of the first half (out)
        # with the result of applying LeakyReLU on the second half (gate)
        return out * self.lrelu(gate)

In [30]:
def calc_same_padding(kernel_size: int) -> Tuple[int, int]:
    r"""Calculates the necessary padding for 'same' padding in convolutional operations.

    For 'same' padding, the output size is the same as the input size for `stride=1`. This function returns
    two integers, representing the padding to be added on either side of the input to achieve 'same' padding.

    Args:
        kernel_size (int): Size of the convolving kernel.

    Returns:
        Tuple[int, int]: A tuple of two integers representing the number of padding elements to be applied on
        left and right (or top and bottom for 2D) of the input tensor respectively.
    """
    # Check if kernel_size is an integer greater than zero
    if not isinstance(kernel_size, int) or kernel_size <= 0:
        raise ValueError("kernel_size must be an integer greater than zero")

    # Determine base padding amount (equal to half the kernel size, truncated down)
    pad = kernel_size // 2

    # Return padding for each side of the kernel. If kernel size is odd, padding is (pad, pad).
    # If kernel size is even, padding is (pad, pad - 1) because we can't pad equally on both sides.
    return (pad, pad - (kernel_size + 1) % 2)

class ConformerConvModule(Module):
    r"""Conformer Convolution Module class represents a module in the Conformer model architecture.
    The module includes a layer normalization, pointwise and depthwise convolutional layers,
    Gated Linear Units (GLU) activation, and dropout layer.

    Args:
        d_model (int): The number of expected features in the input.
        expansion_factor (int): The expansion factor for the hidden layer size in the feed-forward network, default is 2.
        kernel_size (int): The size of the convolving kernel, default is 7.
        dropout (float): The dropout probability, default is 0.1.
        leaky_relu_slope (float): Controls the angle of the negative slope of the LeakyReLU activation, default is `LEAKY_RELU_SLOPE`.
    """

    def __init__(
        self,
        d_model: int,
        expansion_factor: int = 2,
        kernel_size: int = 7,
        dropout: float = 0.1,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ):
        super().__init__()
        inner_dim = d_model * expansion_factor
        self.ln_1 = nn.LayerNorm(d_model)
        self.conv_1 = PointwiseConv1d(
            d_model,
            inner_dim * 2,
        )
        self.conv_act = GLUActivation()
        self.depthwise = DepthWiseConv1d(
            inner_dim,
            inner_dim,
            kernel_size=kernel_size,
            padding=calc_same_padding(kernel_size)[0],
        )
        self.ln_2 = nn.GroupNorm(
            1,
            inner_dim,
        )
        self.activation = nn.LeakyReLU(leaky_relu_slope)
        self.conv_2 = PointwiseConv1d(
            inner_dim,
            d_model,
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the Conformer conv module.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, num_features).

        Returns:
            Tensor: The output tensor of shape (batch_size, seq_len, num_features).
        """
        x = self.ln_1(x)
        x = x.permute(0, 2, 1)
        x = self.conv_1(x)
        x = self.conv_act(x)
        x = self.depthwise(x)
        x = self.ln_2(x)
        x = self.activation(x)
        x = self.conv_2(x)
        x = x.permute(0, 2, 1)
        return self.dropout(x)

In [31]:
class RelativeMultiHeadAttention(Module):
    r"""Multi-head attention with relative positional encoding.
    This concept was proposed in the
    [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860)

    Args:
        d_model (int): The dimension of model
        num_heads (int): The number of attention heads.

    Inputs: query, key, value, pos_embedding, mask
        - **query** (batch, time, dim): Tensor containing query vector
        - **key** (batch, time, dim): Tensor containing key vector
        - **value** (batch, time, dim): Tensor containing value vector
        - **pos_embedding** (batch, time, dim): Positional embedding tensor
        - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
    Returns:
        - **outputs**: Tensor produces by relative multi head attention module.

    Note: `d_model` should be divisible by `num_heads` in other words `d_model % num_heads` should be zero.
    """

    def __init__(
        self,
        d_model: int = 512,
        num_heads: int = 16,
    ):
        super().__init__()
        assert d_model % num_heads == 0, "d_model % num_heads should be zero."
        self.d_model = d_model
        self.d_head = int(d_model / num_heads)
        self.num_heads = num_heads
        self.sqrt_dim = math.sqrt(d_model)

        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model, bias=False)
        self.value_proj = nn.Linear(d_model, d_model, bias=False)
        self.pos_proj = nn.Linear(d_model, d_model, bias=False)

        self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
        self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))

        torch.nn.init.xavier_uniform_(self.u_bias)
        torch.nn.init.xavier_uniform_(self.v_bias)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        pos_embedding: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Function applies multi-head attention along with relative positional encoding to the inputs. It restructures the input queries, keys, and values according to individual attention heads, applies biases, calculates content and position scores, and combines these to get the final score. A softmax activation is applied over the final score, followed by the calculation of context (contextual representation of input).

        Performs the forward pass on the queries, keys, values, and positional embeddings with a mask.

        Args:
            query (torch.Tensor): The input tensor containing query vectors.
            key (torch.Tensor): The input tensor containing key vectors.
            value (torch.Tensor): The input tensor containing value vectors.
            pos_embedding (torch.Tensor): The positional embedding tensor.
            mask (torch.Tensor): The mask tensor containing indices to be masked.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The context and attention tensors.
            Tensor produces by relative multi head attention module.
        """
        batch_size = query.shape[0]
        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
        key = (
            self.key_proj(key)
            .view(batch_size, -1, self.num_heads, self.d_head)
            .permute(0, 2, 1, 3)
        )
        value = (
            self.value_proj(value)
            .view(batch_size, -1, self.num_heads, self.d_head)
            .permute(0, 2, 1, 3)
        )
        pos_embedding = self.pos_proj(pos_embedding).view(
            batch_size, -1, self.num_heads, self.d_head,
        )
        u_bias = self.u_bias.expand_as(query)
        v_bias = self.v_bias.expand_as(query)
        a = (query + u_bias).transpose(1, 2)
        content_score = a @ key.transpose(2, 3)
        b = (query + v_bias).transpose(1, 2)
        pos_score = b @ pos_embedding.permute(0, 2, 3, 1)
        pos_score = self._relative_shift(pos_score)

        score = content_score + pos_score
        score = score * (1.0 / self.sqrt_dim)

        score.masked_fill_(mask, -1e9)

        attn = F.softmax(score, -1)

        context = (attn @ value).transpose(1, 2)
        context = context.contiguous().view(batch_size, -1, self.d_model)

        return self.out_proj(context), attn

    def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor:
        r"""The main idea of relative positional encoding is that the attention score doesn't only depend on the query and the key, but also on the relative position of the key with respect to the query. This becomes particularly useful when working with sequences of tokens, like in NLP tasks, as it helps the model to be aware of the position of the words (or tokens) in the sentence.

        Performs the relative shift operation on the positional scores.

        Args:
            pos_score (torch.Tensor): The positional scores tensor.

        Returns:
            torch.Tensor: The shifted positional scores tensor.
        """
        batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
        zeros = torch.zeros(
            (batch_size, num_heads, seq_length1, 1), device=pos_score.device,
        )
        padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
        padded_pos_score = padded_pos_score.view(
            batch_size, num_heads, seq_length2 + 1, seq_length1,
        )
        return padded_pos_score[:, :, 1:].view_as(pos_score)

In [32]:
class ConformerMultiHeadedSelfAttention(Module):
    """Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
    the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
    module to generalize better on different input length and the resulting encoder is more robust to the variance of
    the utterance length. Conformer use `prenorm` residual units with dropout which helps training
    and regularizing deeper models.

    Args:
        d_model (int): The dimension of model
        num_heads (int): The number of attention heads.
        dropout_p (float): probability of dropout

    Inputs: inputs, mask
        - **inputs** (batch, time, dim): Tensor containing input vector
        - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked

    Returns:
        (batch, time, dim): Tensor produces by relative multi headed self attention module.
    """

    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout_p: float,
    ):
        super().__init__()

        # Initialize the RelativeMultiHeadAttention module passing the model dimension and number of attention heads
        self.attention = RelativeMultiHeadAttention(
            d_model=d_model, num_heads=num_heads,
        )
        self.dropout = nn.Dropout(p=dropout_p)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: torch.Tensor,
        encoding: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, _, _ = key.size()

        # Trim or extend the "encoding" to match the size of key, and repeat this for each input in the batch
        encoding = encoding[:, : key.shape[1]]
        encoding = encoding.repeat(batch_size, 1, 1)

        # Pass inputs through the RelativeMultiHeadAttention layer, dropout the resulting outputs
        outputs, attn = self.attention(
            query, key, value, pos_embedding=encoding, mask=mask,
        )

        # Apply dropout to the attention outputs
        outputs = self.dropout(outputs)
        return outputs, attn

In [33]:
class ConformerBlock(Module):
    r"""ConformerBlock class represents a block in the Conformer model architecture.
    The block includes a pointwise convolution followed by Gated Linear Units (`GLU`) activation layer (`Conv1dGLU`),
    a Conformer self attention layer (`ConformerMultiHeadedSelfAttention`), and optional feed-forward layer (`FeedForward`).

    Args:
        d_model (int): The number of expected features in the input.
        n_head (int): The number of heads for the multiheaded attention mechanism.
        kernel_size_conv_mod (int): The size of the convolving kernel for the convolution module.
        embedding_dim (int): The dimension of the embeddings.
        dropout (float): The dropout probability.
        with_ff (bool): If True, uses FeedForward layer inside ConformerBlock.
    """

    def __init__(
        self,
        d_model: int,
        n_head: int,
        kernel_size_conv_mod: int,
        embedding_dim: int,
        dropout: float,
        with_ff: bool,
    ):
        super().__init__()
        self.with_ff = with_ff
        self.conditioning = Conv1dGLU(
            d_model=d_model,
            kernel_size=kernel_size_conv_mod,
            padding=kernel_size_conv_mod // 2,
            embedding_dim=embedding_dim,
        )
        if self.with_ff:
            self.ff = FeedForward(
                d_model=d_model,
                dropout=dropout,
                kernel_size=3,
            )
        self.conformer_conv_1 = ConformerConvModule(
            d_model,
            kernel_size=kernel_size_conv_mod,
            dropout=dropout,
        )
        self.ln = nn.LayerNorm(
            d_model,
        )
        self.slf_attn = ConformerMultiHeadedSelfAttention(
            d_model=d_model,
            num_heads=n_head,
            dropout_p=dropout,
        )
        self.conformer_conv_2 = ConformerConvModule(
            d_model,
            kernel_size=kernel_size_conv_mod,
            dropout=dropout,
        )

    def forward(
        self,
        x: torch.Tensor,
        embeddings: torch.Tensor,
        mask: torch.Tensor,
        slf_attn_mask: torch.Tensor,
        encoding: torch.Tensor,
    ) -> torch.Tensor:
        r"""Forward pass of the Conformer block.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, num_features).
            embeddings (Tensor): Embeddings tensor.
            mask (Tensor): The mask tensor.
            slf_attn_mask (Tensor): The mask for self-attention layer.
            encoding (Tensor): The positional encoding tensor.

        Returns:
            Tensor: The output tensor of shape (batch_size, seq_len, num_features).
        """
        x = self.conditioning.forward(x, embeddings=embeddings)
        if self.with_ff:
            x = self.ff(x) + x
        x = self.conformer_conv_1(x) + x
        res = x
        x = self.ln(x)
        x, _ = self.slf_attn(
            query=x,
            key=x,
            value=x,
            mask=slf_attn_mask,
            encoding=encoding,
        )
        x = x + res
        x = x.masked_fill(mask.unsqueeze(-1), 0)
        return self.conformer_conv_2(x) + x

In [34]:
class Conformer(Module):
    r"""`Conformer` class represents the `Conformer` model which is a sequence-to-sequence model
    used in some modern automated speech recognition systems. It is composed of several `ConformerBlocks`.

    Args:
        dim (int): The number of expected features in the input.
        n_layers (int): The number of `ConformerBlocks` in the Conformer model.
        n_heads (int): The number of heads in the multiheaded self-attention mechanism in each `ConformerBlock`.
        embedding_dim (int): The dimension of the embeddings.
        p_dropout (float): The dropout probability to be used in each `ConformerBlock`.
        kernel_size_conv_mod (int): The size of the convolving kernel in the convolution module of each `ConformerBlock`.
        with_ff (bool): If True, each `ConformerBlock` uses FeedForward layer inside it.
    """

    def __init__(
        self,
        dim: int,
        n_layers: int,
        n_heads: int,
        embedding_dim: int,
        p_dropout: float,
        kernel_size_conv_mod: int,
        with_ff: bool,
    ):
        super().__init__()
        self.layer_stack = nn.ModuleList(
            [
                ConformerBlock(
                    dim,
                    n_heads,
                    kernel_size_conv_mod=kernel_size_conv_mod,
                    dropout=p_dropout,
                    embedding_dim=embedding_dim,
                    with_ff=with_ff,
                )
                for _ in range(n_layers)
            ],
        )

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
        embeddings: torch.Tensor,
        encoding: torch.Tensor,
    ) -> torch.Tensor:
        r"""Forward Pass of the Conformer block.

        Args:
            x (Tensor): Input tensor of shape (batch_size, seq_len, num_features).
            mask (Tensor): The mask tensor.
            embeddings (Tensor): Embeddings tensor.
            encoding (Tensor): The positional encoding tensor.

        Returns:
            Tensor: The output tensor of shape (batch_size, seq_len, num_features).
        """
        attn_mask = mask.view((mask.shape[0], 1, 1, mask.shape[1]))
        attn_mask.to(x.device)
        for enc_layer in self.layer_stack:
            x = enc_layer(
                x,
                mask=mask,
                slf_attn_mask=attn_mask,
                embeddings=embeddings,
                encoding=encoding,
            )
        return x

In [35]:
class ConvTransposed(Module):
    r"""`ConvTransposed` applies a 1D convolution operation, with the main difference that it transposes the
    last two dimensions of the input tensor before and after applying the `BSConv1d` convolution operation.
    This can be useful in certain architectures where the tensor dimensions are processed in a different order.

    The `ConvTransposed` class performs a `BSConv` operation after transposing the input tensor dimensions. Specifically, it swaps the channels and width dimensions of a tensor, applies the convolution, and then swaps the dimensions back to their original order. The intuition behind swapping dimensions can depend on the specific use case in the larger architecture; typically, it's used when the operation or sequence of operations expected a different arrangement of dimensions.

    Args:
        in_channels (int): Number of channels in the input
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int): Size of the kernel used in convolution
        padding (int): Zero-padding added around the input tensor along the width direction

    Attributes:
        conv (BSConv1d): `BSConv1d` module to apply convolution.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 1,
        padding: int = 0,
    ):
        super().__init__()

        # Define BSConv1d convolutional layer
        self.conv = BSConv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=padding,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward propagation method for the ConvTransposed layer.

        Args:
            x (torch.Tensor): input tensor

        Returns:
            x (torch.Tensor): output tensor after application of ConvTransposed
        """
        # Transpose the last two dimensions (dimension 1 and 2 here). Now the tensor has shape (N, W, C)
        x = x.contiguous().transpose(1, 2)

        # Apply BSConv1d convolution.
        x = self.conv(x)

        # Transpose the last two dimensions back to their original order. Now the tensor has shape (N, C, W)
        # Return final output tensor
        return x.contiguous().transpose(1, 2)

class VariancePredictor(Module):
    r"""Duration and Pitch predictor neural network module in PyTorch.

    It consists of multiple layers, including `ConvTransposed` layers (custom convolution transpose layers from
    the `model.conv_blocks` module), LeakyReLU activation functions, Layer Normalization and Dropout layers.

    Constructor for `VariancePredictor` class.

    Args:
        channels_in (int): Number of input channels.
        channels (int): Number of output channels for ConvTransposed layers and input channels for linear layer.
        channels_out (int): Number of output channels for linear layer.
        kernel_size (int): Size of the kernel for ConvTransposed layers.
        p_dropout (float): Probability of dropout.

    Returns:
        torch.Tensor: Output tensor.
    """

    def __init__(
        self,
        channels_in: int,
        channels: int,
        channels_out: int,
        kernel_size: int,
        p_dropout: float,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                # Convolution transpose layer followed by LeakyReLU, LayerNorm and Dropout
                ConvTransposed(
                    channels_in,
                    channels,
                    kernel_size=kernel_size,
                    padding=(kernel_size - 1) // 2,
                ),
                nn.LeakyReLU(leaky_relu_slope),
                nn.LayerNorm(
                    channels,
                ),
                nn.Dropout(p_dropout),
                # Another "block" of ConvTransposed, LeakyReLU, LayerNorm, and Dropout
                ConvTransposed(
                    channels,
                    channels,
                    kernel_size=kernel_size,
                    padding=(kernel_size - 1) // 2,
                ),
                nn.LeakyReLU(leaky_relu_slope),
                nn.LayerNorm(
                    channels,
                ),
                nn.Dropout(p_dropout),
            ],
        )

        # Output linear layer
        self.linear_layer = nn.Linear(
            channels,
            channels_out,
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        r"""Forward pass for `VariancePredictor`.

        Args:
            x (torch.Tensor): Input tensor.
            mask (torch.Tensor): Mask tensor, has the same size as x.

        Returns:
            torch.Tensor: Output tensor.
        """
        # Sequentially pass the input through all defined layers
        # (ConvTransposed -> LeakyReLU -> LayerNorm -> Dropout -> ConvTransposed -> LeakyReLU -> LayerNorm -> Dropout)
        for layer in self.layers:
            x = layer(x)
        x = self.linear_layer(x)
        x = x.squeeze(-1)
        return x.masked_fill(mask, 0.0)

In [36]:
def average_over_durations(values: torch.Tensor, durs: torch.Tensor) -> torch.Tensor:
    r"""Function calculates the average of values over specified durations.

    Args:
    values (torch.Tensor): A 3D tensor of shape [B, 1, T_de] where B is the batch size,
                           T_de is the duration of each element in the batch. The values
                           represent some quantity that needs to be averaged over durations.
    durs (torch.Tensor): A 2D tensor of shape [B, T_en] where B is the batch size,
                         T_en is the number of elements in each batch. The values represent
                         the durations over which the averaging needs to be done.

    Returns:
    avg (torch.Tensor): A 3D tensor of shape [B, 1, T_en] where B is the batch size,
                        T_en is the number of elements in each batch. The values represent
                        the average of the input values over the specified durations.

    Note:
    The function uses PyTorch operations for efficient computation on GPU.

    Shapes:
        - values: :math:`[B, 1, T_de]`
        - durs: :math:`[B, T_en]`
        - avg: :math:`[B, 1, T_en]`
    """
    durs_cums_ends = torch.cumsum(durs, dim=1).long()
    durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
    values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0))
    values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0))

    bs, l = durs_cums_ends.size()
    n_formants = values.size(1)
    dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l)
    dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l)

    values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).float()
    values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).float()

    avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems)
    return avg

In [37]:
class PitchAdaptorConv(nn.Module):
    """The PitchAdaptorConv class is a pitch adaptor network in the model.
    Updated version of the PitchAdaptorConv uses the conv embeddings for the pitch.

    Args:
        channels_in (int): Number of in channels for conv layers.
        channels_out (int): Number of out channels.
        kernel_size (int): Size the kernel for the conv layers.
        dropout (float): Probability of dropout.
        leaky_relu_slope (float): Slope for the leaky relu.
        emb_kernel_size (int): Size the kernel for the pitch embedding.

    Inputs: inputs, mask
        - **inputs** (batch, time1, dim): Tensor containing input vector
        - **target** (batch, 1, time2): Tensor containing the pitch target
        - **dr** (batch, time1): Tensor containing aligner durations vector
        - **mask** (batch, time1): Tensor containing indices to be masked
    Returns:
        - **pitch prediction** (batch, 1, time1): Tensor produced by pitch predictor
        - **pitch embedding** (batch, channels, time1): Tensor produced pitch adaptor
        - **average pitch target(train only)** (batch, 1, time1): Tensor produced after averaging over durations

    """

    def __init__(
        self,
        channels_in: int,
        channels_hidden: int,
        channels_out: int,
        kernel_size: int,
        dropout: float,
        leaky_relu_slope: float,
        emb_kernel_size: int,
    ):
        super().__init__()
        self.pitch_predictor = VariancePredictor(
            channels_in=channels_in,
            channels=channels_hidden,
            channels_out=channels_out,
            kernel_size=kernel_size,
            p_dropout=dropout,
            leaky_relu_slope=leaky_relu_slope,
        )
        self.pitch_emb = nn.Conv1d(
            1,
            channels_hidden,
            kernel_size=emb_kernel_size,
            padding=int((emb_kernel_size - 1) / 2),
        )

    def get_pitch_embedding_train(
        self,
        x: torch.Tensor,
        target: torch.Tensor,
        dr: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Function is used during training to get the pitch prediction, average pitch target,
        and pitch embedding.

        Args:
            x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
                            T_src is the source sequence length, and C is the number of channels.
            target (torch.Tensor): A 3D tensor of shape [B, 1, T_max2] where B is the batch size,
                                T_max2 is the maximum target sequence length.
            dr (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the durations.
            mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the mask.

        Returns:
            pitch_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                        T_src is the source sequence length. The values represent the pitch prediction.
            avg_pitch_target (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                            T_src is the source sequence length. The values represent the average pitch target.
            pitch_emb (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
                                    C is the number of channels, T_src is the source sequence length. The values represent the pitch embedding.
        Shapes:
            x: :math: `[B, T_src, C]`
            target: :math: `[B, 1, T_max2]`
            dr: :math: `[B, T_src]`
            mask: :math: `[B, T_src]`
        """
        pitch_pred = self.pitch_predictor.forward(x, mask)
        pitch_pred = pitch_pred.unsqueeze(1)

        avg_pitch_target = average_over_durations(target, dr)
        pitch_emb = self.pitch_emb(avg_pitch_target)

        return pitch_pred, avg_pitch_target, pitch_emb

    def add_pitch_embedding_train(
        self,
        x: torch.Tensor,
        target: torch.Tensor,
        dr: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Add pitch embedding during training.

        This method calculates the pitch embedding and adds it to the input tensor 'x'.
        It also returns the predicted pitch and the average target pitch.

        Args:
            x (torch.Tensor): The input tensor to which the pitch embedding will be added.
            target (torch.Tensor): The target tensor used in the pitch embedding calculation.
            dr (torch.Tensor): The duration tensor used in the pitch embedding calculation.
            mask (torch.Tensor): The mask tensor used in the pitch embedding calculation.

        Returns:
            x (torch.Tensor): The input tensor with added pitch embedding.
            pitch_pred (torch.Tensor): The predicted pitch tensor.
            avg_pitch_target (torch.Tensor): The average target pitch tensor.
        """
        pitch_pred, avg_pitch_target, pitch_emb = self.get_pitch_embedding_train(
            x=x,
            target=target.unsqueeze(1),
            dr=dr,
            mask=mask,
        )
        x_pitch = x + pitch_emb.transpose(1, 2)
        return x_pitch, pitch_pred, avg_pitch_target

    def get_pitch_embedding(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Function is used during inference to get the pitch embedding and pitch prediction.

        Args:
            x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
                            T_src is the source sequence length, and C is the number of channels.
            mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the mask.

        Returns:
            pitch_emb_pred (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
                                            C is the number of channels, T_src is the source sequence length. The values represent the pitch embedding.
            pitch_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                        T_src is the source sequence length. The values represent the pitch prediction.
        """
        pitch_pred = self.pitch_predictor.forward(x, mask)
        pitch_pred = pitch_pred.unsqueeze(1)

        pitch_emb_pred = self.pitch_emb(pitch_pred)
        return pitch_emb_pred, pitch_pred

    def add_pitch_embedding(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Add pitch embedding during inference.

        This method calculates the pitch embedding and adds it to the input tensor 'x'.
        It also returns the predicted pitch.

        Args:
            x (torch.Tensor): The input tensor to which the pitch embedding will be added.
            mask (torch.Tensor): The mask tensor used in the pitch embedding calculation.
            pitch_transform (Callable): A function to transform the pitch prediction.

        Returns:
            x (torch.Tensor): The input tensor with added pitch embedding.
            pitch_pred (torch.Tensor): The predicted pitch tensor.
        """
        pitch_emb_pred, pitch_pred = self.get_pitch_embedding(x, mask)
        x_pitch = x + pitch_emb_pred.transpose(1, 2)
        return x_pitch, pitch_pred

In [38]:
class EnergyAdaptor(nn.Module):
    """Variance Adaptor with an added 1D conv layer. Used to
    get energy embeddings.

    Args:
        channels_in (int): Number of in channels for conv layers.
        channels_out (int): Number of out channels.
        kernel_size (int): Size the kernel for the conv layers.
        dropout (float): Probability of dropout.
        leaky_relu_slope (float): Slope for the leaky relu.
        emb_kernel_size (int): Size the kernel for the pitch embedding.

    Inputs: inputs, mask
        - **inputs** (batch, time1, dim): Tensor containing input vector
        - **target** (batch, 1, time2): Tensor containing the energy target
        - **dr** (batch, time1): Tensor containing aligner durations vector
        - **mask** (batch, time1): Tensor containing indices to be masked
    Returns:
        - **energy prediction** (batch, 1, time1): Tensor produced by energy predictor
        - **energy embedding** (batch, channels, time1): Tensor produced energy adaptor
        - **average energy target(train only)** (batch, 1, time1): Tensor produced after averaging over durations

    """

    def __init__(
        self,
        channels_in: int,
        channels_hidden: int,
        channels_out: int,
        kernel_size: int,
        dropout: float,
        leaky_relu_slope: float,
        emb_kernel_size: int,
    ):
        super().__init__()
        self.energy_predictor = VariancePredictor(
            channels_in=channels_in,
            channels=channels_hidden,
            channels_out=channels_out,
            kernel_size=kernel_size,
            p_dropout=dropout,
            leaky_relu_slope=leaky_relu_slope,
        )
        self.energy_emb = nn.Conv1d(
            1,
            channels_hidden,
            kernel_size=emb_kernel_size,
            padding=int((emb_kernel_size - 1) / 2),
        )

    def get_energy_embedding_train(
        self,
        x: torch.Tensor,
        target: torch.Tensor,
        dr: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Function is used during training to get the energy prediction, average energy target, and energy embedding.

        Args:
            x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
                            T_src is the source sequence length, and C is the number of channels.
            target (torch.Tensor): A 3D tensor of shape [B, 1, T_max2] where B is the batch size,
                                T_max2 is the maximum target sequence length.
            dr (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the durations.
            mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the mask.

        Returns:
            energy_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                        T_src is the source sequence length. The values represent the energy prediction.
            avg_energy_target (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                            T_src is the source sequence length. The values represent the average energy target.
            energy_emb (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
                                    C is the number of channels, T_src is the source sequence length. The values represent the energy embedding.
        Shapes:
            x: :math: `[B, T_src, C]`
            target: :math: `[B, 1, T_max2]`
            dr: :math: `[B, T_src]`
            mask: :math: `[B, T_src]`
        """
        energy_pred = self.energy_predictor.forward(x, mask)
        energy_pred = energy_pred.unsqueeze(1)

        avg_energy_target = average_over_durations(target, dr)
        energy_emb = self.energy_emb(avg_energy_target)

        return energy_pred, avg_energy_target, energy_emb

    def add_energy_embedding_train(
        self,
        x: torch.Tensor,
        target: torch.Tensor,
        dr: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Add energy embedding during training.

        This method calculates the energy embedding and adds it to the input tensor 'x'.
        It also returns the predicted energy and the average target energy.

        Args:
            x (torch.Tensor): The input tensor to which the energy embedding will be added.
            target (torch.Tensor): The target tensor used in the energy embedding calculation.
            dr (torch.Tensor): The duration tensor used in the energy embedding calculation.
            mask (torch.Tensor): The mask tensor used in the energy embedding calculation.

        Returns:
            x (torch.Tensor): The input tensor with added energy embedding.
            energy_pred (torch.Tensor): The predicted energy tensor.
            avg_energy_target (torch.Tensor): The average target energy tensor.
        """
        energy_pred, avg_energy_target, energy_emb = self.get_energy_embedding_train(
            x=x,
            target=target,
            dr=dr,
            mask=mask,
        )
        x_energy = x + energy_emb.transpose(1, 2)
        return x_energy, energy_pred, avg_energy_target

    def get_energy_embedding(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Function is used during inference to get the energy embedding and energy prediction.

        Args:
            x (torch.Tensor): A 3D tensor of shape [B, T_src, C] where B is the batch size,
                            T_src is the source sequence length, and C is the number of channels.
            mask (torch.Tensor): A 2D tensor of shape [B, T_src] where B is the batch size,
                                T_src is the source sequence length. The values represent the mask.

        Returns:
            energy_emb_pred (torch.Tensor): A 3D tensor of shape [B, C, T_src] where B is the batch size,
                                            C is the number of channels, T_src is the source sequence length. The values represent the energy embedding.
            energy_pred (torch.Tensor): A 3D tensor of shape [B, 1, T_src] where B is the batch size,
                                        T_src is the source sequence length. The values represent the energy prediction.
        """
        energy_pred = self.energy_predictor.forward(x, mask)
        energy_pred = energy_pred.unsqueeze(1)

        energy_emb_pred = self.energy_emb(energy_pred)
        return energy_emb_pred, energy_pred

    def add_energy_embedding(
        self,
        x: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Add energy embedding during inference.

        This method calculates the energy embedding and adds it to the input tensor 'x'.
        It also returns the predicted energy.

        Args:
            x (torch.Tensor): The input tensor to which the energy embedding will be added.
            mask (torch.Tensor): The mask tensor used in the energy embedding calculation.
            energy_transform (Callable): A function to transform the energy prediction.

        Returns:
            x (torch.Tensor): The input tensor with added energy embedding.
            energy_pred (torch.Tensor): The predicted energy tensor.
        """
        energy_emb_pred, energy_pred = self.get_energy_embedding(x, mask)
        x_energy = x + energy_emb_pred.transpose(1, 2)
        return x_energy, energy_pred

In [39]:
def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor:
    r"""Takes a list of 1D or 2D tensors and pads them to match the maximum length.

    Args:
        input_ele (List[torch.Tensor]): The list of tensors to be padded.
        max_len (int): The length to which the tensors should be padded.

    Returns:
        torch.Tensor: A tensor containing all the padded input tensors.
    """
    # Create an empty list to store the padded tensors
    out_list = torch.jit.annotate(List[torch.Tensor], [])
    for batch in input_ele:
        if len(batch.shape) == 1:
            # Perform padding for 1D tensor
            one_batch_padded = F.pad(
                batch, (0, max_len - batch.size(0)), "constant", 0.0,
            )
        else:
            # Perform padding for 2D tensor
            one_batch_padded = F.pad(
                batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0,
            )
        # Append the padded tensor to the list
        out_list.append(one_batch_padded)

    # Stack all the tensors in the list into a single tensor
    return torch.stack(out_list)

In [40]:
class LengthAdaptor(Module):
    r"""DEPRECATED: The LengthAdaptor module is used to adjust the duration of phonemes.
    It contains a dedicated duration predictor and methods to upsample the input features to match predicted durations.

    Args:
        model_config (AcousticModelConfigType): The model configuration object containing model parameters.
    """

    def __init__(
        self,
        model_config: AcousticModelConfigType,
    ):
        super().__init__()
        # Initialize the duration predictor
        self.duration_predictor = VariancePredictor(
            channels_in=model_config.encoder.n_hidden,
            channels=model_config.variance_adaptor.n_hidden,
            channels_out=1,
            kernel_size=model_config.variance_adaptor.kernel_size,
            p_dropout=model_config.variance_adaptor.p_dropout,
        )

    def length_regulate(
        self,
        x: torch.Tensor,
        duration: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Regulates the length of the input tensor using the duration tensor.

        Args:
            x (torch.Tensor): The input tensor.
            duration (torch.Tensor): The tensor containing duration for each time step in x.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The regulated output tensor and the tensor containing the length of each sequence in the batch.
        """
        output = torch.jit.annotate(List[torch.Tensor], [])
        mel_len = torch.jit.annotate(List[int], [])
        max_len = 0
        for batch, expand_target in zip(x, duration):
            expanded = self.expand(batch, expand_target)
            if expanded.shape[0] > max_len:
                max_len = expanded.shape[0]
            output.append(expanded)
            mel_len.append(expanded.shape[0])
        output = pad(output, max_len)
        return output, torch.tensor(mel_len, dtype=torch.int64)

    def expand(self, batch: torch.Tensor, predicted: torch.Tensor) -> torch.Tensor:
        r"""Expands the input tensor based on the predicted values.

        Args:
            batch (torch.Tensor): The input tensor.
            predicted (torch.Tensor): The tensor containing predicted expansion factors.

        Returns:
            torch.Tensor: The expanded tensor.
        """
        out = torch.jit.annotate(List[torch.Tensor], [])
        for i, vec in enumerate(batch):
            expand_size = predicted[i].item()
            out.append(vec.expand(max(int(expand_size), 0), -1))
        return torch.cat(out, 0)

    def upsample_train(
        self,
        x: torch.Tensor,
        x_res: torch.Tensor,
        duration_target: torch.Tensor,
        embeddings: torch.Tensor,
        src_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Upsamples the input tensor during training using ground truth durations.

        Args:
            x (torch.Tensor): The input tensor.
            x_res (torch.Tensor): Another input tensor for duration prediction.
            duration_target (torch.Tensor): The ground truth durations tensor.
            embeddings (torch.Tensor): The tensor containing phoneme embeddings.
            src_mask (torch.Tensor): The mask tensor indicating valid entries in x and x_res.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The upsampled x, log duration prediction, and upsampled embeddings.
        """
        x_res = x_res.detach()
        log_duration_prediction = self.duration_predictor(
            x_res,
            src_mask,
        )  # type: torch.Tensor
        x, _ = self.length_regulate(x, duration_target)
        embeddings, _ = self.length_regulate(embeddings, duration_target)
        return x, log_duration_prediction, embeddings

    def upsample(
        self,
        x: torch.Tensor,
        x_res: torch.Tensor,
        src_mask: torch.Tensor,
        embeddings: torch.Tensor,
        control: float,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Upsamples the input tensor during inference.

        Args:
            x (torch.Tensor): The input tensor.
            x_res (torch.Tensor): Another input tensor for duration prediction.
            src_mask (torch.Tensor): The mask tensor indicating valid entries in x and x_res.
            embeddings (torch.Tensor): The tensor containing phoneme embeddings.
            control (float): A control parameter for pitch regulation.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The upsampled x, approximated duration, and upsampled embeddings.
        """
        log_duration_prediction = self.duration_predictor(
            x_res,
            src_mask,
        )
        duration_rounded = torch.clamp(
            (torch.round(torch.exp(log_duration_prediction) - 1) * control),
            min=0,
        )
        x, _ = self.length_regulate(x, duration_rounded)
        embeddings, _ = self.length_regulate(embeddings, duration_rounded)
        return x, duration_rounded, embeddings

In [41]:
class AddCoords(Module):
    r"""AddCoords is a PyTorch module that adds additional channels to the input tensor containing the relative
    (normalized to `[-1, 1]`) coordinates of each input element along the specified number of dimensions (`rank`).
    Essentially, it adds spatial context information to the tensor.

    Typically, these inputs are feature maps coming from some CNN, where the spatial organization of the input
    matters (such as an image or speech signal).

    This additional spatial context allows subsequent layers (such as convolutions) to learn position-dependent
    features. For example, in tasks where the absolute position of features matters (such as denoising and
    segmentation tasks), it helps the model to know where (in terms of relative position) the features are.

    Args:
        rank (int): The dimensionality of the input tensor. That is to say, this tells us how many dimensions the
                    input tensor's spatial context has. It's assumed to be 1, 2, or 3 corresponding to some 1D, 2D,
                    or 3D data (like an image).

        with_r (bool): Boolean indicating whether to add an extra radial distance channel or not. If True, an extra
                       channel is appended, which measures the Euclidean (L2) distance from the center of the image.
                       This might be useful when the proximity to the center of the image is important to the task.
    """

    def __init__(self, rank: int, with_r: bool = False):
        super().__init__()
        self.rank = rank
        self.with_r = with_r

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the AddCoords module. Depending on the rank of the tensor, it adds one or more new channels
        with relative coordinate values. If `with_r` is True, an extra radial channel is included.

        For example, for an image (`rank=2`), two channels would be added which contain the normalized x and y
        coordinates respectively of each pixel.

        Calling the forward method updates the original tensor `x` with the added channels.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            out (torch.Tensor): The input tensor with added coordinate and possibly radial channels.
        """
        if self.rank == 1:
            batch_size_shape, _, dim_x = x.shape
            xx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            xx_channel = xx_range[None, None, :]

            xx_channel = xx_channel.float() / (dim_x - 1)
            xx_channel = xx_channel * 2 - 1
            xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)

            out = torch.cat([x, xx_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
                out = torch.cat([out, rr], dim=1)

        elif self.rank == 2:
            batch_size_shape, _, dim_y, dim_x = x.shape
            xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
            yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32, device=x.device)

            xx_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
            yy_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            xx_range = xx_range[None, None, :, None]
            yy_range = yy_range[None, None, :, None]

            xx_channel = torch.matmul(xx_range, xx_ones)
            yy_channel = torch.matmul(yy_range, yy_ones)

            # transpose y
            yy_channel = yy_channel.permute(0, 1, 3, 2)

            xx_channel = xx_channel.float() / (dim_y - 1)
            yy_channel = yy_channel.float() / (dim_x - 1)

            xx_channel = xx_channel * 2 - 1
            yy_channel = yy_channel * 2 - 1

            xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
            yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)

            out = torch.cat([x, xx_channel, yy_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(
                    torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2),
                )
                out = torch.cat([out, rr], dim=1)

        elif self.rank == 3:
            batch_size_shape, _, dim_z, dim_y, dim_x = x.shape
            xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
            yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32, device=x.device)
            zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32, device=x.device)

            xy_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
            xy_range = xy_range[None, None, None, :, None]

            yz_range = torch.arange(dim_z, dtype=torch.int32, device=x.device)
            yz_range = yz_range[None, None, None, :, None]

            zx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            zx_range = zx_range[None, None, None, :, None]

            xy_channel = torch.matmul(xy_range, xx_ones)
            xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)

            yz_channel = torch.matmul(yz_range, yy_ones)
            yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
            yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)

            zx_channel = torch.matmul(zx_range, zz_ones)
            zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
            zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)

            out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(
                    torch.pow(xx_channel - 0.5, 2)
                    + torch.pow(yy_channel - 0.5, 2)
                    + torch.pow(zz_channel - 0.5, 2),
                )
                out = torch.cat([out, rr], dim=1)
        else:
            raise NotImplementedError

        return out

In [42]:
from torch.nn.modules import conv

class CoordConv1d(conv.Conv1d, Module):
    r"""`CoordConv1d` is an extension of the standard 1D convolution layer (`conv.Conv1d`), with the addition of extra coordinate
    channels. These extra channels encode positional coordinates, and optionally, the radial distance from the origin.
    This is inspired by the paper:
    [An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution](https://arxiv.org/abs/1807.03247)
    and is designed to help Convolution layers to pay attention to the absolute position of features in the input space.

    The responsibility of this class is to intercept the input tensor and append extra channels to it. These extra channels
    encode the positional coordinates (and optionally, the radial distance from the center). The enhanced tensor is then
    immediately passed through a standard Conv1D layer.

    In concrete terms, this means Convolution layer does not just process the color in an image-based task, but also 'knows'
    where in the overall image this color is located.

    In a typical Text-To-Speech (TTS) system like DelightfulTTS, the utterance is processed in a sequential manner.
    The importance of sequential data in such a use-case can benefit from `CoordConv` layer as it offers a way to draw
    more attention to the positioning of data. `CoordConv` is a drop-in replacement for standard convolution layers,
    enriches spatial representation in Convolutional Neural Networks (CNN) with additional positional information.

    Hence, the resultant Convolution does not only process the characteristics of the sound in the input speech signal,
    but also 'knows' where in the overall signal this particular sound is located, providing it with the spatial context.
    This can be particularly useful in TTS systems where the sequence of phonemes and their timing can be critical.

    Args:
        in_channels (int): Number of channels in the input.
        out_channels (int): Number of channels produced by the convolution.
        kernel_size (int): Size of the convolving kernel.
        stride (int): Stride of the convolution. Default: 1.
        padding (int): Zero-padding added to both sides of the input . Default: 0.
        dilation (int): Spacing between kernel elements. Default: 1.
        groups (int): Number of blocked connections from input channels to output channels. Default: 1.
        bias (bool): If True, adds a learnable bias to the output. Default: True.
        with_r (bool): If True, adds a radial coordinate channel. Default: False.

    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        with_r: bool = False,
    ):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )

        self.rank = 1
        self.addcoords = AddCoords(self.rank, with_r)

        self.conv = nn.Conv1d(
            in_channels + self.rank + int(with_r),
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""The forward pass of the `CoordConv1d` module. It adds the coordinate channels to the input tensor with the `AddCoords`
        module, and then immediately passes the result through a 1D convolution.

        As a result, the subsequent Conv layers don't merely process sound characteristics of the speech signal, but are
        also aware of their relative positioning, offering a notable improvement over traditional methods, particularly for
        challenging TTS tasks where the sequence is critical.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, out_channels, length).
        """
        # Apply AddCoords layer to add coordinate channels to the input tensor
        x = self.addcoords(x)

        # Apply convolution
        return self.conv(x)

In [43]:
def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor:
    r"""Generate a mask tensor from a tensor of sequence lengths.

    Args:
        lengths (torch.Tensor): A tensor of sequence lengths of shape: (batch_size, )

    Returns:
        torch.Tensor: A mask tensor of shape: (batch_size, max_len) where max_len is the
            maximum sequence length in the provided tensor. The mask tensor has a value of
            True at each position that is more than the length of the sequence (padding positions).

    Example:
      lengths: `torch.tensor([2, 3, 1, 4])`
      Mask tensor will be: `torch.tensor([
            [False, False, True, True],
            [False, False, False, True],
            [False, True, True, True],
            [False, False, False, False]
        ])`
    """
    # Get batch size
    batch_size = lengths.shape[0]

    # Get maximum sequence length in the batch
    max_len = int(torch.max(lengths).item())

    # Generate a tensor of shape (batch_size, max_len)
    # where each row contains values from 0 to max_len
    ids = (
        torch.arange(0, max_len, device=lengths.device)
        .unsqueeze(0)
        .expand(batch_size, -1)
    )
    # Compare each value in the ids tensor with
    # corresponding sequence length to generate a mask.
    # The mask will have True at positions where id >= sequence length,
    # indicating padding positions in the original sequences
    return ids >= lengths.unsqueeze(1).type(torch.int64).expand(-1, max_len)

In [44]:
def stride_lens_downsampling(lens: torch.Tensor, stride: int = 2) -> torch.Tensor:
    r"""Function computes the lengths of 1D tensor when applying a stride for downsampling.

    Args:
        lens (torch.Tensor): Tensor containing the lengths to be downsampled.
        stride (int, optional): The stride to be used for downsampling. Defaults to 2.

    Returns:
        torch.Tensor: A tensor of the same shape as the input containing the downsampled lengths.
    """
    # The torch.ceil function is used to handle cases where the length is not evenly divisible
    # by the stride. The torch.ceil function rounds up to the nearest integer, ensuring that
    # each item is present at least once in the downsampled lengths.
    # Finally, the .int() is used to convert the resulting float32 tensor to an integer tensor.
    return torch.ceil(lens / stride).int()

In [45]:
class ReferenceEncoder(Module):
    r"""A class to define the reference encoder.
    Similar to Tacotron model, the reference encoder is used to extract the high-level features from the reference

    It consists of a number of convolutional blocks (`CoordConv1d` for the first one and `nn.Conv1d` for the rest),
    then followed by instance normalization and GRU layers.
    The `CoordConv1d` at the first layer to better preserve positional information, paper:
    [Robust and fine-grained prosody control of end-to-end speech synthesis](https://arxiv.org/pdf/1811.02122.pdf)

    Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.

    Args:
        preprocess_config (PreprocessingConfig): Configuration object with preprocessing parameters.
        model_config (AcousticModelConfigType): Configuration object with acoustic model parameters.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor
            produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor.
            _Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x
            has been masked.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        model_config: AcousticModelConfigType,
    ):
        super().__init__()

        n_mel_channels = preprocess_config.stft.n_mel_channels
        ref_enc_filters = model_config.reference_encoder.ref_enc_filters
        ref_enc_size = model_config.reference_encoder.ref_enc_size
        ref_enc_strides = model_config.reference_encoder.ref_enc_strides
        ref_enc_gru_size = model_config.reference_encoder.ref_enc_gru_size

        self.n_mel_channels = n_mel_channels
        K = len(ref_enc_filters)
        filters = [self.n_mel_channels, *ref_enc_filters]
        strides = [1, *ref_enc_strides]

        # Use CoordConv1d at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf
        convs = [
            CoordConv1d(
                in_channels=filters[0],
                out_channels=filters[0 + 1],
                kernel_size=ref_enc_size,
                stride=strides[0],
                padding=ref_enc_size // 2,
                with_r=True,
            ),
            *[
                nn.Conv1d(
                    in_channels=filters[i],
                    out_channels=filters[i + 1],
                    kernel_size=ref_enc_size,
                    stride=strides[i],
                    padding=ref_enc_size // 2,
                )
                for i in range(1, K)
            ],
        ]
        # Define convolution layers (ModuleList)
        self.convs = nn.ModuleList(convs)

        self.norms = nn.ModuleList(
            [
                nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True)
                for i in range(K)
            ],
        )

        # Define GRU layer
        self.gru = nn.GRU(
            input_size=ref_enc_filters[-1],
            hidden_size=ref_enc_gru_size,
            batch_first=True,
        )

    def forward(
        self,
        x: torch.Tensor,
        mel_lens: torch.Tensor,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Forward pass of the ReferenceEncoder.

        Args:
            x (torch.Tensor): A 3-dimensional tensor containing the input sequences, its size is [N, n_mels, timesteps].
            mel_lens (torch.Tensor): A 1-dimensional tensor containing the lengths of each sequence in x. Its length is N.
            leaky_relu_slope (float): The slope of the leaky relu function.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing three tensors. _First_: The sequence tensor
                produced by the last GRU layer after padding has been removed. _Second_: The GRU's final hidden state tensor.
                _Third_: The mask tensor, which has the same shape as x, and contains `True` at positions where the input x
                has been masked.
        """
        mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1)
        mel_masks = mel_masks.to(x.device)

        x = x.masked_fill(mel_masks, 0)
        for conv, norm in zip(self.convs, self.norms):
            x = x.float()
            x = conv(x)
            x = F.leaky_relu(x, leaky_relu_slope)  # [N, 128, Ty//2^K, n_mels//2^K]
            x = norm(x)

        for _ in range(2):
            mel_lens = stride_lens_downsampling(mel_lens)

        mel_masks = get_mask_from_lengths(mel_lens)

        x = x.masked_fill(mel_masks.unsqueeze(1), 0)
        x = x.permute((0, 2, 1))

        packed_sequence = torch.nn.utils.rnn.pack_padded_sequence(
            x,
            lengths=mel_lens.cpu().int(),
            batch_first=True,
            enforce_sorted=False,
        )

        self.gru.flatten_parameters()
        # memory --- [N, Ty, E//2], out --- [1, N, E//2]
        out, memory = self.gru(packed_sequence)
        out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)

        return out, memory, mel_masks

    def calculate_channels(
        self,
        L: int,
        kernel_size: int,
        stride: int,
        pad: int,
        n_convs: int,
    ) -> int:
        r"""Calculate the number of channels after applying convolutions.

        Args:
            L (int): The original size.
            kernel_size (int): The kernel size used in the convolutions.
            stride (int): The stride used in the convolutions.
            pad (int): The padding used in the convolutions.
            n_convs (int): The number of convolutions.

        Returns:
            int: The size after the convolutions.
        """
        # Loop through each convolution
        for _ in range(n_convs):
            # Calculate the size after each convolution
            L = (L - kernel_size + 2 * pad) // stride + 1
        return L

In [46]:
class StyleEmbedAttention(Module):
    r"""Mechanism is being used to extract style features from audio data in the form of spectrograms.

    Each style token (parameterized by an embedding vector) represents a unique style feature. The model applies the `StyleEmbedAttention` mechanism to combine these style tokens (style features) in a weighted manner. The output of the attention module is a sum of style tokens, with each token weighted by its relevance to the input.

    This technique is often used in text-to-speech synthesis (TTS) such as Tacotron-2, where the goal is to modulate the prosody, stress, and intonation of the synthesized speech based on the reference audio or some control parameters. The concept of "global style tokens" (GST) was introduced in
    [Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis](https://arxiv.org/abs/1803.09017) by Yuxuan Wang et al.

    The `StyleEmbedAttention` class is a PyTorch module implementing the attention mechanism.
    This class is specifically designed for handling multiple attention heads.
    Attention here operates on a query and a set of key-value pairs to produce an output.

    Builds the `StyleEmbedAttention` network.

    Args:
        query_dim (int): Dimensionality of the query vectors.
        key_dim (int): Dimensionality of the key vectors.
        num_units (int): Total dimensionality of the query, key, and value vectors.
        num_heads (int): Number of parallel attention layers (heads).

    Note: `num_units` should be divisible by `num_heads`.
    """

    def __init__(
        self,
        query_dim: int,
        key_dim: int,
        num_units: int,
        num_heads: int,
    ):
        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(
            in_features=query_dim,
            out_features=num_units,
            bias=False,
        )
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(
            in_features=key_dim, out_features=num_units, bias=False,
        )

    def forward(self, query: torch.Tensor, key_soft: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the StyleEmbedAttention module calculates the attention scores.

        Args:
            query (torch.Tensor): The input tensor for queries of shape `[N, T_q, query_dim]`
            key_soft (torch.Tensor): The input tensor for keys of shape `[N, T_k, key_dim]`

        Returns:
            out (torch.Tensor): The output tensor of shape `[N, T_q, num_units]`
        """
        values = self.W_value(key_soft)
        split_size = self.num_units // self.num_heads
        values = torch.stack(torch.split(values, split_size, dim=2), dim=0)

        # out_soft = scores_soft = None
        queries = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key_soft)  # [N, T_k, num_units]

        # [h, N, T_q, num_units/h]
        queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0)
        # [h, N, T_k, num_units/h]
        keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)
        # [h, N, T_k, num_units/h]

        # score = softmax(QK^T / (d_k ** 0.5))
        scores_soft = torch.matmul(queries, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores_soft = scores_soft / (self.key_dim**0.5)
        scores_soft = F.softmax(scores_soft, dim=3)

        # out = score * V
        # [h, N, T_q, num_units/h]
        out_soft = torch.matmul(scores_soft, values)
        return torch.cat(torch.split(out_soft, 1, dim=0), dim=3).squeeze(
            0,
        )  # [N, T_q, num_units] scores_soft

In [47]:
class STL(Module):
    r"""Style Token Layer (STL).
    This layer helps to encapsulate different speaking styles in token embeddings.

    Args:
        model_config (AcousticModelConfigType): An object containing the model's configuration parameters.

    Attributes:
        embed (nn.Parameter): The style token embedding tensor.
        attention (StyleEmbedAttention): The attention module used to compute a weighted sum of embeddings.
    """

    def __init__(
        self,
        model_config: AcousticModelConfigType,
    ):
        super().__init__()

        # Number of attention heads
        num_heads = 1
        # Dimension of encoder hidden states
        n_hidden = model_config.encoder.n_hidden
        # Number of style tokens
        self.token_num = model_config.reference_encoder.token_num

        # Define a learnable tensor for style tokens embedding
        self.embed = nn.Parameter(
            torch.FloatTensor(self.token_num, n_hidden // num_heads),
        )

        # Dimension of query in attention
        d_q = n_hidden // 2
        # Dimension of keys in attention
        d_k = n_hidden // num_heads

        # Style Embedding Attention module
        self.attention = StyleEmbedAttention(
            query_dim=d_q,
            key_dim=d_k,
            num_units=n_hidden,
            num_heads=num_heads,
        )

        # Initialize the embedding with normal distribution
        torch.nn.init.normal_(self.embed, mean=0, std=0.5)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the Style Token Layer
        Args:
            x (torch.Tensor): The input tensor.

        Returns
            torch.Tensor: The emotion embedded tensor after applying attention mechanism.
        """
        N = x.size(0)

        # Reshape input tensor to [N, 1, n_hidden // 2]
        query = x.unsqueeze(1)

        keys_soft = (
            torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)
        )  # [N, token_num, n_hidden // num_heads]

        # Apply attention mechanism to get weighted sum of style token embeddings
        return self.attention(query, keys_soft)

In [48]:
class UtteranceLevelProsodyEncoder(Module):
    r"""A class to define the utterance level prosody encoder.

    The encoder uses a Reference encoder class to convert input sequences into high-level features,
    followed by prosody embedding, self attention on the embeddings, and a feedforward transformation to generate the final output.Initializes the encoder with given specifications and creates necessary layers.

    Args:
        preprocess_config (PreprocessingConfig): Configuration object with preprocessing parameters.
        model_config (AcousticModelConfigType): Configuration object with acoustic model parameters.

    Returns:
        torch.Tensor: A 3-dimensional tensor sized `[N, seq_len, E]`.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        model_config: AcousticModelConfigType,
    ):
        super().__init__()

        self.E = model_config.encoder.n_hidden
        ref_enc_gru_size = model_config.reference_encoder.ref_enc_gru_size
        ref_attention_dropout = model_config.reference_encoder.ref_attention_dropout
        bottleneck_size = model_config.reference_encoder.bottleneck_size_u

        # Define important layers/modules for the encoder
        self.encoder = ReferenceEncoder(preprocess_config, model_config)
        self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2)
        self.stl = STL(model_config)
        self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size)
        self.dropout = nn.Dropout(ref_attention_dropout)

    def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor:
        r"""Defines the forward pass of the utterance level prosody encoder.

        Args:
            mels (torch.Tensor): A 3-dimensional tensor containing input sequences. Size is `[N, Ty/r, n_mels*r]`.
            mel_lens (torch.Tensor): A 1-dimensional tensor containing the lengths of each sequence in mels. Length is N.

        Returns:
            torch.Tensor: A 3-dimensional tensor sized `[N, seq_len, E]`.
        """
        # Use the reference encoder to get prosody embeddings
        _, embedded_prosody, _ = self.encoder(mels, mel_lens)

        # Bottleneck
        # Use the linear projection layer on the prosody embeddings
        embedded_prosody = self.encoder_prj(embedded_prosody)

        # Apply the style token layer followed by the bottleneck layer
        out = self.encoder_bottleneck(self.stl(embedded_prosody))

        # Apply dropout for regularization
        out = self.dropout(out)

        # Reshape the output tensor before returning
        return out.view((-1, 1, out.shape[3]))

In [49]:
class PhonemeProsodyPredictor(Module):
    r"""A class to define the Phoneme Prosody Predictor.

    In linguistics, prosody (/ˈprɒsədi, ˈprɒzədi/) is the study of elements of speech that are not individual phonetic segments (vowels and consonants) but which are properties of syllables and larger units of speech, including linguistic functions such as intonation, stress, and rhythm. Such elements are known as suprasegmentals.

    [Wikipedia Prosody (linguistics)](https://en.wikipedia.org/wiki/Prosody_(linguistics))

    This prosody predictor is non-parallel and is inspired by the **work of Du et al., 2021 ?**. It consists of
    multiple convolution transpose, Leaky ReLU activation, LayerNorm, and dropout layers, followed by a
    linear transformation to generate the final output.

    Args:
        model_config (AcousticModelConfigType): Configuration object with model parameters.
        phoneme_level (bool): A flag to decide whether to use phoneme level bottleneck size.
        leaky_relu_slope (float): The negative slope of LeakyReLU activation function.
    """

    def __init__(
        self,
        model_config: AcousticModelConfigType,
        phoneme_level: bool,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ):
        super().__init__()

        # Get the configuration
        self.d_model = model_config.encoder.n_hidden
        kernel_size = model_config.reference_encoder.predictor_kernel_size
        dropout = model_config.encoder.p_dropout

        # Decide on the bottleneck size based on phoneme level flag
        bottleneck_size = (
            model_config.reference_encoder.bottleneck_size_p
            if phoneme_level
            else model_config.reference_encoder.bottleneck_size_u
        )

        # Define the layers
        self.layers = nn.ModuleList(
            [
                ConvTransposed(
                    self.d_model,
                    self.d_model,
                    kernel_size=kernel_size,
                    padding=(kernel_size - 1) // 2,
                ),
                nn.LeakyReLU(leaky_relu_slope),
                nn.LayerNorm(
                    self.d_model,
                ),
                nn.Dropout(dropout),
                ConvTransposed(
                    self.d_model,
                    self.d_model,
                    kernel_size=kernel_size,
                    padding=(kernel_size - 1) // 2,
                ),
                nn.LeakyReLU(leaky_relu_slope),
                nn.LayerNorm(
                    self.d_model,
                ),
                nn.Dropout(dropout),
            ],
        )

        # Output bottleneck layer
        self.predictor_bottleneck = nn.Linear(
            self.d_model,
            bottleneck_size,
        )

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the prosody predictor.

        Args:
            x (torch.Tensor): A 3-dimensional tensor `[B, src_len, d_model]`.
            mask (torch.Tensor): A 2-dimensional tensor `[B, src_len]`.

        Returns:
            torch.Tensor: A 3-dimensional tensor `[B, src_len, 2 * d_model]`.
        """
        # Expand the mask tensor's dimensions from [B, src_len] to [B, src_len, 1]
        mask = mask.unsqueeze(2)

        # Pass the input through the layers
        for layer in self.layers:
            x = layer(x)

        # Apply mask
        x = x.masked_fill(mask, 0.0)

        # Final linear transformation
        return self.predictor_bottleneck(x)

In [50]:
class PhonemeLevelProsodyEncoder(Module):
    r"""Phoneme Level Prosody Encoder Module

    This Class is used to encode the phoneme level prosody in the speech synthesis pipeline.

    Args:
        preprocess_config (PreprocessingConfig): Configuration for preprocessing.
        model_config (AcousticModelConfigType): Acoustic model configuration.

    Returns:
        torch.Tensor: The encoded tensor after applying masked fill.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        model_config: AcousticModelConfigType,
    ):
        super().__init__()

        # Obtain the bottleneck size and reference encoder gru size from the model config.
        bottleneck_size = model_config.reference_encoder.bottleneck_size_p
        ref_enc_gru_size = model_config.reference_encoder.ref_enc_gru_size

        # Initialize ReferenceEncoder, Linear layer and ConformerMultiHeadedSelfAttention for attention mechanism.
        self.encoder = ReferenceEncoder(preprocess_config, model_config)
        self.encoder_prj = nn.Linear(ref_enc_gru_size, model_config.encoder.n_hidden)
        self.attention = ConformerMultiHeadedSelfAttention(
            d_model=model_config.encoder.n_hidden,
            num_heads=model_config.encoder.n_heads,
            dropout_p=model_config.encoder.p_dropout,
        )

        # Bottleneck layer to transform the output of the attention mechanism.
        self.encoder_bottleneck = nn.Linear(
            model_config.encoder.n_hidden, bottleneck_size,
        )

    def forward(
        self,
        x: torch.Tensor,
        src_mask: torch.Tensor,
        mels: torch.Tensor,
        mel_lens: torch.Tensor,
        encoding: torch.Tensor,
    ) -> torch.Tensor:
        r"""The forward pass of the PhonemeLevelProsodyEncoder. Input tensors are passed through the reference encoder,
        attention mechanism, and a bottleneck.

        Args:
            x (torch.Tensor): Input tensor of shape [N, seq_len, encoder_embedding_dim].
            src_mask (torch.Tensor): The mask tensor which contains `True` at positions where the input x has been masked.
            mels (torch.Tensor): The mel-spectrogram with shape [N, Ty/r, n_mels*r], where r=1.
            mel_lens (torch.Tensor): The lengths of each sequence in mels.
            encoding (torch.Tensor): The relative positional encoding tensor.

        Returns:
            torch.Tensor: Output tensor of shape [N, seq_len, bottleneck_size].
        """
        # Use the reference encoder to embed prosody representation
        embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens)

        # Pass the prosody representation through a bottleneck (dimension reduction)
        embedded_prosody = self.encoder_prj(embedded_prosody)

        # Flatten and apply attention mask
        attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1))
        x, _ = self.attention(
            query=x,
            key=embedded_prosody,
            value=embedded_prosody,
            mask=attn_mask,
            encoding=encoding,
        )

        # Apply the bottleneck to the output and mask out irrelevant parts
        x = self.encoder_bottleneck(x)
        return x.masked_fill(src_mask.unsqueeze(-1), 0.0)

In [51]:
from numba import njit, prange
@njit(fastmath=True)
def mas_width1(attn_map: np.ndarray) -> np.ndarray:
    r"""Applies a Monotonic Alignments Shrink (MAS) operation with a hard-coded width of 1 to an attention map.
    Mas with hardcoded width=1
    Essentially, it produces optimal alignments based on previous attention distribution.

    Args:
        attn_map (np.ndarray): The original attention map, a 2D numpy array where rows correspond to mel bins and columns to text bins.

    Returns:
        opt (np.ndarray): Returns the optimal attention map after applying the MAS operation.
    """
    # assumes mel x text
    # Create a placeholder for the output
    opt = np.zeros_like(attn_map)

    # Convert the attention map to log scale for stability
    attn_map = np.log(attn_map)

    # Initialize the first row of attention map appropriately
    attn_map[0, 1:] = -np.inf

    # Initialize log_p with the first row of attention map
    log_p = np.zeros_like(attn_map)
    log_p[0, :] = attn_map[0, :]

    # Placeholder to remember the previous indices for backtracking later
    prev_ind = np.zeros_like(attn_map, dtype=np.int64)

    # Compute the log probabilities based on previous attention distribution
    for i in range(1, attn_map.shape[0]):
        for j in range(attn_map.shape[1]):  # for each text dim
            prev_log = log_p[i - 1, j]
            prev_j = j

            # Compare with left (j-1) pixel and update if the left pixel has larger log probability
            if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
                prev_log = log_p[i - 1, j - 1]
                prev_j = j - 1

            log_p[i, j] = attn_map[i, j] + prev_log

            # Store the position of maximum cumulative log probability
            prev_ind[i, j] = prev_j

    # Backtrack to retrieve the path of attention with maximum cumulative log probability
    curr_text_idx = attn_map.shape[1] - 1
    for i in range(attn_map.shape[0] - 1, -1, -1):
        opt[i, curr_text_idx] = 1
        curr_text_idx = prev_ind[i, curr_text_idx]

    # Mark the first position of the optimal path
    opt[0, curr_text_idx] = 1
    return opt


# @njit(parallel=True)
def b_mas(
    b_attn_map: np.ndarray,
    in_lens: np.ndarray,
    out_lens: np.ndarray,
    width: int=1) -> np.ndarray:
    r"""Applies Monotonic Alignments Shrink (MAS) operation in parallel to the batches of an attention map.
    It uses the `mas_width1` function internally to perform MAS operation.

    Args:
        b_attn_map (np.ndarray): The batched attention map; a 3D array where the first dimension is the batch size, second dimension corresponds to source length, and third dimension corresponds to target length.
        in_lens (np.ndarray): Lengths of sequences in the input batch.
        out_lens (np.ndarray): Lengths of sequences in the output batch.
        width (int, optional): The width for the MAS operation. Defaults to 1.

    Raises:
        AssertionError: If width is not equal to 1. This function currently supports only width of 1.

    Returns:
        np.ndarray: The batched attention map after applying the MAS operation. It has the same dimensions as `b_attn_map`.
    """
    # Assert that the width is 1. This function currently supports only width of 1
    assert width == 1
    attn_out = np.zeros_like(b_attn_map)

    # Loop over each attention map in the batch in parallel
    for b in prange(b_attn_map.shape[0]):
        # Apply Monotonic Alignments Shrink operation to the b-th attention map in the batch
        out = mas_width1(b_attn_map[b, 0, : out_lens[b], : in_lens[b]])

        # Update the b-th attention map in the output with the result of MAS operation
        attn_out[b, 0, : out_lens[b], : in_lens[b]] = out

    # Return the batched attention map after applying the MAS operation
    return attn_out

In [52]:
class Aligner(Module):
    r"""DEPRECATED: Aligner class represents a PyTorch module responsible for alignment tasks
    in a sequence-to-sequence model. It uses convolutional layers combined with
    LeakyReLU activation functions to project inputs to a hidden representation.
    The class utilizes both softmax and log-softmax to calculate softmax
    along dimension 3.

    Args:
        d_enc_in (int): Number of channels in the input for the encoder.
        d_dec_in (int): Number of channels in the input for the decoder.
        d_hidden (int): Number of channels in the output (hidden layers).
        kernel_size_enc (int, optional): Size of the convolving kernel for encoder, default is 3.
        kernel_size_dec (int, optional): Size of the convolving kernel for decoder, default is 7.
        temperature (float, optional): The temperature value applied in Gaussian isotropic
            attention mechanism, default is 0.0005.
        leaky_relu_slope (float, optional): Controls the angle of the negative slope of
            LeakyReLU activation, default is LEAKY_RELU_SLOPE.

    """

    def __init__(
        self,
        d_enc_in: int,
        d_dec_in: int,
        d_hidden: int,
        kernel_size_enc: int = 3,
        kernel_size_dec: int = 7,
        temperature: float = 0.0005,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ):
        super().__init__()
        self.temperature = temperature

        self.softmax = torch.nn.Softmax(dim=3)
        self.log_softmax = torch.nn.LogSoftmax(dim=3)

        self.key_proj = nn.Sequential(
            nn.Conv1d(
                d_enc_in,
                d_hidden,
                kernel_size=kernel_size_enc,
                padding=kernel_size_enc // 2,
            ),
            nn.LeakyReLU(leaky_relu_slope),
            nn.Conv1d(
                d_hidden,
                d_hidden,
                kernel_size=kernel_size_enc,
                padding=kernel_size_enc // 2,
            ),
            nn.LeakyReLU(leaky_relu_slope),
        )

        self.query_proj = nn.Sequential(
            nn.Conv1d(
                d_dec_in,
                d_hidden,
                kernel_size=kernel_size_dec,
                padding=kernel_size_dec // 2,
            ),
            nn.LeakyReLU(leaky_relu_slope),
            nn.Conv1d(
                d_hidden,
                d_hidden,
                kernel_size=kernel_size_dec,
                padding=kernel_size_dec // 2,
            ),
            nn.LeakyReLU(leaky_relu_slope),
            nn.Conv1d(
                d_hidden,
                d_hidden,
                kernel_size=kernel_size_dec,
                padding=kernel_size_dec // 2,
            ),
            nn.LeakyReLU(leaky_relu_slope),
        )

    def binarize_attention_parallel(
        self,
        attn: torch.Tensor,
        in_lens: torch.Tensor,
        out_lens: torch.Tensor,
    ) -> torch.Tensor:
        r"""For training purposes only! Binarizes attention with MAS.
        Binarizes the attention tensor using Maximum Attention Strategy (MAS).

        This process is applied for training purposes only and the resulting
        binarized attention tensor will no longer receive a gradient in the
        backpropagation process.

        Args:
            attn (Tensor): The attention tensor. Must be of shape (B, 1, max_mel_len, max_text_len),
                where B represents the batch size, max_mel_len represents the maximum length
                of the mel spectrogram, and max_text_len represents the maximum length of the text.
            in_lens (Tensor): A 1D tensor of shape (B,) that contains the input sequence lengths,
                which likely corresponds to text sequence lengths.
            out_lens (Tensor): A 1D tensor of shape (B,) that contains the output sequence lengths,
                which likely corresponds to mel spectrogram lengths.

        Returns:
            Tensor: The binarized attention tensor. The output tensor has the same shape as the input `attn` tensor.
        """
        with torch.no_grad():
            attn_cpu = np.array(attn.data.cpu().tolist())
            # .numpy()
            attn_out = b_mas(
                attn_cpu,
                np.array(in_lens.cpu().tolist()),
                # .numpy(),
                np.array(out_lens.cpu().tolist()),
                # .numpy(),
                width=1,
            )
        return torch.tensor(attn_out)

    def forward(
        self,
        enc_in: torch.Tensor,
        dec_in: torch.Tensor,
        enc_len: torch.Tensor,
        dec_len: torch.Tensor,
        enc_mask: torch.Tensor,
        attn_prior: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        r"""Performs the forward pass through the Aligner module.

        Args:
            enc_in (Tensor): The text encoder outputs.
                Must be of shape (B, C_1, T_1), where B is the batch size, C_1 the number of
                channels in encoder inputs,
                and T_1 the sequence length of encoder inputs.
            dec_in (Tensor): The data to align with encoder outputs.
                Must be of shape (B, C_2, T_2), where C_2 is the number of channels in decoder inputs,
                and T_2 the sequence length of decoder inputs.
            enc_len (Tensor): 1D tensor representing the lengths of each sequence in the batch in `enc_in`.
            dec_len (Tensor): 1D tensor representing the lengths of each sequence in the batch in `dec_in`.
            enc_mask (Tensor): Binary mask tensor used to avoid attention to certain timesteps.
            attn_prior (Tensor): Previous attention values for attention calculation.

        Returns:
            Tuple[Tensor, Tensor, Tensor, Tensor]: Returns a tuple of Tensors representing the log-probability, soft attention, hard attention, and hard attention duration.
        """
        queries = dec_in.float()
        keys = enc_in.float()
        keys_enc = self.key_proj(keys)  # B x n_attn_dims x T2
        queries_enc = self.query_proj(queries)

        # Simplistic Gaussian Isotopic Attention
        attn = (
            queries_enc[:, :, :, None] - keys_enc[:, :, None]
        ) ** 2  # B x n_attn_dims x T1 x T2
        attn = -self.temperature * attn.sum(1, keepdim=True)

        if attn_prior is not None:
            # print(f"AlignmentEncoder \t| mel: {queries.shape} phone: {keys.shape}
            # mask: {mask.shape} attn: {attn.shape} attn_prior: {attn_prior.shape}")
            attn = self.log_softmax(attn) + torch.log(
                attn_prior.permute((0, 2, 1))[:, None] + 1e-8,
            )
            # print(f"AlignmentEncoder \t| After prior sum attn: {attn.shape}")"""

        attn_logprob = attn.clone()

        if enc_mask is not None:
            attn.masked_fill(enc_mask.unsqueeze(1).unsqueeze(1), -float("inf"))

        attn_soft = self.softmax(attn)  # softmax along T2
        attn_hard = self.binarize_attention_parallel(attn_soft, enc_len, dec_len)
        attn_hard_dur = attn_hard.sum(2)[:, 0, :]
        return attn_logprob, attn_soft, attn_hard, attn_hard_dur

In [53]:
def initialize_embeddings(shape: Tuple[int, ...]) -> torch.Tensor:
    r"""Initialize embeddings using Kaiming initialization (He initialization).

    This method is specifically designed for 2D matrices and helps to avoid
    the vanishing/exploding gradient problem in deep neural networks.
    This is achieved by keeping the variance of the outputs of a layer to be
    the same as the variance of its inputs.

    Args:
        shape (Tuple[int, ...]): The shape of the embedding matrix to create, denoted as a tuple of integers.
                                 The shape should comprise 2 dimensions, i.e., (embedding_dim, num_embeddings).

    Raises:
        AssertionError: if the provided shape is not 2D.

    Returns:
        torch.Tensor: the created embedding matrix.
    """
    # Check if the input shape is 2D
    assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..."

    # Initialize the embedding matrix using Kaiming initialization
    return torch.randn(shape) * np.sqrt(2 / shape[1])

In [54]:
def positional_encoding(
    d_model: int, length: int,
) -> torch.Tensor:
    r"""Function to calculate positional encoding for transformer model.

    Args:
        d_model (int): Dimension of the model (often corresponds to embedding size).
        length (int): Length of sequences.

    Returns:
        torch.Tensor: Tensor having positional encodings.
    """
    # Initialize placeholder for positional encoding
    pe = torch.zeros(length, d_model)

    # Generate position indices and reshape to have shape (length, 1)
    position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)

    # Calculate term for division
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float()
        * -(math.log(10000.0) / d_model),
    )

    # Assign sin of position * div_term to even indices in the encoding matrix
    pe[:, 0::2] = torch.sin(position * div_term)

    # Assign cos of position * div_term to odd indices in the encoding matrix
    pe[:, 1::2] = torch.cos(position * div_term)

    # Add an extra dimension to match expected output shape
    return pe.unsqueeze(0)


In [55]:
from torch.nn.parameter import Parameter

class AcousticModel(Module):
    r"""The DelightfulTTS AcousticModel class represents a PyTorch module for an acoustic model in text-to-speech (TTS).
    The acoustic model is responsible for predicting speech signals from phoneme sequences.

    The model comprises multiple sub-modules including encoder, decoder and various prosody encoders and predictors.
    Additionally, a pitch and length adaptor are instantiated.

    Args:
        preprocess_config (PreprocessingConfig): Object containing the configuration used for preprocessing the data
        model_config (AcousticModelConfigType): Configuration object containing various model parameters
        n_speakers (int): Total number of speakers in the dataset
        leaky_relu_slope (float, optional): Slope for the leaky relu. Defaults to LEAKY_RELU_SLOPE.

    Note:
        For more specific details on the implementation of sub-modules please refer to their individual respective modules.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        model_config: AcousticModelConfigType,
        n_speakers: int,
        leaky_relu_slope: float = LEAKY_RELU_SLOPE,
    ):
        super().__init__()
        self.emb_dim = model_config.encoder.n_hidden

        self.encoder = Conformer(
            dim=model_config.encoder.n_hidden,
            n_layers=model_config.encoder.n_layers,
            n_heads=model_config.encoder.n_heads,
            embedding_dim=model_config.speaker_embed_dim + model_config.lang_embed_dim,
            p_dropout=model_config.encoder.p_dropout,
            kernel_size_conv_mod=model_config.encoder.kernel_size_conv_mod,
            with_ff=model_config.encoder.with_ff,
        )

        self.pitch_adaptor_conv = PitchAdaptorConv(
            channels_in=model_config.encoder.n_hidden,
            channels_hidden=model_config.variance_adaptor.n_hidden,
            channels_out=1,
            kernel_size=model_config.variance_adaptor.kernel_size,
            emb_kernel_size=model_config.variance_adaptor.emb_kernel_size,
            dropout=model_config.variance_adaptor.p_dropout,
            leaky_relu_slope=leaky_relu_slope,
        )

        self.energy_adaptor = EnergyAdaptor(
            channels_in=model_config.encoder.n_hidden,
            channels_hidden=model_config.variance_adaptor.n_hidden,
            channels_out=1,
            kernel_size=model_config.variance_adaptor.kernel_size,
            emb_kernel_size=model_config.variance_adaptor.emb_kernel_size,
            dropout=model_config.variance_adaptor.p_dropout,
            leaky_relu_slope=leaky_relu_slope,
        )

        self.length_regulator = LengthAdaptor(model_config)

        self.utterance_prosody_encoder = UtteranceLevelProsodyEncoder(
            preprocess_config,
            model_config,
        )

        self.utterance_prosody_predictor = PhonemeProsodyPredictor(
            model_config=model_config,
            phoneme_level=False,
        )

        self.phoneme_prosody_encoder = PhonemeLevelProsodyEncoder(
            preprocess_config,
            model_config,
        )

        self.phoneme_prosody_predictor = PhonemeProsodyPredictor(
            model_config=model_config,
            phoneme_level=True,
        )

        self.u_bottle_out = nn.Linear(
            model_config.reference_encoder.bottleneck_size_u,
            model_config.encoder.n_hidden,
        )

        self.u_norm = nn.LayerNorm(
            model_config.reference_encoder.bottleneck_size_u,
            elementwise_affine=False,
        )

        self.p_bottle_out = nn.Linear(
            model_config.reference_encoder.bottleneck_size_p,
            model_config.encoder.n_hidden,
        )

        self.p_norm = nn.LayerNorm(
            model_config.reference_encoder.bottleneck_size_p,
            elementwise_affine=False,
        )

        self.aligner = Aligner(
            d_enc_in=model_config.encoder.n_hidden,
            d_dec_in=preprocess_config.stft.n_mel_channels,
            d_hidden=model_config.encoder.n_hidden,
        )

        self.decoder = Conformer(
            dim=model_config.decoder.n_hidden,
            n_layers=model_config.decoder.n_layers,
            n_heads=model_config.decoder.n_heads,
            embedding_dim=model_config.speaker_embed_dim + model_config.lang_embed_dim,
            p_dropout=model_config.decoder.p_dropout,
            kernel_size_conv_mod=model_config.decoder.kernel_size_conv_mod,
            with_ff=model_config.decoder.with_ff,
        )

        self.src_word_emb = Parameter(
            initialize_embeddings(
                (len(symbols), model_config.encoder.n_hidden),
            ),
        )

        self.to_mel = nn.Linear(
            model_config.decoder.n_hidden,
            preprocess_config.stft.n_mel_channels,
        )

        # NOTE: here you can manage the speaker embeddings, can be used for the voice export ?
        # NOTE: flexibility of the model binded by the n_speaker parameter, maybe I can find another way?
        # NOTE: in LIBRITTS there are 2477 speakers, we can add more, just extend the speaker_embed matrix
        # Need to think about it more
        self.speaker_embed = Parameter(
            initialize_embeddings(
                (n_speakers, model_config.speaker_embed_dim),
            ),
        )

        self.lang_embed = Parameter(
            initialize_embeddings(
                (len(SUPPORTED_LANGUAGES), model_config.lang_embed_dim),
            ),
        )

    def get_embeddings(
        self,
        token_idx: torch.Tensor,
        speaker_idx: torch.Tensor,
        src_mask: torch.Tensor,
        lang_idx: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Given the tokens, speakers, source mask, and language indices, compute
        the embeddings for tokens, speakers and languages and return the
        token_embeddings and combined speaker and language embeddings

        Args:
            token_idx (torch.Tensor): Tensor of token indices.
            speaker_idx (torch.Tensor): Tensor of speaker identities.
            src_mask (torch.Tensor): Mask tensor for source sequences.
            lang_idx (torch.Tensor): Tensor of language indices.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Token embeddings tensor,
            and combined speaker and language embeddings tensor.
        """
        token_embeddings = F.embedding(token_idx, self.src_word_emb)
        # NOTE: here you can manage the speaker embeddings, can be used for the voice export ?
        speaker_embeds = F.embedding(speaker_idx, self.speaker_embed)
        lang_embeds = F.embedding(lang_idx, self.lang_embed)

        # Merge the speaker and language embeddings
        embeddings = torch.cat([speaker_embeds, lang_embeds], dim=2)

        # Apply the mask to the embeddings and token embeddings
        embeddings = embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0)
        token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0)

        return token_embeddings, embeddings

    def prepare_for_export(self) -> None:
        r"""Prepare the model for export.

        This method is called when the model is about to be exported, such as for deployment
        or serializing for later use. The method removes unnecessary components that are
        not needed during inference. Specifically, it removes the phoneme and utterance
        prosody encoders for this acoustic model. These components are typically used during
        training and are not needed when the model is used for making predictions.

        Returns
            None
        """
        del self.phoneme_prosody_encoder
        del self.utterance_prosody_encoder

    # NOTE: freeze/unfreeze params changed, because of the conflict with the lightning module
    def freeze_params(self) -> None:
        r"""Freeze the trainable parameters in the model.

        By freezing, the parameters are no longer updated by gradient descent.
        This is typically done when you want to keep parts of your model fixed while training other parts.
        For this model, it freezes all parameters and then selectively unfreezes the
        speaker embeddings and the pitch adaptor's pitch embeddings to allow these components to update during training.

        Returns
            None
        """
        for par in self.parameters():
            par.requires_grad = False
        self.speaker_embed.requires_grad = True

    # NOTE: freeze/unfreeze params changed, because of the conflict with the lightning module
    def unfreeze_params(self, freeze_text_embed: bool, freeze_lang_embed: bool) -> None:
        r"""Unfreeze the trainable parameters in the model, allowing them to be updated during training.

        This method is typically used to 'unfreeze' previously 'frozen' parameters, making them trainable again.
        For this model, it unfreezes all parameters and then selectively freezes the
        text embeddings and language embeddings, if required.

        Args:
            freeze_text_embed (bool): Flag to indicate if text embeddings should remain frozen.
            freeze_lang_embed (bool): Flag to indicate if language embeddings should remain frozen.

        Returns:
            None
        """
        # Iterate through all model parameters and make them trainable
        for par in self.parameters():
            par.requires_grad = True

        # If freeze_text_embed flag is True, keep the source word embeddings frozen
        if freeze_text_embed:
            # @fixed self.src_word_emb.parameters has no parameters() method!
            # for par in self.src_word_emb.parameters():
            self.src_word_emb.requires_grad = False

        # If freeze_lang_embed flag is True, keep the language embeddings frozen
        if freeze_lang_embed:
            self.lang_embed.requires_grad = False

    def average_utterance_prosody(
        self,
        u_prosody_pred: torch.Tensor,
        src_mask: torch.Tensor,
    ) -> torch.Tensor:
        r"""Compute the average utterance prosody over the length of non-masked elements.

        This method averages the output of the utterance prosody predictor over
        the sequence lengths (non-masked elements). This function will return
        a tensor with the same first dimension but singleton trailing dimensions.

        Args:
            u_prosody_pred (torch.Tensor): Tensor containing the predicted utterance prosody of dimension (batch_size, T, n_features).
            src_mask (torch.Tensor): Tensor of dimension (batch_size, T) acting as a mask where masked entries are set to False.

        Returns:
            torch.Tensor: Tensor of dimension (batch_size, 1, n_features) containing average utterance prosody over non-masked sequence length.
        """
        # Compute the real sequence lengths by negating the mask and summing along the sequence dimension
        lengths = ((~src_mask) * 1.0).sum(1)

        # Compute the sum of u_prosody_pred across the sequence length dimension,
        #  then divide by the sequence lengths tensor to calculate the average.
        #  This performs a broadcasting operation to account for the third dimension (n_features).
        # Return the averaged prosody prediction
        return u_prosody_pred.sum(1, keepdim=True) / lengths.view(-1, 1, 1)

    def forward_train(
        self,
        x: torch.Tensor,
        speakers: torch.Tensor,
        src_lens: torch.Tensor,
        mels: torch.Tensor,
        mel_lens: torch.Tensor,
        pitches: torch.Tensor,
        langs: torch.Tensor,
        attn_priors: Union[torch.Tensor, None],
        energies: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        r"""Forward pass during training phase.

        For a given phoneme sequence, speaker identities, sequence lengths, mels,
        mel lengths, pitches, language, and attention priors, the forward pass
        processes these inputs through the defined architecture.

        Args:
            x (torch.Tensor): Tensor of phoneme sequence.
            speakers (torch.Tensor): Tensor of speaker identities.
            src_lens (torch.Tensor): Long tensor representing the lengths of source sequences.
            mels (torch.Tensor): Tensor of mel spectrograms.
            mel_lens (torch.Tensor): Long tensor representing the lengths of mel sequences.
            pitches (torch.Tensor): Tensor of pitch values.
            langs (torch.Tensor): Tensor of language identities.
            attn_priors (torch.Tensor): Prior attention values.
            energies (torch.Tensor): Tensor of energy values.

        Returns:
            Dict[str, torch.Tensor]: Returns the prediction outputs as a dictionary.
        """
        # Generate masks for padding positions in the source sequences and mel sequences
        src_mask = get_mask_from_lengths(src_lens)
        mel_mask = get_mask_from_lengths(mel_lens)

        x, embeddings = self.get_embeddings(
            token_idx=x,
            speaker_idx=speakers,
            src_mask=src_mask,
            lang_idx=langs,
        )

        encoding = positional_encoding(
            self.emb_dim,
            max(x.shape[1], int(mel_lens.max().item())),
        )
        x = x.to(src_mask.device)
        encoding = encoding.to(src_mask.device)
        embeddings = embeddings.to(src_mask.device)

        x = self.encoder(x, src_mask, embeddings=embeddings, encoding=encoding)

        u_prosody_ref = self.u_norm(
            self.utterance_prosody_encoder(mels=mels, mel_lens=mel_lens),
        )
        u_prosody_pred = self.u_norm(
            self.average_utterance_prosody(
                u_prosody_pred=self.utterance_prosody_predictor(x=x, mask=src_mask),
                src_mask=src_mask,
            ),
        )

        p_prosody_ref = self.p_norm(
            self.phoneme_prosody_encoder(
                x=x,
                src_mask=src_mask,
                mels=mels,
                mel_lens=mel_lens,
                encoding=encoding,
            ),
        )
        p_prosody_pred = self.p_norm(
            self.phoneme_prosody_predictor(
                x=x,
                mask=src_mask,
            ),
        )

        x = x + self.u_bottle_out(u_prosody_pred)
        x = x + self.p_bottle_out(p_prosody_pred)

        # Save the residual for later use
        x_res = x

        attn_logprob, attn_soft, attn_hard, attn_hard_dur = self.aligner(
            enc_in=x_res.permute((0, 2, 1)),
            dec_in=mels,
            enc_len=src_lens,
            dec_len=mel_lens,
            enc_mask=src_mask,
            attn_prior=attn_priors,
        )

        attn_hard_dur = attn_hard_dur.to(src_mask.device)

        x, pitch_prediction, avg_pitch_target = (
            self.pitch_adaptor_conv.add_pitch_embedding_train(
                x=x,
                target=pitches,
                dr=attn_hard_dur,
                mask=src_mask,
            )
        )

        energies = energies.to(src_mask.device)

        x, energy_pred, avg_energy_target = (
            self.energy_adaptor.add_energy_embedding_train(
                x=x,
                target=energies,
                dr=attn_hard_dur,
                mask=src_mask,
            )
        )

        x, log_duration_prediction, embeddings = self.length_regulator.upsample_train(
            x=x,
            x_res=x_res,
            duration_target=attn_hard_dur,
            src_mask=src_mask,
            embeddings=embeddings,
        )

        # Decode the encoder output to pred mel spectrogram
        decoder_output = self.decoder(
            x,
            mel_mask,
            embeddings=embeddings,
            encoding=encoding,
        )

        y_pred = self.to_mel(decoder_output)
        y_pred = y_pred.permute((0, 2, 1))

        return {
            "y_pred": y_pred,
            "pitch_prediction": pitch_prediction,
            "pitch_target": avg_pitch_target,
            "energy_pred": energy_pred,
            "energy_target": avg_energy_target,
            "log_duration_prediction": log_duration_prediction,
            "u_prosody_pred": u_prosody_pred,
            "u_prosody_ref": u_prosody_ref,
            "p_prosody_pred": p_prosody_pred,
            "p_prosody_ref": p_prosody_ref,
            "attn_logprob": attn_logprob,
            "attn_soft": attn_soft,
            "attn_hard": attn_hard,
            "attn_hard_dur": attn_hard_dur,
        }

    def forward(
        self,
        x: torch.Tensor,
        speakers: torch.Tensor,
        langs: torch.Tensor,
        d_control: float = 1.0,
    ) -> torch.Tensor:
        r"""Forward pass during model inference.

        The forward pass receives phoneme sequence, speaker identities, languages, pitch control and
        duration control, conducts a series of operations on these inputs and returns the predicted mel
        spectrogram.

        Args:
            x (torch.Tensor): Tensor of phoneme sequences.
            speakers (torch.Tensor): Tensor of speaker identities.
            langs (torch.Tensor): Tensor of language identities.
            d_control (float): Duration control parameter. Defaults to 1.0.

        Returns:
            torch.Tensor: Predicted mel spectrogram.
        """
        # Generate masks for padding positions in the source sequences
        src_mask = get_mask_from_lengths(
            torch.tensor([x.shape[1]], dtype=torch.int64),
        ).to(x.device)

        # Obtain the embeddings for the input
        x, embeddings = self.get_embeddings(
            token_idx=x,
            speaker_idx=speakers,
            src_mask=src_mask,
            lang_idx=langs,
        )

        # Generate positional encodings
        encoding = positional_encoding(
            self.emb_dim,
            x.shape[1],
        ).to(x.device)

        # Process the embeddings through the encoder
        x = self.encoder(x, src_mask, embeddings=embeddings, encoding=encoding)

        # Predict prosody at utterance level and phoneme level
        u_prosody_pred = self.u_norm(
            self.average_utterance_prosody(
                u_prosody_pred=self.utterance_prosody_predictor(x=x, mask=src_mask),
                src_mask=src_mask,
            ),
        )
        p_prosody_pred = self.p_norm(
            self.phoneme_prosody_predictor(
                x=x,
                mask=src_mask,
            ),
        )

        x = x + self.u_bottle_out(u_prosody_pred)
        x = x + self.p_bottle_out(p_prosody_pred)

        x_res = x

        x, _ = self.pitch_adaptor_conv.add_pitch_embedding(
            x=x,
            mask=src_mask,
        )

        x, _ = self.energy_adaptor.add_energy_embedding(
            x=x,
            mask=src_mask,
        )

        x, _, embeddings = self.length_regulator.upsample(
            x=x,
            x_res=x_res,
            src_mask=src_mask,
            control=d_control,
            embeddings=embeddings,
        )

        mel_mask = get_mask_from_lengths(
            torch.tensor([x.shape[1]], dtype=torch.int64),
        ).to(x.device)

        if x.shape[1] > encoding.shape[1]:
            encoding = positional_encoding(self.emb_dim, x.shape[1]).to(x.device)

        decoder_output = self.decoder(
            x,
            mel_mask,
            embeddings=embeddings,
            encoding=encoding,
        )

        x = self.to_mel(decoder_output)
        x = x.permute((0, 2, 1))

        return x

In [56]:
class BinLoss(Module):
    r"""Binary cross-entropy loss for hard and soft attention.

    Attributes
        None

    Methods
        forward: Computes the binary cross-entropy loss for hard and soft attention.

    """

    def __init__(self):
        super().__init__()

    def forward(
        self, hard_attention: torch.Tensor, soft_attention: torch.Tensor,
    ) -> torch.Tensor:
        r"""Computes the binary cross-entropy loss for hard and soft attention.

        Args:
            hard_attention (torch.Tensor): A binary tensor indicating the hard attention.
            soft_attention (torch.Tensor): A tensor containing the soft attention probabilities.

        Returns:
            torch.Tensor: The binary cross-entropy loss.

        """
        log_sum = torch.log(
            torch.clamp(soft_attention[hard_attention == 1], min=1e-12),
        ).sum()
        return -log_sum / hard_attention.sum()

In [57]:
class ForwardSumLoss(Module):
    r"""Computes the forward sum loss for sequence-to-sequence models with attention.

    Args:
        blank_logprob (float): The log probability of the blank symbol. Default: -1.

    Attributes:
        log_softmax (nn.LogSoftmax): The log softmax function.
        ctc_loss (nn.CTCLoss): The CTC loss function.
        blank_logprob (float): The log probability of the blank symbol.

    Methods:
        forward: Computes the forward sum loss for sequence-to-sequence models with attention.

    """

    def __init__(self, blank_logprob: float = -1):
        super().__init__()
        self.log_softmax = nn.LogSoftmax(dim=3)
        self.ctc_loss = nn.CTCLoss(zero_infinity=True)
        self.blank_logprob = blank_logprob

    def forward(
        self, attn_logprob: torch.Tensor, in_lens: torch.Tensor, out_lens: torch.Tensor,
    ) -> float:
        r"""Computes the forward sum loss for sequence-to-sequence models with attention.

        Args:
            attn_logprob (torch.Tensor): The attention log probabilities of shape (batch_size, max_out_len, max_in_len).
            in_lens (torch.Tensor): The input lengths of shape (batch_size,).
            out_lens (torch.Tensor): The output lengths of shape (batch_size,).

        Returns:
            float: The forward sum loss.

        """
        key_lens = in_lens
        query_lens = out_lens
        attn_logprob_padded = F.pad(
            input=attn_logprob, pad=(1, 0), value=self.blank_logprob,
        )

        total_loss = 0.0
        for bid in range(attn_logprob.shape[0]):
            target_seq = torch.arange(1, int(key_lens[bid]) + 1).unsqueeze(0)
            curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[
                : int(query_lens[bid]), :, : int(key_lens[bid]) + 1,
            ]

            curr_logprob = self.log_softmax(curr_logprob[None])[0]
            loss = self.ctc_loss(
                curr_logprob,
                target_seq,
                input_lengths=query_lens[bid : bid + 1],
                target_lengths=key_lens[bid : bid + 1],
            )
            total_loss += loss

        total_loss /= attn_logprob.shape[0]
        return total_loss

In [58]:
def sample_wise_min_max(x: Tensor) -> Tensor:
    r"""Applies sample-wise min-max normalization to a tensor.

    Args:
        x (torch.Tensor): Input tensor of shape (batch_size, num_samples, num_features).

    Returns:
        torch.Tensor: Normalized tensor of the same shape as the input tensor.
    """
    # Compute the maximum and minimum values of each sample in the batch
    maximum = torch.amax(x, dim=(1, 2), keepdim=True)
    minimum = torch.amin(x, dim=(1, 2), keepdim=True)

    # Apply sample-wise min-max normalization to the input tensor
    return (x - minimum) / (maximum - minimum)

In [59]:
from piq import SSIMLoss

class FastSpeech2LossGen(Module):
    def __init__(
        self,
        bin_warmup: bool = True,
        binarization_loss_enable_steps: int = 1260,
        binarization_loss_warmup_steps: int = 700,
    ):
        r"""Initializes the FastSpeech2LossGen module.

        Args:
            bin_warmup (bool, optional): Whether to use binarization warmup. Defaults to True. NOTE: Switch this off if you preload the model with a checkpoint that has already passed the warmup phase.
            binarization_loss_enable_steps (int, optional): Number of steps to enable the binarization loss. Defaults to 1260.
            binarization_loss_warmup_steps (int, optional): Number of warmup steps for the binarization loss. Defaults to 700.
        """
        super().__init__()

        self.mse_loss = nn.MSELoss()
        self.mae_loss = nn.L1Loss()
        self.ssim_loss = SSIMLoss()
        self.sum_loss = ForwardSumLoss()
        self.bin_loss = BinLoss()

        self.bin_warmup = bin_warmup
        self.binarization_loss_enable_steps = binarization_loss_enable_steps
        self.binarization_loss_warmup_steps = binarization_loss_warmup_steps

    def forward(
        self,
        src_masks: torch.Tensor,
        mel_masks: torch.Tensor,
        mel_targets: torch.Tensor,
        mel_predictions: torch.Tensor,
        log_duration_predictions: torch.Tensor,
        u_prosody_ref: torch.Tensor,
        u_prosody_pred: torch.Tensor,
        p_prosody_ref: torch.Tensor,
        p_prosody_pred: torch.Tensor,
        durations: torch.Tensor,
        pitch_predictions: torch.Tensor,
        p_targets: torch.Tensor,
        attn_logprob: torch.Tensor,
        attn_soft: torch.Tensor,
        attn_hard: torch.Tensor,
        step: int,
        src_lens: torch.Tensor,
        mel_lens: torch.Tensor,
        energy_pred: torch.Tensor,
        energy_target: torch.Tensor,
    ) -> Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ]:
        r"""Computes the loss for the FastSpeech2 model.

        Args:
            src_masks (torch.Tensor): Mask for the source sequence.
            mel_masks (torch.Tensor): Mask for the mel-spectrogram.
            mel_targets (torch.Tensor): Target mel-spectrogram.
            mel_predictions (torch.Tensor): Predicted mel-spectrogram.
            log_duration_predictions (torch.Tensor): Predicted log-duration.
            u_prosody_ref (torch.Tensor): Reference unvoiced prosody.
            u_prosody_pred (torch.Tensor): Predicted unvoiced prosody.
            p_prosody_ref (torch.Tensor): Reference voiced prosody.
            p_prosody_pred (torch.Tensor): Predicted voiced prosody.
            durations (torch.Tensor): Ground-truth durations.
            pitch_predictions (torch.Tensor): Predicted pitch.
            p_targets (torch.Tensor): Ground-truth pitch.
            attn_logprob (torch.Tensor): Log-probability of attention.
            attn_soft (torch.Tensor): Soft attention.
            attn_hard (torch.Tensor): Hard attention.
            step (int): Current training step.
            src_lens (torch.Tensor): Lengths of the source sequences.
            mel_lens (torch.Tensor): Lengths of the mel-spectrograms.
            energy_pred (torch.Tensor): Predicted energy.
            energy_target (torch.Tensor): Ground-truth energy.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: The total loss and its components.

        Note:
            Here is the description of the returned loss components:
            `total_loss`: This is the total loss computed as the sum of all the other losses.
            `mel_loss`: This is the mean absolute error (MAE) loss between the predicted and target mel-spectrograms. It measures how well the model predicts the mel-spectrograms.
            `sc_mag_loss`: This is the spectral convergence loss between the predicted and target mel-spectrograms. It measures how well the model predicts the mel-spectrograms in terms of their spectral structure.
            `log_mag_loss`: This is the log STFT magnitude loss between the predicted and target mel-spectrograms. It measures how well the model predicts the mel-spectrograms in terms of their spectral structure.
            `ssim_loss`: This is the Structural Similarity Index (SSIM) loss between the predicted and target mel-spectrograms. It measures the similarity between the two mel-spectrograms in terms of their structure, contrast, and luminance.
            `duration_loss`: This is the mean squared error (MSE) loss between the predicted and target log-durations. It measures how well the model predicts the durations of the phonemes.
            `u_prosody_loss`: This is the MAE loss between the predicted and reference unvoiced prosody. It measures how well the model predicts the prosody (rhythm, stress, and intonation) of the unvoiced parts of the speech.
            `p_prosody_loss`: This is the MAE loss between the predicted and reference voiced prosody. It measures how well the model predicts the prosody of the voiced parts of the speech.
            `pitch_loss`: This is the MSE loss between the predicted and target pitch. It measures how well the model predicts the pitch of the speech.
            `ctc_loss`: This is the Connectionist Temporal Classification (CTC) loss computed from the log-probability of attention and the lengths of the source sequences and mel-spectrograms. It measures how well the model aligns the input and output sequences.
            `bin_loss`: This is the binarization loss computed from the hard and soft attention. It measures how well the model learns to attend to the correct parts of the input sequence.
            `energy_loss`: This is the MSE loss between the predicted and target energy. It measures how well the model predicts the energy of the speech.
        """
        log_duration_targets = torch.log(durations.float() + 1).to(src_masks.device)

        log_duration_targets.requires_grad = False
        mel_targets.requires_grad = False
        p_targets.requires_grad = False
        energy_target.requires_grad = False

        log_duration_predictions = log_duration_predictions.masked_select(~src_masks)
        log_duration_targets = log_duration_targets.masked_select(~src_masks)

        mel_masks_expanded = mel_masks.unsqueeze(1)

        mel_predictions_normalized = (
            sample_wise_min_max(mel_predictions).float().to(mel_predictions.device)
        )
        mel_targets_normalized = (
            sample_wise_min_max(mel_targets).float().to(mel_predictions.device)
        )

        ssim_loss: torch.Tensor = self.ssim_loss(
            mel_predictions_normalized.unsqueeze(1),
            mel_targets_normalized.unsqueeze(1),
        )

        if ssim_loss.item() > 1.0 or ssim_loss.item() < 0.0:
            ssim_loss = torch.tensor([1.0], device=mel_predictions.device)

        masked_mel_predictions = mel_predictions.masked_select(~mel_masks_expanded)

        masked_mel_targets = mel_targets.masked_select(~mel_masks_expanded)

        mel_loss: torch.Tensor = self.mae_loss(
            masked_mel_predictions,
            masked_mel_targets,
        )

        p_prosody_ref = p_prosody_ref.permute((0, 2, 1))
        p_prosody_pred = p_prosody_pred.permute((0, 2, 1))

        p_prosody_ref = p_prosody_ref.masked_fill(src_masks.unsqueeze(1), 0.0)
        p_prosody_pred = p_prosody_pred.masked_fill(src_masks.unsqueeze(1), 0.0)

        p_prosody_ref = p_prosody_ref.detach()

        p_prosody_loss: torch.Tensor = 0.5 * self.mae_loss(
            p_prosody_ref.masked_select(~src_masks.unsqueeze(1)),
            p_prosody_pred.masked_select(~src_masks.unsqueeze(1)),
        )

        u_prosody_ref = u_prosody_ref.detach()
        u_prosody_loss: torch.Tensor = 0.5 * self.mae_loss(
            u_prosody_ref,
            u_prosody_pred,
        )

        duration_loss: torch.Tensor = self.mse_loss(
            log_duration_predictions,
            log_duration_targets,
        )

        pitch_predictions = pitch_predictions.masked_select(~src_masks)
        p_targets = p_targets.masked_select(~src_masks)

        pitch_loss: torch.Tensor = self.mse_loss(pitch_predictions, p_targets)

        ctc_loss: torch.Tensor = self.sum_loss(
            attn_logprob=attn_logprob,
            in_lens=src_lens,
            out_lens=mel_lens,
        )

        if self.bin_warmup:
            if step < self.binarization_loss_enable_steps:
                bin_loss_weight = 0.0
            else:
                bin_loss_weight = (
                    min(
                        (step - self.binarization_loss_enable_steps)
                        / self.binarization_loss_warmup_steps,
                        1.0,
                    )
                    * 1.0
                )

            bin_loss: torch.Tensor = (
                self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft)
                * bin_loss_weight
            )
        else:
            bin_loss: torch.Tensor = self.bin_loss(
                hard_attention=attn_hard,
                soft_attention=attn_soft,
            )

        energy_loss: torch.Tensor = self.mse_loss(energy_pred, energy_target)

        total_loss = (
            mel_loss
            + duration_loss
            + u_prosody_loss
            + p_prosody_loss
            + ssim_loss
            + pitch_loss
            + ctc_loss
            + bin_loss
            + energy_loss
        )

        return (
            total_loss,
            mel_loss,
            ssim_loss,
            duration_loss,
            u_prosody_loss,
            p_prosody_loss,
            pitch_loss,
            ctc_loss,
            bin_loss,
            energy_loss,
        )

  from .autonotebook import tqdm as notebook_tqdm


In [60]:
# The selected speakers from the HiFiTTS dataset
speakers_hifi_ids = [
    # "Cori Samuel",  # 92,
    # "Tony Oliva",  # 6671,
    # "John Van Stan",  # 9017,
    # "Helen Taylor",  # 9136,
    # "Phil Benson",  # 6097,
    # "Mike Pelton",  # 6670,
    # "Maria Kasper",  # 8051,
    # "Sylviamb",  # 11614,
    # "Celine Major",  # 11697,
    # "LikeManyWaters",  # 12787,
]

# The selected speakers from the LibriTTS dataset
speakers_libri_ids = list(
    map(
        str,
        [
            84
            # train-clean-100
            # 40,
            # 1088,
            # train-clean-360
            # 3307,
            # 5935,
            # train-other-500
            # 215,
            # 6594,
            # 3867,
            # 5733,
            # 5181,
        ],
    ),
)

selected_speakers_ids = {
    v: k
    for k, v in enumerate(
        speakers_hifi_ids + speakers_libri_ids,
    )
}

In [61]:
@dataclass
class PreprocessingConfigHifiGAN(PreprocessingConfig):
    stft: STFTConfig = field(
        default_factory=lambda: STFTConfig(
            filter_length=1024,
            hop_length=256,
            win_length=1024,
            n_mel_channels=80,  # For univnet 100
            mel_fmin=20,
            mel_fmax=11025,
        ),
    )

    def __post_init__(self):
        r"""It modifies the 'stft' attribute based on the 'sampling_rate' attribute.
        If 'sampling_rate' is 44100, 'stft' is set with specific values for this rate.
        If 'sampling_rate' is not 22050 or 44100, a ValueError is raised.

        Raises:
            ValueError: If 'sampling_rate' is not 22050 or 44100.
        """
        if self.sampling_rate == 44100:
            self.stft = STFTConfig(
                filter_length=2048,
                hop_length=512,  # NOTE: 441 ?? https://github.com/jik876/hifi-gan/issues/116#issuecomment-1436999858
                win_length=2048,
                n_mel_channels=80,  # Based on https://github.com/jik876/hifi-gan/issues/116
                mel_fmin=20,
                mel_fmax=11025,
            )
        if self.sampling_rate not in [22050, 44100]:
            raise ValueError("Sampling rate must be 22050 or 44100")


In [62]:
from lhotse import CutSet, RecordingSet, SupervisionSet

def prep_2_cutset(prep: Dict[str, Dict[str, RecordingSet | SupervisionSet]]) -> CutSet:
    r"""Prepare the dataset for the model. This function is used to convert the prepared dataset to a CutSet.

    Args:
        prep (Dict[str, Dict[str, RecordingSet | SupervisionSet]]): The prepared dataset.

    Returns:
        CutSet: The dataset prepared for the model.
    """
    recordings_hifi = RecordingSet()
    supervisions_hifi = SupervisionSet()

    for hifi_row in prep.values():
        record = hifi_row["recordings"]
        supervision = hifi_row["supervisions"]

        # Separate the recordings and supervisions
        if isinstance(record, RecordingSet):
            recordings_hifi += record

        if isinstance(supervision, SupervisionSet):
            supervisions_hifi += supervision

    # Add the recordings and supervisions to the CutSet
    return CutSet.from_manifests(
        recordings=recordings_hifi,
        supervisions=supervisions_hifi,
    )


DATASET_TYPES = Literal["hifitts", "libritts"]


@dataclass
class HifiLibriItem:
    """Dataset row for the HiFiTTS and LibriTTS datasets combined in this code.

    Args:
        id (str): The ID of the item.
        wav (Tensor): The waveform of the audio.
        mel (Tensor): The mel spectrogram.
        pitch (Tensor): The pitch.
        text (Tensor): The text.
        attn_prior (Tensor): The attention prior.
        energy (Tensor): The energy.
        raw_text (str): The raw text.
        normalized_text (str): The normalized text.
        speaker (int): The speaker ID.
        pitch_is_normalized (bool): Whether the pitch is normalized.
        lang (int): The language ID.
        dataset_type (DATASET_TYPES): The type of dataset.
    """

    id: str
    wav: Tensor
    mel: Tensor
    pitch: Tensor
    text: Tensor
    attn_prior: Tensor
    energy: Tensor
    raw_text: str
    normalized_text: str
    speaker: int
    pitch_is_normalized: bool
    lang: int
    dataset_type: DATASET_TYPES


In [63]:
from dataclasses import asdict, dataclass
import os
from pathlib import Path
import tempfile
from typing import Dict, List, Literal, Optional, Tuple
from lhotse.cut import MonoCut
from lhotse.recipes import hifitts, libritts
import numpy as np
import soundfile as sf
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from voicefixer import VoiceFixer

# from models.config import PreprocessingConfigHifiGAN as PreprocessingConfig
# from models.config import get_lang_map, lang2id
# from training.preprocess import PreprocessLibriTTS
# from training.tools import pad_1D, pad_2D, pad_3D

NUM_JOBS = (os.cpu_count() or 2) - 1

class HifiLibriDataset(Dataset):
    r"""A PyTorch dataset for loading delightful TTS data."""

    def __init__(
        self,
        lang: str = "en",
        root: str = "datasets_cache",
        sampling_rate: int = 44100,
        hifitts_path: str = "hifitts",
        hifi_cutset_file_name: str = "hifi.json.gz",
        libritts_path: str = "librittsr",
        libritts_cutset_file_name: str = "libri.json.gz",
        libritts_subsets: List[str] | str = "dev-clean",
        cache: bool = False,
        cache_dir: str = "/dev/shm",
        num_jobs: int = NUM_JOBS,
        min_seconds: Optional[float] = None,
        max_seconds: Optional[float] = None,
        include_libri: bool = True,
        libri_speakers: List[str] = speakers_libri_ids,
        hifi_speakers: List[str] = speakers_hifi_ids,
    ):
        r"""Initializes the dataset.

        Args:
            lang (str, optional): The language of the dataset. Defaults to "en".
            root (str, optional): The root directory of the dataset. Defaults to "datasets_cache".
            sampling_rate (int, optional): The sampling rate of the audio. Defaults to 44100.
            hifitts_path (str, optional): The path to the HiFiTTS dataset. Defaults to "hifitts".
            hifi_cutset_file_name (str, optional): The file name of the HiFiTTS cutset. Defaults to "hifi.json.gz".
            libritts_path (str, optional): The path to the LibriTTS dataset. Defaults to "librittsr".
            libritts_cutset_file_name (str, optional): The file name of the LibriTTS cutset. Defaults to "libri.json.gz".
            libritts_subsets (Union[List[str], str], optional): The subsets of the LibriTTS dataset to use. Defaults to "all".
            cache (bool, optional): Whether to cache the dataset. Defaults to False.
            cache_dir (str, optional): The directory to cache the dataset in. Defaults to "/dev/shm".
            num_jobs (int, optional): The number of jobs to use for preparing the dataset. Defaults to NUM_JOBS.
            min_seconds (Optional[float], optional): The minimum duration of the audio. Defaults from the preprocess config.
            max_seconds (Optional[float], optional): The maximum duration of the audio. Defaults from the preprocess config.
            include_libri (bool, optional): Whether to include the LibriTTS dataset. Defaults to True.
            libri_speakers (List[str], optional): The selected speakers from the LibriTTS dataset. Defaults to selected_speakers_libri_ids.
            hifi_speakers (List[str], optional): The selected speakers from the HiFiTTS dataset. Defaults to selected_speakers_hi_fi_ids.
        """
        lang_map = get_lang_map(lang)
        processing_lang_type = lang_map.processing_lang_type
        self.preprocess_config = PreprocessingConfigHifiGAN(
            processing_lang_type,
            sampling_rate=sampling_rate,
        )

        self.min_seconds = min_seconds or self.preprocess_config.min_seconds
        self.max_seconds = max_seconds or self.preprocess_config.max_seconds

        self.dur_filter = (
            lambda duration: duration >= self.min_seconds
            and duration <= self.max_seconds
        )

        self.preprocess_libtts = PreprocessLibriTTS(
            self.preprocess_config,
            lang,
        )
        self.root_dir = Path(root)
        self.voicefixer = VoiceFixer()

        # Map the speaker ids to string and list of selected speaker ids to set
        self.selected_speakers_libri_ids_ = set(libri_speakers)
        self.selected_speakers_hi_fi_ids_ = set(hifi_speakers)

        self.cache = cache
        self.cache_dir = Path(cache_dir) / f"cache-{libritts_path}"

        # Prepare the HiFiTTS dataset
        self.hifitts_path = self.root_dir / hifitts_path
        hifi_cutset_file_path = self.root_dir / hifi_cutset_file_name

        # Initialize the cutset
        self.cutset = CutSet()

        # Check if the HiFiTTS dataset has been prepared
        # if hifi_cutset_file_path.exists():
        #     self.cutset_hifi = CutSet.from_file(hifi_cutset_file_path)
        # else:
        #     hifitts_root = hifitts.download_hifitts(self.hifitts_path)
        #     prepared_hifi = hifitts.prepare_hifitts(
        #         hifitts_root,
        #         num_jobs=num_jobs,
        #     )

        #     # Add the recordings and supervisions to the CutSet
        #     self.cutset_hifi = prep_2_cutset(prepared_hifi)
        #     # Save the prepared HiFiTTS dataset cutset
        #     self.cutset_hifi.to_file(hifi_cutset_file_path)

        # # Filter the HiFiTTS cutset to only include the selected speakers
        # self.cutset_hifi = self.cutset_hifi.filter(
        #     lambda cut: isinstance(cut, MonoCut)
        #     and str(cut.supervisions[0].speaker) in self.selected_speakers_hi_fi_ids_
        #     and self.dur_filter(cut.duration),
        # ).to_eager()

        # Add the HiFiTTS cutset to the final cutset
        # self.cutset += self.cutset_hifi

        if include_libri:
            # Prepare the LibriTTS dataset
            self.libritts_path = self.root_dir / libritts_path
            libritts_cutset_file_path = self.root_dir / libritts_cutset_file_name

            # Check if the LibriTTS dataset has been prepared
            if libritts_cutset_file_path.exists():
                self.cutset_libri = CutSet.from_file(libritts_cutset_file_path)
            else:
                libritts_root = libritts.download_librittsr(
                    self.libritts_path,
                    dataset_parts=libritts_subsets,
                )
                prepared_libri = libritts.prepare_librittsr(
                    libritts_root / "LibriTTS_R",
                    dataset_parts=libritts_subsets,
                    num_jobs=num_jobs,
                )

                # Add the recordings and supervisions to the CutSet
                self.cutset_libri = prep_2_cutset(prepared_libri)
                # Save the prepared cutset for LibriTTS
                self.cutset_libri.to_file(libritts_cutset_file_path)

            # Filter the libri cutset to only include the selected speakers
            self.cutset_libri = self.cutset_libri.filter(
                lambda cut: isinstance(cut, MonoCut)
                and str(cut.supervisions[0].speaker)
                in self.selected_speakers_libri_ids_
                and self.dur_filter(cut.duration),
            ).to_eager()

            # Add the LibriTTS cutset to the final cutset
            self.cutset += self.cutset_libri

        # to_eager() is used to evaluates all lazy operations on this manifest
        self.cutset = self.cutset.to_eager()

    def get_cache_subdir_path(self, idx: int) -> Path:
        r"""Calculate the path to the cache subdirectory.

        Args:
            idx (int): The index of the cache subdirectory.

        Returns:
            Path: The path to the cache subdirectory.
        """
        return self.cache_dir / str(((idx // 1000) + 1) * 1000)

    def get_cache_file_path(self, idx: int) -> Path:
        r"""Calculate the path to the cache file.

        Args:
            idx (int): The index of the cache file.

        Returns:
            Path: The path to the cache file.
        """
        return self.get_cache_subdir_path(idx) / f"{idx}.pt"

    def __len__(self) -> int:
        r"""Returns the length of the dataset.

        Returns:
            int: The length of the dataset.
        """
        return len(self.cutset)

    def __getitem__(self, idx: int) -> HifiLibriItem:
        r"""Returns the item at the specified index.

        Args:
            idx (int): The index of the item.

        Returns:
            HifiLibriItem: The item at the specified index.
        """
        cache_file = self.get_cache_file_path(idx)

        if self.cache and cache_file.exists():
            cached_data: Dict = torch.load(cache_file)
            # Cast the cached data to the PreprocessForAcousticResult class
            result = HifiLibriItem(**cached_data)
            return result

        print(self.cutset)
        print(idx)
        cutset = self.cutset[idx]

        if isinstance(cutset, MonoCut) and cutset.recording is not None:
            dataset_speaker_id = str(cutset.supervisions[0].speaker)

            # Map the dataset speaker id to the speaker id in the model
            speaker_id = selected_speakers_ids.get(
                dataset_speaker_id,
                len(selected_speakers_ids) + 1,
            )

            # Run voicefixer only for the libri speakers
            # if str(dataset_speaker_id) in self.selected_speakers_libri_ids_:
            #     audio_path = cutset.recording.sources[0].source
            #     # Restore LibriTTS-R audio
            #     with tempfile.NamedTemporaryFile(
            #         suffix=".wav",
            #         delete=True,
            #     ) as out_file:
            #         self.voicefixer.restore(
            #             input=audio_path,  # low quality .wav/.flac file
            #             output=out_file.name,  # save file path
            #             cuda=False,  # GPU acceleration
            #             mode=0,
            #         )
            #         audio, _ = sf.read(out_file.name)
            #         # Convert the np audio to a tensor
            #         audio = torch.from_numpy(audio).float().unsqueeze(0)
            # else:
            #     # Load the audio from the cutset
            #     audio = torch.from_numpy(cutset.load_audio())


            #     detach = audio.data.detach().tolist()       
            #     audio = np.array(detach, dtype=float)
            # # audio
            # audio_np = resample(audio, orig_sr=sr_actual, target_sr=sr)
            # # Convert back to torch tensor
            # audio = torch.tensor(audio_np)

            audio = torch.tensor(np.array(cutset.load_audio(), dtype=float))

            # audio = torch.from_numpy(cutset.load_audio())

            text: str = str(cutset.supervisions[0].text)

            fileid = str(cutset.supervisions[0].recording_id)

            split_fileid = fileid.split("_")
            chapter_id = split_fileid[1]
            utterance_id = split_fileid[-1]

            libri_row = (
                audio,
                cutset.sampling_rate,
                text,
                text,
                speaker_id,
                chapter_id,
                utterance_id,
            )
            
            data = self.preprocess_libtts.acoustic(libri_row)

            if data is None:
                rand_idx = int(
                    torch.randint(
                        0,
                        self.__len__(),
                        (1,),
                    ).item(),
                )
                return self.__getitem__(rand_idx)

            data.wav = data.wav.unsqueeze(0)

            result = HifiLibriItem(
                id=data.utterance_id,
                wav=data.wav,
                mel=data.mel,
                pitch=data.pitch,
                text=data.phones,
                attn_prior=data.attn_prior,
                energy=data.energy,
                raw_text=data.raw_text,
                normalized_text=data.normalized_text,
                speaker=speaker_id,
                pitch_is_normalized=data.pitch_is_normalized,
                lang=lang2id["en"],
                dataset_type="libritts",

                # dataset_type="hifitts" if idx < len(self.cutset_hifi) else "libritts",
            )

            if self.cache:
                # Create the cache subdirectory if it doesn't exist
                Path.mkdir(
                    self.get_cache_subdir_path(idx),
                    parents=True,
                    exist_ok=True,
                )
                # Save the preprocessed data to the cache
                torch.save(asdict(result), cache_file)

            return result
        else:
            raise FileNotFoundError(f"Cut not found at index {idx}.")

    def __iter__(self):
        r"""Method makes the class iterable. It iterates over the `_walker` attribute
        and for each item, it gets the corresponding item from the dataset using the
        `__getitem__` method.

        Yields:
        The item from the dataset corresponding to the current item in `_walker`.
        """
        for item in range(self.__len__()):
            yield self.__getitem__(item)

    def collate_fn(self, data: List[HifiLibriItem]) -> List:
        r"""Collates a batch of data samples.

        Args:
            data (List[HifiLibriItem]): A list of data samples.

        Returns:
            List: A list of reprocessed data batches.
        """
        data_size = len(data)

        idxs = list(range(data_size))

        # Initialize empty lists to store extracted values
        empty_lists: List[List] = [[] for _ in range(12)]
        (
            ids,
            speakers,
            texts,
            raw_texts,
            mels,
            pitches,
            attn_priors,
            langs,
            src_lens,
            mel_lens,
            wavs,
            energy,
        ) = empty_lists

        # Extract fields from data dictionary and populate the lists
        for idx in idxs:
            data_entry = data[idx]
            ids.append(data_entry.id)
            speakers.append(data_entry.speaker)
            texts.append(data_entry.text)
            raw_texts.append(data_entry.raw_text)
            mels.append(data_entry.mel)
            pitches.append(data_entry.pitch)
            attn_priors.append(data_entry.attn_prior)
            langs.append(data_entry.lang)
            src_lens.append(data_entry.text.shape[0])
            mel_lens.append(data_entry.mel.shape[1])
            wavs.append(data_entry.wav)
            energy.append(data_entry.energy)

        # Convert langs, src_lens, and mel_lens to numpy arrays
        langs = np.array(langs)
        src_lens = np.array(src_lens)
        mel_lens = np.array(mel_lens)

        # NOTE: Instead of the pitches for the whole dataset, used stat for the batch
        # Take only min and max values for pitch
        pitches_stat = list(self.normalize_pitch(pitches)[:2])

        texts = pad_1D(texts)
        mels = pad_2D(mels)
        pitches = pad_1D(pitches)
        attn_priors = pad_3D(attn_priors, len(idxs), max(src_lens), max(mel_lens))

        speakers = np.repeat(
            np.expand_dims(np.array(speakers), axis=1),
            texts.shape[1],
            axis=1,
        )
        langs = np.repeat(
            np.expand_dims(np.array(langs), axis=1),
            texts.shape[1],
            axis=1,
        )

        wavs = pad_2D(wavs)
        energy = pad_2D(energy)

        return [
            ids,
            raw_texts,
            torch.tensor(speakers, dtype=int),
            texts.int(),
            torch.tensor(src_lens, dtype=int),
            mels,
            pitches,
            pitches_stat,
            torch.tensor(mel_lens, dtype=int),
            torch.tensor(langs, dtype=int),
            attn_priors,
            wavs,
            energy,
        ]

    def normalize_pitch(
        self,
        pitches: List[torch.Tensor],
    ) -> Tuple[float, float, float, float]:
        r"""Normalizes the pitch values.

        Args:
            pitches (List[torch.Tensor]): A list of pitch values.

        Returns:
            Tuple: A tuple containing the normalized pitch values.
        """
        pitches_t = torch.concatenate(pitches)

        min_value = torch.min(pitches_t).item()
        max_value = torch.max(pitches_t).item()

        mean = torch.mean(pitches_t).item()
        std = torch.std(pitches_t).item()

        return min_value, max_value, mean, std


In [64]:
def train_dataloader(
    batch_size: int = 6,
    num_workers: int = 5,
    sampling_rate: int = 22050,
    shuffle: bool = False,
    lang: str = "en",
    root: str = "datasets_cache",
    hifitts_path: str = "hifitts",
    hifi_cutset_file_name: str = "hifi.json.gz",
    libritts_path: str = "librittsr",
    libritts_cutset_file_name: str = "libri.json.gz",
    libritts_subsets: List[str] | str = "all",
    cache: bool = False,
    cache_dir: str = "/dev/shm",
    include_libri: bool = True,
    libri_speakers: List[str] = speakers_libri_ids,
    hifi_speakers: List[str] = speakers_hifi_ids,
) -> DataLoader:
    r"""Returns the training dataloader, that is using the HifiLibriDataset dataset.

    Args:
        batch_size (int): The batch size.
        num_workers (int): The number of workers.
        sampling_rate (int): The sampling rate of the audio. Defaults to 22050.
        shuffle (bool): Whether to shuffle the dataset.
        lang (str): The language of the dataset.
        root (str): The root directory of the dataset.
        hifitts_path (str): The path to the HiFiTTS dataset.
        hifi_cutset_file_name (str): The file name of the HiFiTTS cutset.
        libritts_path (str): The path to the LibriTTS dataset.
        libritts_cutset_file_name (str): The file name of the LibriTTS cutset.
        libritts_subsets (List[str] | str): The subsets of the LibriTTS dataset to use.
        cache (bool): Whether to cache the dataset.
        cache_dir (str): The directory to cache the dataset in.
        include_libri (bool): Whether to include the LibriTTS dataset.
        libri_speakers (List[str]): The selected speakers from the LibriTTS dataset.
        hifi_speakers (List[str]): The selected speakers from the HiFiTTS dataset.

    Returns:
        DataLoader: The training dataloader.
    """
    dataset = HifiLibriDataset(
        root=root,
        hifitts_path=hifitts_path,
        sampling_rate=sampling_rate,
        hifi_cutset_file_name=hifi_cutset_file_name,
        libritts_path=libritts_path,
        libritts_cutset_file_name=libritts_cutset_file_name,
        libritts_subsets=libritts_subsets,
        cache=cache,
        cache_dir=cache_dir,
        lang=lang,
        include_libri=include_libri,
        libri_speakers=libri_speakers,
        hifi_speakers=hifi_speakers,
    )

    train_loader = DataLoader(
        dataset,
        # 4x80Gb max 10 sec audio
        # batch_size=20, # self.train_config.batch_size,
        # 4*80Gb max ~20.4 sec audio
        batch_size=batch_size,
        # TODO: find the optimal num_workers
        num_workers=num_workers,
        persistent_workers=True,
        pin_memory=True,
        shuffle=shuffle,
        collate_fn=dataset.collate_fn,
    )

    return train_loader

In [65]:

from typing import List

from lightning.pytorch.core import LightningModule
import torch
from torch import Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader

# from models.config import (
#     AcousticFinetuningConfig,
#     AcousticModelConfigType,
#     AcousticMultilingualModelConfig,
#     AcousticPretrainingConfig,
#     AcousticTrainingConfig,
#     PreprocessingConfig,
#     get_lang_map,
#     lang2id,
# )
# from models.helpers.tools import get_mask_from_lengths
# from training.datasets.hifi_libri_dataset import (
#     speakers_hifi_ids,
#     speakers_libri_ids,
#     train_dataloader,
# )
# from training.loss import FastSpeech2LossGen
# from training.preprocess.normalize_text import NormalizeText

# # Updated version of the tokenizer
# from training.preprocess.tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA

# from .acoustic_model import AcousticModel

MEL_SPEC_EVERY_N_STEPS = 1000
AUDIO_EVERY_N_STEPS = 100


class DelightfulTTS(LightningModule):
    r"""Trainer for the acoustic model.

    Args:
        preprocess_config PreprocessingConfig: The preprocessing configuration.
        model_config AcousticModelConfigType: The model configuration.
        fine_tuning (bool, optional): Whether to use fine-tuning mode or not. Defaults to False.
        bin_warmup (bool, optional): Whether to use binarization warmup for the loss or not. Defaults to True.
        lang (str): Language of the dataset.
        n_speakers (int): Number of speakers in the dataset.generation during training.
        batch_size (int): The batch size.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        model_config: AcousticModelConfigType = AcousticENModelConfig(),
        fine_tuning: bool = False,
        bin_warmup: bool = True,
        lang: str = "en",
        n_speakers: int = 5392,
        batch_size: int = 19,
    ):
        super().__init__()

        self.lang = lang
        self.lang_id = lang2id[self.lang]

        self.fine_tuning = fine_tuning
        self.batch_size = batch_size

        lang_map = get_lang_map(lang)
        normilize_text_lang = lang_map.nemo

        self.tokenizer = TokenizerIpaEspeak(lang)
        self.normilize_text = NormalizeText(normilize_text_lang)

        self.train_config_acoustic = AcousticPretrainingConfig()

        self.preprocess_config = preprocess_config

        # TODO: fix the arguments!
        self.acoustic_model = AcousticModel(
            preprocess_config=self.preprocess_config,
            model_config=model_config,
            # NOTE: this parameter may be hyperparameter that you can define based on the demands
            n_speakers=n_speakers,
        )

        # NOTE: in case of training from 0 bin_warmup should be True!
        self.loss_acoustic = FastSpeech2LossGen(
            bin_warmup=bin_warmup,
        )

    def forward(
        self,
        text: str,
        speaker_idx: Tensor,
    ) -> Tensor:
        r"""Performs a forward pass through the AcousticModel.
        This code must be run only with the loaded weights from the checkpoint!

        Args:
            text (str): The input text.
            speaker_idx (Tensor): The index of the speaker

        Returns:
            Tensor: The generated waveform with hifi-gan.
        """
        normalized_text = self.normilize_text(text)
        _, phones = self.tokenizer(normalized_text)

        # Convert to tensor
        x = torch.tensor(
            phones,
            dtype=torch.int,
            device=speaker_idx.device,
        ).unsqueeze(0)

        speakers = speaker_idx.repeat(x.shape[1]).unsqueeze(0)

        langs = (
            torch.tensor(
                [self.lang_id],
                dtype=torch.int,
                device=speaker_idx.device,
            )
            .repeat(x.shape[1])
            .unsqueeze(0)
        )

        mel_pred = self.acoustic_model.forward(
            x=x,
            speakers=speakers,
            langs=langs,
        )

        return mel_pred

    def training_step(self, batch: List, _: int):
        r"""Performs a training step for the model.

        Args:
        batch (List): The batch of data for training. The batch should contain:
            - ids: List of indexes.
            - raw_texts: Raw text inputs.
            - speakers: Speaker identities.
            - texts: Text inputs.
            - src_lens: Lengths of the source sequences.
            - mels: Mel spectrogram targets.
            - pitches: Pitch targets.
            - pitches_stat: Statistics of the pitches.
            - mel_lens: Lengths of the mel spectrograms.
            - langs: Language identities.
            - attn_priors: Prior attention weights.
            - wavs: Waveform targets.
            - energies: Energy targets.
        batch_idx (int): Index of the batch.

        Returns:
            - 'loss': The total loss for the training step.
        """
        (
            _,
            _,
            speakers,
            texts,
            src_lens,
            mels,
            pitches,
            _,
            mel_lens,
            langs,
            attn_priors,
            _,
            energies,
        ) = batch

        outputs = self.acoustic_model.forward_train(
            x=texts,
            speakers=speakers,
            src_lens=src_lens,
            mels=mels,
            mel_lens=mel_lens,
            pitches=pitches,
            langs=langs,
            attn_priors=attn_priors,
            energies=energies,
        )

        y_pred = outputs["y_pred"]
        log_duration_prediction = outputs["log_duration_prediction"]
        p_prosody_ref = outputs["p_prosody_ref"]
        p_prosody_pred = outputs["p_prosody_pred"]
        pitch_prediction = outputs["pitch_prediction"]
        energy_pred = outputs["energy_pred"]
        energy_target = outputs["energy_target"]

        src_mask = get_mask_from_lengths(src_lens)
        mel_mask = get_mask_from_lengths(mel_lens)

        (
            total_loss,
            mel_loss,
            ssim_loss,
            duration_loss,
            u_prosody_loss,
            p_prosody_loss,
            pitch_loss,
            ctc_loss,
            bin_loss,
            energy_loss,
        ) = self.loss_acoustic.forward(
            src_masks=src_mask,
            mel_masks=mel_mask,
            mel_targets=mels,
            mel_predictions=y_pred,
            log_duration_predictions=log_duration_prediction,
            u_prosody_ref=outputs["u_prosody_ref"],
            u_prosody_pred=outputs["u_prosody_pred"],
            p_prosody_ref=p_prosody_ref,
            p_prosody_pred=p_prosody_pred,
            pitch_predictions=pitch_prediction,
            p_targets=outputs["pitch_target"],
            durations=outputs["attn_hard_dur"],
            attn_logprob=outputs["attn_logprob"],
            attn_soft=outputs["attn_soft"],
            attn_hard=outputs["attn_hard"],
            src_lens=src_lens,
            mel_lens=mel_lens,
            energy_pred=energy_pred,
            energy_target=energy_target,
            step=self.trainer.global_step,
        )

        self.log(
            "train_total_loss",
            total_loss,
            sync_dist=True,
            batch_size=self.batch_size,
        )
        self.log("train_mel_loss", mel_loss, sync_dist=True, batch_size=self.batch_size)
        self.log(
            "train_ssim_loss",
            ssim_loss,
            sync_dist=True,
            batch_size=self.batch_size,
        )
        self.log(
            "train_duration_loss",
            duration_loss,
            sync_dist=True,
            batch_size=self.batch_size,
        )
        self.log(
            "train_u_prosody_loss",
            u_prosody_loss,
            sync_dist=True,
            batch_size=self.batch_size,
        )
        self.log(
            "train_p_prosody_loss",
            p_prosody_loss,
            sync_dist=True,
            batch_size=self.batch_size,
        )
        self.log(
            "train_pitch_loss",
            pitch_loss,
            sync_dist=True,
            batch_size=self.batch_size,
        )
        self.log("train_ctc_loss", ctc_loss, sync_dist=True, batch_size=self.batch_size)
        self.log("train_bin_loss", bin_loss, sync_dist=True, batch_size=self.batch_size)
        self.log(
            "train_energy_loss",
            energy_loss,
            sync_dist=True,
            batch_size=self.batch_size,
        )

        return total_loss

    def configure_optimizers(self):
        r"""Configures the optimizer used for training.

        Returns
            tuple: A tuple containing three dictionaries. Each dictionary contains the optimizer and learning rate scheduler for one of the models.
        """
        lr_decay = self.train_config_acoustic.optimizer_config.lr_decay
        default_lr = self.train_config_acoustic.optimizer_config.learning_rate

        init_lr = (
            default_lr
            if self.trainer.global_step == 0
            else default_lr * (lr_decay**self.trainer.global_step)
        )

        optimizer_acoustic = AdamW(
            self.acoustic_model.parameters(),
            lr=init_lr,
            betas=self.train_config_acoustic.optimizer_config.betas,
            eps=self.train_config_acoustic.optimizer_config.eps,
            weight_decay=self.train_config_acoustic.optimizer_config.weight_decay,
        )

        scheduler_acoustic = ExponentialLR(optimizer_acoustic, gamma=lr_decay)

        return {
            "optimizer": optimizer_acoustic,
            "lr_scheduler": scheduler_acoustic,
        }

    def train_dataloader(
        self,
        root: str = "datasets_cache",
        cache: bool = True,
        cache_dir: str = "/dev/shm",
        include_libri: bool = False,
        libri_speakers: List[str] = speakers_libri_ids,
        hifi_speakers: List[str] = speakers_hifi_ids,
    ) -> DataLoader:
        r"""Returns the training dataloader, that is using the LibriTTS dataset.

        Args:
            root (str): The root directory of the dataset.
            cache (bool): Whether to cache the preprocessed data.
            cache_dir (str): The directory for the cache. Defaults to "/dev/shm".
            include_libri (bool): Whether to include the LibriTTS dataset or not.
            libri_speakers (List[str]): The list of LibriTTS speakers to include.
            hifi_speakers (List[str]): The list of HiFi-GAN speakers to include.

        Returns:
            Tupple[DataLoader, DataLoader]: The training and validation dataloaders.
        """
        return train_dataloader(
            batch_size=self.batch_size,
            num_workers=self.preprocess_config.workers,
            sampling_rate=self.preprocess_config.sampling_rate,
            root=root,
            cache=cache,
            cache_dir=cache_dir,
            lang=self.lang,
            include_libri=include_libri,
            libri_speakers=libri_speakers,
            hifi_speakers=hifi_speakers,
        )


In [66]:
from voicefixer import Vocoder


cache_dir = "datasets_cache"
dataset = HifiLibriDataset(cache_dir=cache_dir, cache=True)
vocoder_vf = Vocoder(44100)

In [67]:
from lightning.pytorch import Trainer

batch_size = 1

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=dataset.collate_fn,
    # persistent_workers=True,
    # pin_memory=True,
    # num_workers=2,
)

# preprocessing_config =  PreprocessingConfig(
#     "english_only",
#     stft=STFTConfig(
#         filter_length=1024,
#         hop_length=256,
#         win_length=1024,
#         n_mel_channels=100,
#         mel_fmin=20,
#         mel_fmax=11025,
#     )
# )

preprocessing_config =  PreprocessingConfigHifiGAN('multilingual')

dataloader_iterator = iter(dataloader)

# # Now you can fetch batches from it
first_batch = next(dataloader_iterator)
print(first_batch)  # This will print the first batch


default_root_dir = "checkpoints/vcoder"

trainer = Trainer(
    default_root_dir=default_root_dir,
    fast_dev_run=1,
    limit_train_batches=1,
    max_epochs=1,
)

module = DelightfulTTS(preprocess_config=preprocessing_config)
# module.train_dataloader()
# result = trainer.fit(module, train_dataloaders=dataloader)

# print(result)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.


[['000000'], ['"It is a young girl."'], tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0]]), tensor([[ 2, 63, 71, 50,  4, 42, 63, 71, 21, 50, 11, 52, 63, 71, 71, 63, 71, 50,
          4, 42, 63, 71, 50, 38, 20, 63, 71, 71, 63, 71, 50,  7, 42, 63, 71, 71,
         63, 71, 24, 50,  4, 42, 63, 71, 50, 17, 45, 63, 71, 12, 50, 22, 52, 63,
         71, 50, 38, 16, 63, 71,  6, 48, 50, 11, 52, 63, 71, 71, 63, 71,  6, 48,
         50, 11, 52, 63, 71, 50,  4, 42, 63, 71, 50, 35, 52, 40, 63, 71, 50, 38,
         14, 63, 71, 63, 63, 71,  3]], dtype=torch.int32), tensor([97]), tensor([[[-2.1233, -1.5432, -1.8827,  ..., -3.0057, -3.1312, -2.7474],
         [-3.0421, -3.1670, -3.9047,  ..., -3.7185, -3.7156, -3.9338],
 

In [68]:
result = trainer.fit(module, train_dataloaders=dataloader)


  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | acoustic_model | AcousticModel      | 133 M  | train
1 | loss_acoustic  | FastSpeech2LossGen | 0      | train
--------------------------------------------------------------
133 M     Trainable params
0         Non-trainable params
133 M     Total params
534.261   Total estimated model params size (MB)
799       Modules in train mode
0         Modules in eval mode
/Users/user/codec/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 1/1 [00:08<00:00,  0.12it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


Epoch 0: 100%|██████████| 1/1 [00:08<00:00,  0.12it/s]


In [70]:
print(result)

None
