In [1]:
import pandas as pd

In [2]:
speakers_df = pd.read_csv(
    "./datasets_cache/LIBRITTS/LibriTTS/speakers.tsv",
    sep="\t",
    names=["READER", "GENDER", "SUBSET", "NAME"],
)

In [3]:
from dataclasses import dataclass, field
from typing import List, Literal, Tuple, Union

PreprocessLangType = Literal["english_only", "multilingual"]


@dataclass
class STFTConfig:
    filter_length: int
    hop_length: int
    win_length: int
    n_mel_channels: int
    mel_fmin: int
    mel_fmax: int


# Base class used with the Univnet vocoder
@dataclass
class PreprocessingConfig:
    language: PreprocessLangType
    stft: STFTConfig
    sampling_rate: int = 22050
    min_seconds: float = 0.5
    max_seconds: float = 6.0
    use_audio_normalization: bool = True
    workers: int = 8


@dataclass
class PreprocessingConfigUnivNet(PreprocessingConfig):
    stft: STFTConfig = field(
        default_factory=lambda: STFTConfig(
            filter_length=1024,
            hop_length=256,
            win_length=1024,
            n_mel_channels=100,  # univnet
            mel_fmin=20,
            mel_fmax=11025,
        ),
    )

In [4]:
@dataclass
class LangItem:
    r"""A class for storing language information."""

    phonemizer: str
    phonemizer_espeak: str
    nemo: str
    processing_lang_type: PreprocessLangType
    
langs_map: dict[str, LangItem] = {
    "en": LangItem(
        phonemizer="en_us",
        phonemizer_espeak="en-us",
        nemo="en",
        processing_lang_type="english_only",
    ),
}

def get_lang_map(lang: str) -> LangItem:
    r"""Returns a LangItem object for the given language.

    Args:
        lang (str): The language to get the LangItem for.

    Raises:
        ValueError: If the language is not supported.

    Returns:
        LangItem: The LangItem object for the given language.
    """
    if lang not in langs_map:
        raise ValueError(f"Language {lang} is not supported!")
    return langs_map[lang]

In [5]:
import re

# from nemo_text_processing.text_normalization.normalize import Normalizer
from unidecode import unidecode
import torchaudio


class NormalizeText:
    r"""NVIDIA NeMo is a conversational AI toolkit built for researchers working on automatic speech recognition (ASR), text-to-speech synthesis (TTS), large language models (LLMs), and natural language processing (NLP). The primary objective of NeMo is to help researchers from industry and academia to reuse prior work (code and pretrained models) and make it easier to create new conversational AI models.

    This class normalize the characters in the input text and normalize the input text with the `nemo_text_processing`.

    Args:
        lang (str): The language code to use for normalization. Defaults to "en".

    Attributes:
        lang (str): The language code to use for normalization. Defaults to "en".
        model (Normalizer): The `nemo_text_processing` Normalizer model.

    Methods:
        byte_encode(word: str) -> list: Encode a word as a list of bytes.
        normalize_chars(text: str) -> str: Normalize the characters in the input text.
        __call__(text: str) -> str: Normalize the input text with the `nemo_text_processing`.

    Examples:
        >>> from training.preprocess.normilize_text import NormalizeText
        >>> normilize_text = NormalizeText()
        >>> normilize_text("It’s a beautiful day…")
        "It's a beautiful day."
    """

    def __init__(self, lang: str = "en"):
        r"""Initialize a new instance of the NormalizeText class.

        Args:
            lang (str): The language code to use for normalization. Defaults to "en".

        """
        self.lang = lang
        self.processor = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()
        # self.model = Normalizer(input_case="cased", lang=lang)

    def byte_encode(self, word: str) -> list:
        r"""Encode a word as a list of bytes.

        Args:
            word (str): The word to encode.

        Returns:
            list: A list of bytes representing the encoded word.

        """
        text = word.strip()
        return list(text.encode("utf-8"))

    def normalize_chars(self, text: str) -> str:
        r"""Normalize the characters in the input text.

        Args:
            text (str): The input text to normalize.

        Returns:
            str: The normalized text.

        Examples:
            >>> normalize_chars("It’s a beautiful day…")
            "It's a beautiful day."

        """
        # Define the character mapping
        char_mapping = {
            ord("’"): ord("'"),
            ord("”"): ord("'"),
            ord("…"): ord("."),
            ord("„"): ord("'"),
            ord("“"): ord("'"),
            ord('"'): ord("'"),
            ord("–"): ord("-"),
            ord("—"): ord("-"),
            ord("«"): ord("'"),
            ord("»"): ord("'"),
        }

        # Add unicode normalization as additional garanty and normalize the characters using translate() method
        normalized_string = unidecode(text).translate(char_mapping)

        # Remove redundant multiple characters
        # TODO: Maybe there is some effect on duplication?
        return re.sub(r"(\.|\!|\?|\-)\1+", r"\1", normalized_string)

    def __call__(self, text: str) -> str:
        r"""Normalize the input text with the `nemo_text_processing`.

        Args:
            text (str): The input text to normalize.

        Returns:
            str: The normalized text.

        """
        text = self.normalize_chars(text)
        # return self.model.normalize(text)

        # Split the text into lines
        # lines = text.split("\n")
        processed, lengths = self.processor(text)
        normalized_lines = [self.processor.tokens[i] for i in processed[0, : lengths[0]]]
        # normalized_lines = self.model.normalize_list(lines)

        # TODO: check this!
        # Join the normalized lines, replace \n with . and return the result
        result = ". ".join(normalized_lines)
        return result


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/user/codec/.venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/user/codec/.venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/user/codec/.venv/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/Users/u

In [6]:
# NOTE: for the backward comp.
# Prepare the phonemes list and dictionary for the embedding
phoneme_basic_symbols = [
    # IPA symbols
    "a",
    "b",
    "d",
    "e",
    "f",
    "g",
    "h",
    "i",
    "j",
    "k",
    "l",
    "m",
    "n",
    "o",
    "p",
    "r",
    "s",
    "t",
    "u",
    "v",
    "w",
    "x",
    "y",
    "z",
    "æ",
    "ç",
    "ð",
    "ø",
    "ŋ",
    "œ",
    "ɐ",
    "ɑ",
    "ɔ",
    "ə",
    "ɛ",
    "ɝ",
    "ɹ",
    "ɡ",
    "ɪ",
    "ʁ",
    "ʃ",
    "ʊ",
    "ʌ",
    "ʏ",
    "ʒ",
    "ʔ",
    "ˈ",
    "ˌ",
    "ː",
    "̃",
    "̍",
    "̥",
    "̩",
    "̯",
    "͡",
    "θ",
    # Punctuation
    "!",
    "?",
    ",",
    ".",
    "-",
    ":",
    ";",
    '"',
    "'",
    "(",
    ")",
    " ",
]

# TODO: add support for other languages
# _letters_accented = "µßàáâäåæçèéêëìíîïñòóôöùúûüąćęłńœśşźżƒ"
# _letters_cyrilic = "абвгдежзийклмнопрстуфхцчшщъыьэюяёєіїґӧ"
# _pad = "$"

# This is the list of symbols from StyledTTS2
_punctuation = ';:,.!?¡¿—…"«»“”'
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"

# Combine all symbols
symbols = list(_punctuation) + list(_letters) + list(_letters_ipa)

# Add only unique symbols
phones = phoneme_basic_symbols + [
    symbol for symbol in symbols if symbol not in phoneme_basic_symbols
]

# TODO: Need to understand how to replace this
# len(phones) == 184, leave it as is at this point
symbols = [str(el) for el in range(256)]
symbol2id = {s: i for i, s in enumerate(symbols)}
id2symbol = {i: s for i, s in enumerate(symbols)}

In [7]:
from logging import ERROR, Logger
import os

from phonemizer.backend import EspeakBackend

# IPA Phonemizer: https://github.com/bootphon/phonemizer
from phonemizer.backend.espeak.wrapper import EspeakWrapper

# Create a Logger instance
logger = Logger("my_logger")
# Set the level to ERROR
logger.setLevel(ERROR)

from dp.preprocessing.text import SequenceTokenizer

# from models.config import get_lang_map
# from models.config.symbols import phones

# INFO: Fix for windows, used for local env
if os.name == "nt":
    ESPEAK_LIBRARY = os.getenv(
        "ESPEAK_LIBRARY",
        "C:\\Program Files\\eSpeak NG\\libespeak-ng.dll",
    )
    EspeakWrapper.set_library(ESPEAK_LIBRARY)


class TokenizerIpaEspeak:
    def __init__(self, lang: str = "en"):
        lang_map = get_lang_map(lang)
        self.lang = lang_map.phonemizer_espeak
        self.lang_seq = lang_map.phonemizer

        # NOTE: for backward compatibility with previous IPA tokenizer see the TokenizerIPA class
        self.tokenizer = SequenceTokenizer(
            phones,
            languages=["de", "en_us"],
            lowercase=True,
            char_repeats=1,
            append_start_end=True,
        )

        self.phonemizer = EspeakBackend(
            language=self.lang,
            preserve_punctuation=True,
            with_stress=True,
            words_mismatch="ignore",
            logger=logger,
        ).phonemize

    def __call__(self, text: str):
        r"""Converts the input text to phonemes and tokenizes them.

        Args:
            text (str): The input text to be tokenized.

        Returns:
            Tuple[Union[str, List[str]], List[int]]: IPA phonemes and tokens.

        """
        phones_ipa = "".join(self.phonemizer([text]))

        tokens = self.tokenizer(phones_ipa, language=self.lang_seq)

        return phones_ipa, tokens

In [8]:
@dataclass
class VocoderBasicConfig:
    segment_size: int = 16384
    learning_rate: float = 0.0001
    adam_b1: float = 0.5
    adam_b2: float = 0.9
    lr_decay: float = 0.995
    synth_interval: int = 250
    checkpoint_interval: int = 250
    stft_lamb: float = 2.5

In [17]:
from typing import Optional, Tuple

import librosa
import torch
from torch.nn import Module


class TacotronSTFT(Module):
    def __init__(
        self,
        filter_length: int,
        hop_length: int,
        win_length: int,
        n_mel_channels: int,
        sampling_rate: int,
        center: bool,
        mel_fmax: Optional[int],
        mel_fmin: float = 0.0,
    ):
        r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.

        Args:
            filter_length (int): Length of the filter window.
            hop_length (int): Number of samples between successive frames.
            win_length (int): Size of the STFT window.
            n_mel_channels (int): Number of mel bins.
            sampling_rate (int): Sampling rate of the input waveforms.
            mel_fmin (int or None): Minimum frequency for the mel filter bank.
            mel_fmax (int or None): Maximum frequency for the mel filter bank.
            center (bool): Whether to pad the input signal on both sides.
        """
        super().__init__()

        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.n_fft = filter_length
        self.hop_size = hop_length
        self.win_size = win_length
        self.fmin = mel_fmin
        self.fmax = mel_fmax
        self.center = center

        # Define the mel filterbank
        mel = librosa.filters.mel(
            sr=sampling_rate,
            n_fft=filter_length,
            n_mels=n_mel_channels,
            fmin=mel_fmin,
            fmax=mel_fmax,
        )

        mel_basis = torch.tensor(mel, dtype=float).float()

        # Define the Hann window
        hann_window = torch.hann_window(win_length)

        self.register_buffer("mel_basis", mel_basis)
        self.register_buffer("hann_window", hann_window)

    def _spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        assert torch.min(y.data) >= -1
        assert torch.max(y.data) <= 1

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (
                int((self.n_fft - self.hop_size) / 2),
                int((self.n_fft - self.hop_size) / 2),
            ),
            mode="reflect",
        )
        y = y.squeeze(1)
        spec = torch.stft(
            y,
            self.n_fft,
            hop_length=self.hop_size,
            win_length=self.win_size,
            window=self.hann_window,  # type: ignore
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        return torch.view_as_real(spec)

    def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        spec = self._spectrogram(y)
        return torch.norm(spec, p=2, dim=-1)

    def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Computes mel-spectrograms from a batch of waves.

        Args:
            y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]

        Returns:
            torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
            torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
        """
        spec = self._spectrogram(y)

        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

        mel = torch.matmul(self.mel_basis, spec)  # type: ignore
        mel = self.spectral_normalize_torch(mel)

        return spec, mel

    def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
        r"""Applies dynamic range compression to magnitudes.

        Args:
            magnitudes (torch.Tensor): Input magnitudes.

        Returns:
            torch.Tensor: Output magnitudes.
        """
        return self.dynamic_range_compression_torch(magnitudes)

    def dynamic_range_compression_torch(
        self,
        x: torch.Tensor,
        C: int = 1,
        clip_val: float = 1e-5,
    ) -> torch.Tensor:
        r"""Applies dynamic range compression to x.

        Args:
            x (torch.Tensor): Input tensor.
            C (float): Compression factor.
            clip_val (float): Clipping value.

        Returns:
            torch.Tensor: Output tensor.
        """
        return torch.log(torch.clamp(x, min=clip_val) * C)

    # NOTE: audio np.ndarray changed to torch.FloatTensor!
    def get_mel_from_wav(self, audio: torch.Tensor) -> torch.Tensor:
        audio_tensor = audio.unsqueeze(0)
        with torch.no_grad():
            _, melspec = self.forward(audio_tensor)
        return melspec.squeeze(0)

In [10]:
from librosa.filters import mel as librosa_mel_fn
import torch


class AudioProcessor:
    r"""A class used to process audio signals and convert them into different representations.

    Attributes:
        hann_window (dict): A dictionary to store the Hann window for different configurations.
        mel_basis (dict): A dictionary to store the Mel basis for different configurations.

    Methods:
        name_mel_basis(spec, n_fft, fmax): Generate a name for the Mel basis based on the FFT size, maximum frequency, data type, and device.
        amp_to_db(magnitudes, C=1, clip_val=1e-5): Convert amplitude to decibels (dB).
        db_to_amp(magnitudes, C=1): Convert decibels (dB) to amplitude.
        wav_to_spec(y, n_fft, hop_length, win_length, center=False): Convert a waveform to a spectrogram and compute the magnitude.
        wav_to_energy(y, n_fft, hop_length, win_length, center=False): Convert a waveform to a spectrogram and compute the energy.
        spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): Convert a spectrogram to a Mel spectrogram.
        wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): Convert a waveform to a Mel spectrogram.
    """

    def __init__(self):
        self.hann_window = {}
        self.mel_basis = {}

    @staticmethod
    def name_mel_basis(spec: torch.Tensor, n_fft: int, fmax: int) -> str:
        """Generate a name for the Mel basis based on the FFT size, maximum frequency, data type, and device.

        Args:
            spec (torch.Tensor): The spectrogram tensor.
            n_fft (int): The FFT size.
            fmax (int): The maximum frequency.

        Returns:
            str: The generated name for the Mel basis.
        """
        n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}"
        return n_fft_len

    @staticmethod
    def amp_to_db(magnitudes: torch.Tensor, C: int = 1, clip_val: float = 1e-5) -> torch.Tensor:
        r"""Convert amplitude to decibels (dB).

        Args:
            magnitudes (Tensor): The amplitude magnitudes to convert.
            C (int, optional): A constant value used in the conversion. Defaults to 1.
            clip_val (float, optional): A value to clamp the magnitudes to avoid taking the log of zero. Defaults to 1e-5.

        Returns:
            Tensor: The converted magnitudes in dB.
        """
        return torch.log(torch.clamp(magnitudes, min=clip_val) * C)

    @staticmethod
    def db_to_amp(magnitudes: torch.Tensor, C: int = 1) -> torch.Tensor:
        r"""Convert decibels (dB) to amplitude.

        Args:
            magnitudes (Tensor): The dB magnitudes to convert.
            C (int, optional): A constant value used in the conversion. Defaults to 1.

        Returns:
            Tensor: The converted magnitudes in amplitude.
        """
        return torch.exp(magnitudes) / C

    def wav_to_spec(
        self,
        y: torch.Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        center: bool = False,
    ) -> torch.Tensor:
        r"""Convert a waveform to a spectrogram and compute the magnitude.

        Args:
            y (Tensor): The input waveform.
            n_fft (int): The FFT size.
            hop_length (int): The hop (stride) size.
            win_length (int): The window size.
            center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False.

        Returns:
            Tensor: The magnitude of the computed spectrogram.
        """
        y = y.squeeze(1)

        dtype_device = str(y.dtype) + "_" + str(y.device)
        wnsize_dtype_device = str(win_length) + "_" + dtype_device
        if wnsize_dtype_device not in self.hann_window:
            self.hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device)

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
            mode="reflect",
        )
        y = y.squeeze(1)

        spec = torch.stft(
            y,
            n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=self.hann_window[wnsize_dtype_device],
            center=center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )

        spec = torch.view_as_real(spec)

        # Compute the magnitude
        spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)

        return spec

    def wav_to_energy(
        self,
        y: torch.Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        center: bool = False,
    ) -> torch.Tensor:
        r"""Convert a waveform to a spectrogram and compute the energy.

        Args:
            y (Tensor): The input waveform.
            n_fft (int): The FFT size.
            hop_length (int): The hop (stride) size.
            win_length (int): The window size.
            center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False.

        Returns:
            Tensor: The energy of the computed spectrogram.
        """
        spec = self.wav_to_spec(y, n_fft, hop_length, win_length, center=center)
        spec = torch.norm(spec, dim=1, keepdim=True).squeeze(0)

        # Normalize the energy
        return (spec - spec.mean()) / spec.std()

    def spec_to_mel(
            self,
            spec: torch.Tensor,
            n_fft: int,
            num_mels: int,
            sample_rate: int,
            fmin: int,
            fmax: int,
    ) -> torch.Tensor:
        r"""Convert a spectrogram to a Mel spectrogram.

        Args:
            spec (torch.Tensor): The input spectrogram of shape [B, C, T].
            n_fft (int): The FFT size.
            num_mels (int): The number of Mel bands.
            sample_rate (int): The sample rate of the audio.
            fmin (int): The minimum frequency.
            fmax (int): The maximum frequency.

        Returns:
            torch.Tensor: The computed Mel spectrogram of shape [B, C, T].
        """
        mel_basis_key = self.name_mel_basis(spec, n_fft, fmax)

        if mel_basis_key not in self.mel_basis:
            mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
            self.mel_basis[mel_basis_key] = torch.tensor(mel).to(dtype=spec.dtype, device=spec.device)

        mel = torch.matmul(self.mel_basis[mel_basis_key], spec)
        mel = self.amp_to_db(mel)

        return mel

    def wav_to_mel(
        self,
        y: torch.Tensor,
        n_fft: int,
        num_mels: int,
        sample_rate: int,
        hop_length: int,
        win_length: int,
        fmin: int,
        fmax: int,
        center: bool = False,
    ) -> torch.Tensor:
        r"""Convert a waveform to a Mel spectrogram.

        Args:
            y (torch.Tensor): The input waveform.
            n_fft (int): The FFT size.
            num_mels (int): The number of Mel bands.
            sample_rate (int): The sample rate of the audio.
            hop_length (int): The hop (stride) size.
            win_length (int): The window size.
            fmin (int): The minimum frequency.
            fmax (int): The maximum frequency.
            center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False.

        Returns:
            torch.Tensor: The computed Mel spectrogram.
        """
        # Convert the waveform to a spectrogram
        spec = self.wav_to_spec(y, n_fft, hop_length, win_length, center=center)

        # Convert the spectrogram to a Mel spectrogram
        mel = self.spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax)

        return mel

In [49]:
import sys
from typing import Tuple, Union

import librosa
import numpy as np
import torch
import torchaudio


def stereo_to_mono(audio: torch.Tensor) -> torch.Tensor:
    r"""Converts a stereo audio tensor to mono by taking the mean across channels.

    Args:
        audio (torch.Tensor): Input audio tensor of shape (channels, samples).

    Returns:
        torch.Tensor: Mono audio tensor of shape (1, samples).
    """
    return torch.mean(audio, 0, True)


def resample(wav: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
    r"""Resamples an audio waveform from the original sampling rate to the target sampling rate.

    Args:
        wav (np.ndarray): The audio waveform to be resampled.
        orig_sr (int): The original sampling rate of the audio waveform.
        target_sr (int): The target sampling rate to resample the audio waveform to.

    Returns:
        np.ndarray: The resampled audio waveform.
    """
    return librosa.resample(wav, orig_sr=orig_sr, target_sr=target_sr)


def safe_load(path: str, sr: Union[int, None]) -> Tuple[np.ndarray, int]:
    r"""Load an audio file from disk and return its content as a numpy array.

    Args:
        path (str): The path to the audio file.
        sr (int or None): The target sampling rate. If None, the original sampling rate is used.

    Returns:
        Tuple[np.ndarray, int]: A tuple containing the audio content as a numpy array and the actual sampling rate.
    """
    try:
        audio, sr_actual = torchaudio.load(path) # type: ignore
        if audio.shape[0] > 0:
            audio = stereo_to_mono(audio)
        audio = audio.squeeze(0)
        if sr_actual != sr and sr is not None:
            audio = resample(audio.numpy(), orig_sr=sr_actual, target_sr=sr)
            sr_actual = sr
        else:
            audio = audio.numpy()
    except Exception as e:
        raise type(e)(
            f"The following error happened loading the file {path} ... \n" + str(e),
        ).with_traceback(sys.exc_info()[2])

    return audio, sr_actual


def preprocess_audio(
    audio: torch.Tensor, sr_actual: int, sr: Union[int, None],
) -> Tuple[torch.Tensor, int]:
    r"""Preprocesses audio by converting stereo to mono, resampling if necessary, and returning the audio tensor and sample rate.

    Args:
        audio (torch.Tensor): The audio tensor to preprocess.
        sr_actual (int): The actual sample rate of the audio.
        sr (Union[int, None]): The target sample rate to resample the audio to, if necessary.

    Returns:
        Tuple[torch.Tensor, int]: The preprocessed audio tensor and sample rate.
    """
    try:
        if audio.shape[0] > 0:
            audio = stereo_to_mono(audio)
        audio = audio.squeeze(0)
        if sr_actual != sr and sr is not None:
            if isinstance(audio, torch.Tensor):
                detach = audio.data.detach().tolist()
                
                audio = np.array(detach, dtype=float)

            # audio
            audio_np = resample(audio, orig_sr=sr_actual, target_sr=sr)
            # Convert back to torch tensor
            audio = torch.tensor(audio_np)
            sr_actual = sr
    except Exception as e:
        raise type(e)(
            f"The following error happened while processing the audio ... \n {e!s}",
        ).with_traceback(sys.exc_info()[2])

    return audio, sr_actual


def normalize_loudness(wav: torch.Tensor) -> torch.Tensor:
    r"""Normalize the loudness of an audio waveform.

    Args:
        wav (torch.Tensor): The input waveform.

    Returns:
        torch.Tensor: The normalized waveform.

    Examples:
        >>> wav = np.array([1.0, 2.0, 3.0])
        >>> normalize_loudness(wav)
        tensor([0.33333333, 0.66666667, 1.  ])
    """
    return wav / torch.max(torch.abs(wav))

In [12]:
from typing import List, Tuple

import numpy as np
import torch
import torch.nn.functional as F

# TODO: LOOK AT THIS ESTIMATION ALGO
#######################################################################################
# Original implementation from https://github.com/NVIDIA/mellotron/blob/master/yin.py #
#######################################################################################


def differenceFunction(x: np.ndarray, N: int, tau_max: int) -> np.ndarray:
    r"""Compute the difference function of an audio signal.

    This function computes the difference function of an audio signal `x` using the algorithm described in equation (6) of [1]. The difference function is a measure of the similarity between the signal and a time-shifted version of itself, and is commonly used in pitch detection algorithms.

    This implementation uses the NumPy FFT functions to compute the difference function efficiently.

    Parameters
        x (np.ndarray): The audio signal to compute the difference function for.
        N (int): The length of the audio signal.
        tau_max (int): The maximum integration window size to use.

    Returns
        np.ndarray: The difference function of the audio signal.

    References
        [1] A. de Cheveigné and H. Kawahara, "YIN, a fundamental frequency estimator for speech and music," The Journal of the Acoustical Society of America, vol. 111, no. 4, pp. 1917-1930, 2002.
    """
    x = np.array(x, np.float64)
    w = x.size
    tau_max = min(tau_max, w)
    x_cumsum = np.concatenate((np.array([0.0]), (x * x).cumsum()))
    size = w + tau_max
    p2 = (size // 32).bit_length()
    nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
    size_pad = min(x * 2**p2 for x in nice_numbers if x * 2**p2 >= size)
    fc = np.fft.rfft(x, size_pad)
    conv = np.fft.irfft(fc * fc.conjugate())[:tau_max]
    return x_cumsum[w : w - tau_max : -1] + x_cumsum[w] - x_cumsum[:tau_max] - 2 * conv


def cumulativeMeanNormalizedDifferenceFunction(df: np.ndarray, N: int) -> np.ndarray:
    r"""Compute the cumulative mean normalized difference function (CMND) of a difference function.

    The CMND is defined as the element-wise product of the difference function with a range of values from 1 to N-1,
    divided by the cumulative sum of the difference function up to that point, plus a small epsilon value to avoid
    division by zero. The first element of the CMND is set to 1.

    Args:
        df (np.ndarray): The difference function.
        N (int): The length of the data.

    Returns:
        np.ndarray: The cumulative mean normalized difference function.

    References:
        [1] K. K. Paliwal and R. P. Sharma, "A robust algorithm for pitch detection in noisy speech signals,"
            Speech Communication, vol. 12, no. 3, pp. 249-263, 1993.
    """
    cmndf = (
        df[1:] * range(1, N) / (np.cumsum(df[1:]).astype(float) + 1e-8)
    )  # scipy method
    return np.insert(cmndf, 0, 1)


def getPitch(cmdf: np.ndarray, tau_min: int, tau_max: int, harmo_th: float=0.1) -> int:
    r"""Compute the fundamental period of a frame based on the Cumulative Mean Normalized Difference function (CMND).

    The CMND is a measure of the periodicity of a signal, and is computed as the cumulative mean normalized difference
    function of the difference function of the signal. The fundamental period is the first value of the index `tau`
    between `tau_min` and `tau_max` where the CMND is below the `harmo_th` threshold. If there are no such values, the
    function returns 0 to indicate that the signal is unvoiced.

    Args:
        cmdf (np.ndarray): The Cumulative Mean Normalized Difference function of the signal.
        tau_min (int): The minimum period for speech.
        tau_max (int): The maximum period for speech.
        harmo_th (float, optional): The harmonicity threshold to determine if it is necessary to compute pitch
            frequency. Defaults to 0.1.

    Returns:
        int: The fundamental period of the signal, or 0 if the signal is unvoiced.

    References:
        [1] K. K. Paliwal and R. P. Sharma, "A robust algorithm for pitch detection in noisy speech signals,"
            Speech Communication, vol. 12, no. 3, pp. 249-263, 1993.
    """
    tau = tau_min
    while tau < tau_max:
        if cmdf[tau] < harmo_th:
            while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]:
                tau += 1
            return tau
        tau += 1

    return 0  # if unvoiced


def compute_yin(
    sig_torch: torch.Tensor,
    sr: int,
    w_len: int = 512,
    w_step: int = 256,
    f0_min: int = 100,
    f0_max: int = 500,
    harmo_thresh: float = 0.1,
) -> Tuple[np.ndarray, List[float], List[float], List[float]]:
    r"""Compute the Yin Algorithm for pitch detection on an audio signal.

    The Yin Algorithm is a widely used method for pitch detection in speech and music signals. It works by computing the
    Cumulative Mean Normalized Difference function (CMND) of the difference function of the signal, and finding the first
    minimum of the CMND below a given threshold. The fundamental period of the signal is then estimated as the inverse of
    the lag corresponding to this minimum.

    Args:
        sig_torch (torch.Tensor): The audio signal as a 1D numpy array of floats.
        sr (int): The sampling rate of the signal.
        w_len (int, optional): The size of the analysis window in samples. Defaults to 512.
        w_step (int, optional): The size of the lag between two consecutive windows in samples. Defaults to 256.
        f0_min (int, optional): The minimum fundamental frequency that can be detected in Hz. Defaults to 100.
        f0_max (int, optional): The maximum fundamental frequency that can be detected in Hz. Defaults to 500.
        harmo_thresh (float, optional): The threshold of detection. The algorithm returns the first minimum of the CMND
            function below this threshold. Defaults to 0.1.

    Returns:
        Tuple[np.ndarray, List[float], List[float], List[float]]: A tuple containing the following elements:
            * pitches (np.ndarray): A 1D numpy array of fundamental frequencies estimated for each analysis window.
            * harmonic_rates (List[float]): A list of harmonic rate values for each fundamental frequency value, which
              can be interpreted as a confidence value.
            * argmins (List[float]): A list of the minimums of the Cumulative Mean Normalized Difference Function for
              each analysis window.
            * times (List[float]): A list of the time of each estimation, in seconds.

    References:
        [1] A. K. Jain, Fundamentals of Digital Image Processing, Prentice Hall, 1989.
        [2] A. de Cheveigné and H. Kawahara, "YIN, a fundamental frequency estimator for speech and music," The Journal
            of the Acoustical Society of America, vol. 111, no. 4, pp. 1917-1930, 2002.
    """
    sig_torch = sig_torch.view(1, 1, -1)
    sig_torch = F.pad(
        sig_torch.unsqueeze(1),
        (int((w_len - w_step) / 2), int((w_len - w_step) / 2), 0, 0),
        mode="reflect",
    )
    sig_torch_n: np.ndarray = sig_torch.view(-1).numpy()

    tau_min = int(sr / f0_max)
    tau_max = int(sr / f0_min)

    timeScale = range(
        0, len(sig_torch_n) - w_len, w_step,
    )  # time values for each analysis window
    times = [t / float(sr) for t in timeScale]
    frames = [sig_torch_n[t : t + w_len] for t in timeScale]

    pitches = [0.0] * len(timeScale)
    harmonic_rates = [0.0] * len(timeScale)
    argmins = [0.0] * len(timeScale)

    for i, frame in enumerate(frames):
        # Compute YIN
        df = differenceFunction(frame, w_len, tau_max)
        cmdf = cumulativeMeanNormalizedDifferenceFunction(df, tau_max)
        p = getPitch(cmdf, tau_min, tau_max, harmo_thresh)

        # Get results
        if np.argmin(cmdf) > tau_min:
            argmins[i] = float(sr / np.argmin(cmdf))
        if p != 0:  # A pitch was found
            pitches[i] = float(sr / p)
            harmonic_rates[i] = cmdf[p]
        else:  # No pitch, but we compute a value of the harmonic rate
            harmonic_rates[i] = min(cmdf)

    return np.array(pitches), harmonic_rates, argmins, times


def norm_interp_f0(f0: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    r"""Normalize and interpolate the fundamental frequency (f0) values.

    Args:
        f0 (np.ndarray): The input f0 values.

    Returns:
        Tuple[np.ndarray, np.ndarray]: A tuple containing the normalized f0 values and a boolean array indicating which values were interpolated.

    Examples:
        >>> f0 = np.array([0, 100, 0, 200, 0])
        >>> norm_interp_f0(f0)
        (
            np.array([100, 100, 150, 200, 200]),
            np.array([True, False, True, False, True]),
        )
    """
    uv: np.ndarray = f0 == 0
    if sum(uv) == len(f0):
        f0[uv] = 0
    elif sum(uv) > 0:
        f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
    return f0, uv


def compute_pitch(
    sig_torch: torch.Tensor,
    sr: int,
    w_len: int = 1024,
    w_step: int = 256,
    f0_min: int = 50,
    f0_max: int = 1000,
    harmo_thresh: float = 0.25,
):
    r"""Compute the pitch of an audio signal using the Yin algorithm.

    The Yin algorithm is a widely used method for pitch detection in speech and music signals. This function uses the
    Yin algorithm to compute the pitch of the input audio signal, and then normalizes and interpolates the pitch values.
    Returns the normalized and interpolated pitch values.

    Args:
        sig_torch (torch.Tensor): The audio signal as a 1D numpy array of floats.
        sr (int): The sampling rate of the signal.
        w_len (int, optional): The size of the analysis window in samples.
        w_step (int, optional): The size of the lag between two consecutive windows in samples.
        f0_min (int, optional): The minimum fundamental frequency that can be detected in Hz.
        f0_max (int, optional): The maximum fundamental frequency that can be detected in Hz.
        harmo_thresh (float, optional): The threshold of detection. The algorithm returns the first minimum of the CMND function below this threshold.

    Returns:
        np.ndarray: The normalized and interpolated pitch values of the audio signal.
    """
    pitch, _, _, _ = compute_yin(
        sig_torch,
        sr=sr,
        w_len=w_len,
        w_step=w_step,
        f0_min=f0_min,
        f0_max=f0_max,
        harmo_thresh=harmo_thresh,
    )

    pitch, _ = norm_interp_f0(pitch)

    return pitch

In [13]:
from dataclasses import dataclass
import math
import random
from typing import Any, List, Tuple, Union

import numpy as np
from scipy.stats import betabinom
import torch
import torch.nn.functional as F

# from models.config import PreprocessingConfig, VocoderBasicConfig, get_lang_map

# from .audio import normalize_loudness, preprocess_audio
# from .audio_processor import AudioProcessor
# from .compute_yin import compute_yin, norm_interp_f0
# from .normalize_text import NormalizeText
# from .tacotron_stft import TacotronSTFT
# from .tokenizer_ipa_espeak import TokenizerIpaEspeak as TokenizerIPA


@dataclass
class PreprocessForAcousticResult:
    wav: torch.Tensor
    mel: torch.Tensor
    pitch: torch.Tensor
    phones_ipa: Union[str, List[str]]
    phones: torch.Tensor
    attn_prior: torch.Tensor
    energy: torch.Tensor
    raw_text: str
    normalized_text: str
    speaker_id: int
    chapter_id: str | int
    utterance_id: str
    pitch_is_normalized: bool


class PreprocessLibriTTS:
    r"""Preprocessing PreprocessLibriTTS audio and text data for use with a TacotronSTFT model.

    Args:
        preprocess_config (PreprocessingConfig): The preprocessing configuration.
        lang (str): The language of the input text.

    Attributes:
        min_seconds (float): The minimum duration of audio clips in seconds.
        max_seconds (float): The maximum duration of audio clips in seconds.
        hop_length (int): The hop length of the STFT.
        sampling_rate (int): The sampling rate of the audio.
        use_audio_normalization (bool): Whether to normalize the loudness of the audio.
        tacotronSTFT (TacotronSTFT): The TacotronSTFT object used for computing mel spectrograms.
        min_samples (int): The minimum number of audio samples in a clip.
        max_samples (int): The maximum number of audio samples in a clip.
    """

    def __init__(
        self,
        preprocess_config: PreprocessingConfig,
        lang: str = "en",
    ):
        super().__init__()

        lang_map = get_lang_map(lang)

        self.phonemizer_lang = lang_map.phonemizer
        normilize_text_lang = lang_map.nemo

        self.normilize_text = NormalizeText(normilize_text_lang)
        self.tokenizer = TokenizerIpaEspeak(lang)
        self.vocoder_train_config = VocoderBasicConfig()

        self.preprocess_config = preprocess_config

        self.sampling_rate = self.preprocess_config.sampling_rate
        self.use_audio_normalization = self.preprocess_config.use_audio_normalization

        self.hop_length = self.preprocess_config.stft.hop_length
        self.filter_length = self.preprocess_config.stft.filter_length
        self.mel_fmin = self.preprocess_config.stft.mel_fmin
        self.win_length = self.preprocess_config.stft.win_length

        self.tacotronSTFT = TacotronSTFT(
            filter_length=self.filter_length,
            hop_length=self.hop_length,
            win_length=self.preprocess_config.stft.win_length,
            n_mel_channels=self.preprocess_config.stft.n_mel_channels,
            sampling_rate=self.sampling_rate,
            mel_fmin=self.mel_fmin,
            mel_fmax=self.preprocess_config.stft.mel_fmax,
            center=False,
        )

        min_seconds, max_seconds = (
            self.preprocess_config.min_seconds,
            self.preprocess_config.max_seconds,
        )

        self.min_samples = int(self.sampling_rate * min_seconds)
        self.max_samples = int(self.sampling_rate * max_seconds)

        self.audio_processor = AudioProcessor()

    def beta_binomial_prior_distribution(
        self,
        phoneme_count: int,
        mel_count: int,
        scaling_factor: float = 1.0,
    ) -> torch.Tensor:
        r"""Computes the beta-binomial prior distribution for the attention mechanism.

        Args:
            phoneme_count (int): Number of phonemes in the input text.
            mel_count (int): Number of mel frames in the input mel-spectrogram.
            scaling_factor (float, optional): Scaling factor for the beta distribution. Defaults to 1.0.

        Returns:
            torch.Tensor: A 2D tensor containing the prior distribution.
        """
        P, M = phoneme_count, mel_count
        x = np.arange(0, P)
        mel_text_probs = []
        for i in range(1, M + 1):
            a, b = scaling_factor * i, scaling_factor * (M + 1 - i)
            rv: Any = betabinom(P, a, b)
            mel_i_prob = rv.pmf(x)
            mel_text_probs.append(mel_i_prob)
        return torch.tensor(np.array(mel_text_probs))

    def acoustic(
        self,
        row: Tuple[torch.Tensor, int, str, str, int, str | int, str],
    ) -> Union[None, PreprocessForAcousticResult]:
        r"""Preprocesses audio and text data for use with a TacotronSTFT model.

        Args:
            row (Tuple[torch.FloatTensor, int, str, str, int, str | int, str]): The input row. The row is a tuple containing the following elements: (audio, sr_actual, raw_text, normalized_text, speaker_id, chapter_id, utterance_id).

        Returns:
            dict: A dictionary containing the preprocessed audio and text data.

        Examples:
            >>> preprocess_audio = PreprocessAudio("english_only")
            >>> audio = torch.randn(1, 44100)
            >>> sr_actual = 44100
            >>> raw_text = "Hello, world!"
            >>> output = preprocess_audio(audio, sr_actual, raw_text)
            >>> output.keys()
            dict_keys(['wav', 'mel', 'pitch', 'phones', 'raw_text', 'normalized_text', 'speaker_id', 'chapter_id', 'utterance_id', 'pitch_is_normalized'])
        """
        (
            audio,
            sr_actual,
            raw_text,
            normalized_text,
            speaker_id,
            chapter_id,
            utterance_id,
        ) = row

        wav, sampling_rate = preprocess_audio(audio, sr_actual, self.sampling_rate)

        # TODO: check this, maybe you need to move it to some other place
        # TODO: maybe we can increate the max_samples ?
        # if wav.shape[0] < self.min_samples or wav.shape[0] > self.max_samples:
        #     return None

        if self.use_audio_normalization:
            wav = normalize_loudness(wav)

        normalized_text = self.normilize_text(normalized_text)

        # NOTE: fixed version of tokenizer with punctuation
        phones_ipa, phones = self.tokenizer(normalized_text)

        # Convert to tensor
        phones = torch.Tensor(phones)

        mel_spectrogram = self.tacotronSTFT.get_mel_from_wav(wav)

        # Skipping small sample due to the mel-spectrogram containing less than self.mel_fmin frames
        # if mel_spectrogram.shape[1] < self.mel_fmin:
        #     return None

        # Text is longer than mel, will be skipped due to monotonic alignment search
        if phones.shape[0] >= mel_spectrogram.shape[1]:
            return None

        pitch, _, _, _ = compute_yin(
            wav,
            sr=sampling_rate,
            w_len=self.filter_length,
            w_step=self.hop_length,
            f0_min=50,
            f0_max=1000,
            harmo_thresh=0.25,
        )

        pitch, _ = norm_interp_f0(pitch)

        if np.sum(pitch != 0) <= 1:
            return None

        pitch = torch.tensor(pitch)

        # TODO this shouldnt be necessary, currently pitch sometimes has 1 less frame than spectrogram,
        # We should find out why
        mel_spectrogram = mel_spectrogram[:, : pitch.shape[0]]

        attn_prior = self.beta_binomial_prior_distribution(
            phones.shape[0],
            mel_spectrogram.shape[1],
        ).T

        assert pitch.shape[0] == mel_spectrogram.shape[1], (
            pitch.shape,
            mel_spectrogram.shape[1],
        )

        energy = self.audio_processor.wav_to_energy(
            wav.unsqueeze(0),
            self.filter_length,
            self.hop_length,
            self.win_length,
        )

        return PreprocessForAcousticResult(
            wav=wav,
            mel=mel_spectrogram,
            pitch=pitch,
            attn_prior=attn_prior,
            energy=energy,
            phones_ipa=phones_ipa,
            phones=phones,
            raw_text=raw_text,
            normalized_text=normalized_text,
            speaker_id=speaker_id,
            chapter_id=chapter_id,
            utterance_id=utterance_id,
            # TODO: check the pitch normalization process
            pitch_is_normalized=False,
        )

    def univnet(self, row: Tuple[torch.Tensor, int, str, str, int, str | int, str]):
        r"""Preprocesses audio data for use with a UnivNet model.

        This method takes a row of data, extracts the audio and preprocesses it.
        It then selects a random segment from the preprocessed audio and its corresponding mel spectrogram.

        Args:
            row (Tuple[torch.FloatTensor, int, str, str, int, str | int, str]): The input row. The row is a tuple containing the following elements: (audio, sr_actual, raw_text, normalized_text, speaker_id, chapter_id, utterance_id).

        Returns:
            Tuple[torch.Tensor, torch.Tensor, int]: A tuple containing the selected segment of the mel spectrogram, the corresponding audio segment, and the speaker ID.

        Examples:
            >>> preprocess = PreprocessLibriTTS()
            >>> audio = torch.randn(1, 44100)
            >>> sr_actual = 44100
            >>> speaker_id = 0
            >>> mel, audio_segment, speaker_id = preprocess.preprocess_univnet((audio, sr_actual, "", "", speaker_id, 0, ""))
        """
        (
            audio,
            sr_actual,
            _,
            _,
            speaker_id,
            _,
            _,
        ) = row

        segment_size = self.vocoder_train_config.segment_size
        frames_per_seg = math.ceil(segment_size / self.hop_length)

        wav, _ = preprocess_audio(audio, sr_actual, self.sampling_rate)

        if self.use_audio_normalization:
            wav = normalize_loudness(wav)

        mel_spectrogram = self.tacotronSTFT.get_mel_from_wav(wav)

        if wav.shape[0] < segment_size:
            wav = F.pad(
                wav,
                (0, segment_size - wav.shape[0]),
                "constant",
            )

        if mel_spectrogram.shape[1] < frames_per_seg:
            mel_spectrogram = F.pad(
                mel_spectrogram,
                (0, frames_per_seg - mel_spectrogram.shape[1]),
                "constant",
            )

        from_frame = random.randint(0, mel_spectrogram.shape[1] - frames_per_seg)

        # Skip last frame, otherwise errors are thrown, find out why
        if from_frame > 0:
            from_frame -= 1

        till_frame = from_frame + frames_per_seg

        mel_spectrogram = mel_spectrogram[:, from_frame:till_frame]
        wav = wav[from_frame * self.hop_length : till_frame * self.hop_length]

        return mel_spectrogram, wav, speaker_id

In [14]:
from typing import List, Union

import torch
from torch import Tensor, nn


def pad_1D(inputs: List[Tensor], pad_value: float = 0.0) -> Tensor:
    r"""Pad a list of 1D tensor list to the same length.

    Args:
        inputs (List[torch.Tensor]): List of 1D numpy arrays to pad.
        pad_value (float): Value to use for padding. Default is 0.0.

    Returns:
        torch.Tensor: Padded 2D numpy array of shape (len(inputs), max_len), where max_len is the length of the longest input array.
    """
    max_len = max(x.size(0) for x in inputs)
    padded_inputs = [nn.functional.pad(x, (0, max_len - x.size(0)), value=pad_value) for x in inputs]
    return torch.stack(padded_inputs)


def pad_2D(
    inputs: List[Tensor], maxlen: Union[int, None] = None, pad_value: float = 0.0,
) -> Tensor:
    r"""Pad a list of 2D tensor arrays to the same length.

    Args:
        inputs (List[torch.Tensor]): List of 2D numpy arrays to pad.
        maxlen (Union[int, None]): Maximum length to pad the arrays to. If None, pad to the length of the longest array. Default is None.
        pad_value (float): Value to use for padding. Default is 0.0.

    Returns:
        torch.Tensor: Padded 3D numpy array of shape (len(inputs), max_len, input_dim), where max_len is the maximum length of the input arrays, and input_dim is the dimension of the input arrays.
    """
    max_len = max(x.size(1) for x in inputs) if maxlen is None else maxlen

    padded_inputs = [nn.functional.pad(x, (0, max_len - x.size(1), 0, 0), value=pad_value) for x in inputs]
    return torch.stack(padded_inputs)


def pad_3D(inputs: Union[Tensor, List[Tensor]], B: int, T: int, L: int) -> Tensor:
    r"""Pad a 3D torch tensor to a specified shape.

    Args:
        inputs (torch.Tensor): 3D numpy array to pad.
        B (int): Batch size to pad the array to.
        T (int): Time steps to pad the array to.
        L (int): Length to pad the array to.

    Returns:
        torch.Tensor: Padded 3D numpy array of shape (B, T, L), where B is the batch size, T is the time steps, and L is the length.
    """
    if isinstance(inputs, list):
        inputs_padded = torch.zeros(B, T, L, dtype=inputs[0].dtype)
        for i, input_ in enumerate(inputs):
            inputs_padded[i, :input_.size(0), :input_.size(1)] = input_

    elif isinstance(inputs, torch.Tensor):
        inputs_padded = torch.zeros(B, T, L, dtype=inputs.dtype)
        inputs_padded[:inputs.size(0), :inputs.size(1), :inputs.size(2)] = inputs

    return inputs_padded

In [15]:
from typing import Any, Dict, List

import numpy as np
import torch
from torch.utils.data import Dataset
from torchaudio import datasets

# from models.config import PreprocessingConfigUnivNet, get_lang_map

# from training.preprocess import PreprocessLibriTTS
# from training.tools import pad_1D, pad_2D


class LibriTTSDatasetVocoder(Dataset):
    r"""Loading preprocessed univnet model data."""

    def __init__(
        self,
        root: str,
        batch_size: int,
        download: bool = True,
        lang: str = "en",
    ):
        r"""A PyTorch dataset for loading preprocessed univnet data.

        Args:
            root (str): Path to the directory where the dataset is found or downloaded.
            batch_size (int): Batch size for the dataset.
            download (bool, optional): Whether to download the dataset if it is not found. Defaults to True.
        """
        self.dataset = datasets.LIBRITTS(root=root, download=download)
        self.batch_size = batch_size

        lang_map = get_lang_map(lang)
        self.preprocess_libtts = PreprocessLibriTTS(
            PreprocessingConfigUnivNet(lang_map.processing_lang_type),
        )

    def __len__(self) -> int:
        r"""Returns the number of samples in the dataset.

        Returns
            int: Number of samples in the dataset.
        """
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r"""Returns a sample from the dataset at the given index.

        Args:
            idx (int): Index of the sample to return.

        Returns:
            Dict[str, Any]: A dictionary containing the sample data.
        """
        # Retrive the dataset row
        data = self.dataset[idx]

        data = self.preprocess_libtts.univnet(data)

        if data is None:
            # print("Skipping due to preprocessing error")
            rand_idx = np.random.randint(0, self.__len__())
            return self.__getitem__(rand_idx)

        mel, audio, speaker_id = data

        return {
            "mel": mel,
            "audio": audio,
            "speaker_id": speaker_id,
        }

    def collate_fn(self, data: List) -> List:
        r"""Collates a batch of data samples.

        Args:
            data (List): A list of data samples.

        Returns:
            List: A list of reprocessed data batches.
        """
        data_size = len(data)

        idxs = list(range(data_size))

        # Initialize empty lists to store extracted values
        empty_lists: List[List] = [[] for _ in range(4)]
        (
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ) = empty_lists

        # Extract fields from data dictionary and populate the lists
        for idx in idxs:
            data_entry = data[idx]

            mels.append(data_entry["mel"])
            mel_lens.append(data_entry["mel"].shape[1])
            audios.append(data_entry["audio"])
            speaker_ids.append(data_entry["speaker_id"])

        mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
        mel_lens = torch.tensor(mel_lens, dtype=torch.int64)
        audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
        speaker_ids = torch.tensor(speaker_ids, dtype=torch.int64)

        return [
            mels,
            mel_lens,
            audios,
            speaker_ids,
        ]

In [56]:
batch_size = 2
dataset = LibriTTSDatasetVocoder(
            root="datasets_cache/LIBRITTS",
            batch_size=batch_size,
            download=False,
        )

In [57]:
assert len(dataset) == 33236

In [58]:
sample = dataset[0]

assert sample["mel"].shape == torch.Size([100, 64])
assert sample["audio"].shape == torch.Size([16384])
assert sample["speaker_id"] == 1034

In [62]:
data = [
    dataset[0],
    dataset[2],
]
result = dataset.collate_fn(data)
assert len(result) == 4
for batch in result:
    assert len(batch) == batch_size

  mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
  audios = torch.tensor(pad_1D(audios), dtype=torch.float32)


In [66]:
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=dataset.collate_fn,
)

dataloader_iter = iter(dataloader)

for _, items in enumerate([next(dataloader_iter), next(dataloader_iter)]):
    assert len(items) == 4
    for it in items:
        assert len(it) == batch_size

  mels = torch.tensor(pad_2D(mels), dtype=torch.float32)
  audios = torch.tensor(pad_1D(audios), dtype=torch.float32)
