<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#TextMelDataset" data-toc-modified-id="TextMelDataset-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>TextMelDataset</a></span></li><li><span><a href="#TextMelCollate" data-toc-modified-id="TextMelCollate-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>TextMelCollate</a></span></li><li><span><a href="#TextAudioLoader" data-toc-modified-id="TextAudioLoader-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>TextAudioLoader</a></span></li><li><span><a href="#TextAudioCollate" data-toc-modified-id="TextAudioCollate-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>TextAudioCollate</a></span></li><li><span><a href="#DistributedBucketSampler" data-toc-modified-id="DistributedBucketSampler-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>DistributedBucketSampler</a></span><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#TextMelBatch-GradTTS" data-toc-modified-id="TextMelBatch-GradTTS-5.0.1"><span class="toc-item-num">5.0.1&nbsp;&nbsp;</span>TextMelBatch GradTTS</a></span></li></ul></li></ul></li></ul></div>

In [None]:
# default_exp data_loader

In [None]:
# export
import os
import random
import re
from pathlib import Path
from typing import List

import numpy as np
from scipy.io.wavfile import read
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler

from uberduck_ml_dev.models.common import STFT, MelSTFT
from uberduck_ml_dev.text.symbols import (
    DEFAULT_SYMBOLS,
    IPA_SYMBOLS,
    NVIDIA_TACO2_SYMBOLS,
    GRAD_TTS_SYMBOLS,
)
from uberduck_ml_dev.text.util import cleaned_text_to_sequence, text_to_sequence
from uberduck_ml_dev.utils.audio import compute_yin, load_wav_to_torch
from uberduck_ml_dev.utils.utils import (
    load_filepaths_and_text,
    intersperse,
)

In [None]:
# export
from collections import defaultdict


def oversample(filepaths_text_sid, sid_to_weight):
    assert all([isinstance(sid, str) for sid in sid_to_weight.keys()])
    output = []
    for fts in filepaths_text_sid:
        sid = fts[2]
        for _ in range(sid_to_weight.get(sid, 1)):
            output.append(fts)
    return output

In [None]:
mock_fts = [
    ("speaker0/1.wav", "Test one two", "0"),
    ("speaker0/2.wav", "Test one two", "0"),
    ("speaker1/1.wav", "Test one two", "1"),
]
assert oversample(mock_fts, {"1": 3}) == [
    ("speaker0/1.wav", "Test one two", "0"),
    ("speaker0/2.wav", "Test one two", "0"),
    ("speaker1/1.wav", "Test one two", "1"),
    ("speaker1/1.wav", "Test one two", "1"),
    ("speaker1/1.wav", "Test one two", "1"),
]

# TextMelDataset

In [None]:
# export


def _orig_to_dense_speaker_id(speaker_ids):
    speaker_ids = sorted(list(set(speaker_ids)))
    return {orig: idx for orig, idx in zip(speaker_ids, range(len(speaker_ids)))}


class TextMelDataset(Dataset):
    def __init__(
        self,
        audiopaths_and_text: str,
        text_cleaners: List[str],
        p_arpabet: float,
        n_mel_channels: int,
        sampling_rate: int,
        mel_fmin: float,
        mel_fmax: float,
        filter_length: int,
        hop_length: int,
        win_length: int,
        symbol_set: str,
        padding: int = None,
        max_wav_value: float = 32768.0,
        include_f0: bool = False,
        pos_weight: float = 10,
        f0_min: int = 80,
        f0_max: int = 880,
        harmonic_thresh=0.25,
        debug: bool = False,
        debug_dataset_size: int = None,
        oversample_weights=None,
        intersperse_text: bool = False,
        intersperse_token: int = 0,
        compute_gst=None,
    ):
        super().__init__()
        path = audiopaths_and_text
        oversample_weights = oversample_weights or {}
        self.audiopaths_and_text = oversample(
            load_filepaths_and_text(path), oversample_weights
        )
        self.text_cleaners = text_cleaners
        self.p_arpabet = p_arpabet

        self.stft = MelSTFT(
            filter_length=filter_length,
            hop_length=hop_length,
            win_length=win_length,
            n_mel_channels=n_mel_channels,
            sampling_rate=sampling_rate,
            mel_fmin=mel_fmin,
            mel_fmax=mel_fmax,
            padding=padding,
        )
        self.max_wav_value = max_wav_value
        self.sampling_rate = sampling_rate
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.mel_fmin = mel_fmin
        self.mel_fmax = mel_fmax
        self.include_f0 = include_f0
        self.f0_min = f0_min
        self.f0_max = f0_max
        self.harmonic_threshold = harmonic_thresh
        # speaker id lookup table
        speaker_ids = [i[2] for i in self.audiopaths_and_text]
        self._speaker_id_map = _orig_to_dense_speaker_id(speaker_ids)
        self.debug = debug
        self.debug_dataset_size = debug_dataset_size
        self.symbol_set = symbol_set
        self.intersperse_text = intersperse_text
        self.intersperse_token = intersperse_token
        self.compute_gst = compute_gst

    def _get_f0(self, audio):
        f0, harmonic_rates, argmins, times = compute_yin(
            audio,
            self.sampling_rate,
            self.filter_length,
            self.hop_length,
            self.f0_min,
            self.f0_max,
            self.harmonic_threshold,
        )
        pad = int((self.filter_length / self.hop_length) / 2)
        f0 = [0.0] * pad + f0 + [0.0] * pad
        f0 = np.array(f0, dtype=np.float32)
        return f0

    def _get_gst(self, transcription):
        return self.compute_gst(transcription)

    def _get_data(self, audiopath_and_text):
        path, transcription, speaker_id = audiopath_and_text
        speaker_id = self._speaker_id_map[speaker_id]
        sampling_rate, wav_data = read(path)
        text_sequence = torch.LongTensor(
            text_to_sequence(
                transcription,
                self.text_cleaners,
                p_arpabet=self.p_arpabet,
                symbol_set=self.symbol_set,
            )
        )
        if self.intersperse_text:
            text_sequence = torch.LongTensor(
                intersperse(text_sequence.numpy(), self.intersperse_token)
            )  # add a blank token, whose id number is len(symbols)

        audio = torch.FloatTensor(wav_data)
        audio_norm = audio / self.max_wav_value
        audio_norm = audio_norm.unsqueeze(0)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0)
        data = {
            "text_sequence": text_sequence,
            "mel": melspec,
            "speaker_id": speaker_id,
            "embedded_gst": None,
            "f0": None,
        }

        if self.compute_gst:
            embedded_gst = self._get_gst([transcription])
            data["embedded_gst"] = embedded_gst

        if self.include_f0:
            f0 = self._get_f0(audio.data.cpu().numpy())
            f0 = torch.from_numpy(f0)[None]
            f0 = f0[:, : melspec.size(1)]
            data["f0"] = f0

        return data  # (text_sequence, melspec, speaker_id, f0)

    def __getitem__(self, idx):
        """Return data for a single audio file + transcription."""
        try:
            data = self._get_data(self.audiopaths_and_text[idx])
        except Exception as e:
            print(f"Error while getting data: {self.audiopaths_and_text[idx]}")
            print(e)
            raise
        return data

    def __len__(self):
        if self.debug and self.debug_dataset_size:
            return min(self.debug_dataset_size, len(self.audiopaths_and_text))
        return len(self.audiopaths_and_text)

    def sample_test_batch(self, size):
        idx = np.random.choice(range(len(self)), size=size, replace=False)
        test_batch = []
        for index in idx:
            test_batch.append(self.__getitem__(index))
        return test_batch

# TextMelCollate

In [None]:
# export


class TextMelCollate:
    def __init__(self, n_frames_per_step: int = 1, include_f0: bool = False):
        self.n_frames_per_step = n_frames_per_step
        self.include_f0 = include_f0

    def set_frames_per_step(self, n_frames_per_step):
        """Set n_frames_step.

        This is used to train with gradual training, where we start with a large
        n_frames_per_step in order to learn attention quickly and decrease it
        over the course of training in order to increase accuracy. Gradual training
        reference:
        https://erogol.com/gradual-training-with-tacotron-for-faster-convergence/
        """
        self.n_frames_per_step = n_frames_per_step

    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram
        PARAMS
        ------
        batch: [text_normalized, mel_normalized, speaker_id]
        """
        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x["text_sequence"]) for x in batch]),
            dim=0,
            descending=True,
        )
        max_input_len = input_lengths[0]

        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]]["text_sequence"]
            text_padded[i, : text.size(0)] = text

        # Right zero-pad mel-spec
        num_mels = batch[0]["mel"].size(0)
        max_target_len = max([x["mel"].size(1) for x in batch])
        if max_target_len % self.n_frames_per_step != 0:
            max_target_len += (
                self.n_frames_per_step - max_target_len % self.n_frames_per_step
            )
            assert max_target_len % self.n_frames_per_step == 0

        # include mel padded, gate padded and speaker ids
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        speaker_ids = torch.LongTensor(len(batch))
        if self.include_f0:
            f0_padded = torch.FloatTensor(len(batch), 1, max_target_len)
            f0_padded.zero_()

        # pdb.set_trace()
        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]]["mel"]
            mel_padded[i, :, : mel.size(1)] = mel
            gate_padded[i, mel.size(1) - 1 :] = 1
            output_lengths[i] = mel.size(1)
            speaker_ids[i] = batch[ids_sorted_decreasing[i]]["speaker_id"]

        if batch[0]["embedded_gst"] is None:
            embedded_gsts = None
        else:
            embedded_gsts = torch.FloatTensor(
                np.array([sample["embedded_gst"] for sample in batch])
            )

        model_inputs = (
            text_padded,
            input_lengths,
            mel_padded,
            gate_padded,
            output_lengths,
            speaker_ids,
            embedded_gsts,
        )
        return model_inputs

In [None]:
_orig_to_dense_speaker_id([4, 2, 9, 3, 9])

{2: 0, 3: 1, 4: 2, 9: 3}

In [None]:
ds = TextMelDataset(
    "test/fixtures/val.txt",
    ["english_cleaners"],
    0.0,
    80,
    22050,
    0,
    8000,
    1024,
    256,
    padding=None,
    win_length=1024,
    debug=True,
    debug_dataset_size=12,
    symbol_set="default",
)
len(ds)

1

In [None]:
from torch.utils.data import DataLoader

collate_fn = TextMelCollate()
dl = DataLoader(ds, 12, collate_fn=collate_fn)
for i, batch in enumerate(dl):
    assert len(batch) == 7

In [None]:
ds = TextMelDataset(
    "test/fixtures/val.txt",
    ["english_cleaners"],
    0.0,
    80,
    22050,
    0,
    8000,
    1024,
    256,
    padding=None,
    win_length=1024,
    debug=True,
    debug_dataset_size=12,
    include_f0=True,
    symbol_set="default",
)
assert len(ds) == 1

In [None]:
collate_fn = TextMelCollate(include_f0=True)
dl = DataLoader(ds, 12, collate_fn=collate_fn)
for i, batch in enumerate(dl):
    # text_padded,
    # input_lengths,
    # mel_padded,
    # gate_padded,
    # output_lengths,
    # speaker_ids,
    (
        text_padded,
        input_lengths,
        mel_padded,
        gate_padded,
        output_lengths,
        speaker_ids,
        *_,
    ) = batch
    assert output_lengths.item() == 566, print("output lengths: ", output_lengths)
    assert gate_padded.size(1) == 566
    assert mel_padded.size(2) == 566
    assert len(batch) == 7

In [None]:
mel_padded.shape

torch.Size([1, 80, 566])

In [None]:
# testing n_frames_per_step > 1
ds = TextMelDataset(
    "test/fixtures/val.txt",
    ["english_cleaners"],
    0.0,
    80,
    22050,
    0,
    8000,
    1024,
    256,
    padding=None,
    win_length=1024,
    debug=True,
    debug_dataset_size=12,
    include_f0=True,
    symbol_set="default",
)
assert len(ds) == 1
collate_fn = TextMelCollate(n_frames_per_step=5, include_f0=True)
dl = DataLoader(ds, 12, collate_fn=collate_fn)
# text_padded,
# input_lengths,
# mel_padded,
# gate_padded,
# output_lengths,
# speaker_ids,
for i, batch in enumerate(dl):
    (
        text_padded,
        input_lengths,
        mel_padded,
        gate_padded,
        output_lengths,
        speaker_ids,
        *_,
    ) = batch
    assert output_lengths.item() == 566, output_lengths.item()
    assert mel_padded.size(2) == 570, print("actual shape: ", mel_padded.shape)
    assert gate_padded.size(1) == 570, print("actual shape: ", gate_padded.shape)
    assert len(batch) == 7

# TextAudioLoader

In [None]:
# export


class TextAudioSpeakerLoader(Dataset):
    """
    1) loads audio, speaker_id, text pairs
    2) normalizes text and converts them to sequences of integers
    3) computes spectrograms from audio files.
    """

    def __init__(
        self, audiopaths_sid_text, hparams, debug=False, debug_dataset_size=None
    ):
        oversample_weights = hparams.oversample_weights or {}
        self.audiopaths_sid_text = oversample(
            load_filepaths_and_text(audiopaths_sid_text), oversample_weights
        )
        self.text_cleaners = hparams.text_cleaners
        self.max_wav_value = hparams.max_wav_value
        self.sampling_rate = hparams.sampling_rate
        self.filter_length = hparams.filter_length
        self.hop_length = hparams.hop_length
        self.win_length = hparams.win_length
        self.sampling_rate = hparams.sampling_rate

        self.debug = debug
        self.debug_dataset_size = debug_dataset_size

        self.stft = MelSTFT(
            filter_length=self.filter_length,
            hop_length=self.hop_length,
            win_length=self.win_length,
            n_mel_channels=hparams.n_mel_channels,
            sampling_rate=hparams.sampling_rate,
            mel_fmin=hparams.mel_fmin,
            mel_fmax=hparams.mel_fmax,
            padding=(self.filter_length - self.hop_length) // 2,
        )

        self.cleaned_text = getattr(hparams, "cleaned_text", False)
        # NOTE(zach): Parametrize this later if desired.
        self.symbol_set = IPA_SYMBOLS

        self.add_blank = hparams.add_blank
        self.min_text_len = getattr(hparams, "min_text_len", 1)
        self.max_text_len = getattr(hparams, "max_text_len", 190)

        random.seed(1234)
        random.shuffle(self.audiopaths_sid_text)
        self._filter()

    def _filter(self):
        """
        Filter text & store spec lengths
        """
        # Store spectrogram lengths for Bucketing
        # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
        # spec_length = wav_length // hop_length

        audiopaths_sid_text_new = []
        lengths = []
        for audiopath, sid, text in self.audiopaths_sid_text:
            if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
                audiopaths_sid_text_new.append([audiopath, sid, text])
                lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
        self.audiopaths_sid_text = audiopaths_sid_text_new
        self.lengths = lengths

    def get_audio_text_speaker_pair(self, audiopath_sid_text):
        # separate filename, speaker_id and text
        audiopath, text, sid = (
            audiopath_sid_text[0],
            audiopath_sid_text[1],
            audiopath_sid_text[2],
        )
        text = self.get_text(text)
        spec, wav = self.get_audio(audiopath)
        sid = self.get_sid(sid)
        return (text, spec, wav, sid)

    def get_audio(self, filename):
        audio, sampling_rate = load_wav_to_torch(filename)
        if sampling_rate != self.sampling_rate:
            raise ValueError(
                "{} {} SR doesn't match target {} SR".format(
                    sampling_rate, self.sampling_rate
                )
            )

        audio_norm = audio / self.max_wav_value
        audio_norm = audio_norm.unsqueeze(0)
        spec_filename = filename.replace(".wav", ".uberduck.spec.pt")
        if os.path.exists(spec_filename):
            spec = torch.load(spec_filename)
        else:
            spec = self.stft.spectrogram(audio_norm)
            spec = torch.squeeze(spec, 0)
            torch.save(spec, spec_filename)
        return spec, audio_norm

    def get_text(self, text):
        if self.cleaned_text:
            text_norm = cleaned_text_to_sequence(text, symbol_set=self.symbol_set)
        else:
            text_norm = text_to_sequence(
                text, self.text_cleaners, symbol_set=self.symbol_set
            )
        if self.add_blank:
            text_norm = intersperse(text_norm, 0)
        text_norm = torch.LongTensor(text_norm)
        return text_norm

    def get_sid(self, sid):
        sid = torch.LongTensor([int(sid)])
        return sid

    def __getitem__(self, index):
        return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])

    def __len__(self):
        if self.debug and self.debug_dataset_size:
            return min(self.debug_dataset_size, len(self.audiopaths_sid_text))
        else:
            return len(self.audiopaths_sid_text)

# TextAudioCollate

In [None]:
# export


class TextAudioSpeakerCollate:
    """Zero-pads model inputs and targets"""

    def __init__(self, return_ids=False):
        self.return_ids = return_ids

    def __call__(self, batch):
        """Collate's training batch from normalized text, audio and speaker identities
        PARAMS
        ------
        batch: [text_normalized, spec_normalized, wav_normalized, sid]
        """
        # Right zero-pad all one-hot text sequences to max input length
        _, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
        )

        max_text_len = max([len(x[0]) for x in batch])
        max_spec_len = max([x[1].size(1) for x in batch])
        max_wav_len = max([x[2].size(1) for x in batch])

        text_lengths = torch.LongTensor(len(batch))
        spec_lengths = torch.LongTensor(len(batch))
        wav_lengths = torch.LongTensor(len(batch))
        sid = torch.LongTensor(len(batch))

        text_padded = torch.LongTensor(len(batch), max_text_len)
        spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
        wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
        text_padded.zero_()
        spec_padded.zero_()
        wav_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            row = batch[ids_sorted_decreasing[i]]

            text = row[0]
            text_padded[i, : text.size(0)] = text
            text_lengths[i] = text.size(0)

            spec = row[1]
            spec_padded[i, :, : spec.size(1)] = spec
            spec_lengths[i] = spec.size(1)

            wav = row[2]
            wav_padded[i, :, : wav.size(1)] = wav
            wav_lengths[i] = wav.size(1)

            sid[i] = row[3]

        if self.return_ids:
            return (
                text_padded,
                text_lengths,
                spec_padded,
                spec_lengths,
                wav_padded,
                wav_lengths,
                sid,
                ids_sorted_decreasing,
            )
        return (
            text_padded,
            text_lengths,
            spec_padded,
            spec_lengths,
            wav_padded,
            wav_lengths,
            sid,
        )

# DistributedBucketSampler

In [None]:
# export


class DistributedBucketSampler(DistributedSampler):
    """
    Maintain similar input lengths in a batch.
    Length groups are specified by boundaries.
    Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.

    It removes samples which are not included in the boundaries.
    Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
    """

    def __init__(
        self,
        dataset,
        batch_size,
        boundaries,
        num_replicas=None,
        rank=None,
        shuffle=True,
    ):
        super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
        self.lengths = dataset.lengths
        self.batch_size = batch_size
        self.boundaries = boundaries

        self.buckets, self.num_samples_per_bucket = self._create_buckets()
        self.total_size = sum(self.num_samples_per_bucket)
        self.num_samples = self.total_size // self.num_replicas

    def _create_buckets(self):
        buckets = [[] for _ in range(len(self.boundaries) - 1)]
        for i in range(len(self.lengths)):
            length = self.lengths[i]
            idx_bucket = self._bisect(length)
            if idx_bucket != -1:
                buckets[idx_bucket].append(i)

        for i in range(len(buckets) - 1, 0, -1):
            if len(buckets[i]) == 0:
                buckets.pop(i)
                self.boundaries.pop(i + 1)

        num_samples_per_bucket = []
        for i in range(len(buckets)):
            len_bucket = len(buckets[i])
            total_batch_size = self.num_replicas * self.batch_size
            rem = (
                total_batch_size - (len_bucket % total_batch_size)
            ) % total_batch_size
            num_samples_per_bucket.append(len_bucket + rem)
        return buckets, num_samples_per_bucket

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)

        indices = []
        if self.shuffle:
            for bucket in self.buckets:
                indices.append(torch.randperm(len(bucket), generator=g).tolist())
        else:
            for bucket in self.buckets:
                indices.append(list(range(len(bucket))))

        batches = []
        for i in range(len(self.buckets)):
            bucket = self.buckets[i]
            len_bucket = len(bucket)
            ids_bucket = indices[i]
            num_samples_bucket = self.num_samples_per_bucket[i]

            # add extra samples to make it evenly divisible
            rem = num_samples_bucket - len_bucket
            ids_bucket = (
                ids_bucket
                + ids_bucket * (rem // len_bucket)
                + ids_bucket[: (rem % len_bucket)]
            )

            # subsample
            ids_bucket = ids_bucket[self.rank :: self.num_replicas]

            # batching
            for j in range(len(ids_bucket) // self.batch_size):
                batch = [
                    bucket[idx]
                    for idx in ids_bucket[
                        j * self.batch_size : (j + 1) * self.batch_size
                    ]
                ]
                batches.append(batch)

        if self.shuffle:
            batch_ids = torch.randperm(len(batches), generator=g).tolist()
            batches = [batches[i] for i in batch_ids]
        self.batches = batches

        assert len(self.batches) * self.batch_size == self.num_samples
        return iter(self.batches)

    def _bisect(self, x, lo=0, hi=None):
        if hi is None:
            hi = len(self.boundaries) - 1

        if hi > lo:
            mid = (hi + lo) // 2
            if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
                return mid
            elif x <= self.boundaries[mid]:
                return self._bisect(x, lo, mid)
            else:
                return self._bisect(x, mid + 1, hi)
        else:
            return -1

    def __len__(self):
        return self.num_samples // self.batch_size

### TextMelBatch GradTTS

In [None]:
# # export
# import torchaudio as ta


# class TextMelDatasetGradTTS(torch.utils.data.Dataset):
#     def __init__(
#         self,
#         filelist_path,
#         intersperse_text=True,
#         n_fft=1024,
#         n_mels=80,
#         sample_rate=22050,
#         hop_length=256,
#         win_length=1024,
#         f_min=0.0,
#         f_max=8000,
#         intersperse_token=0,
#         symbol_set="grad_tts",
#         text_cleaners=["english"],
#     ):
#         self.filepaths_and_text = load_filepaths_and_text(filelist_path)
#         self.intersperse_text = intersperse_text
#         self.intersperse_token = intersperse_token
#         self.n_fft = n_fft
#         self.n_mels = n_mels
#         self.sample_rate = sample_rate
#         self.hop_length = hop_length
#         self.win_length = win_length
#         self.f_min = f_min
#         self.f_max = f_max
#         self.symbol_set = symbol_set
#         self.text_cleaners = text_cleaners
#         self.p_arpabet = 1.0
#         random.seed(1234)
#         random.shuffle(self.filepaths_and_text)

#     def get_pair(self, filepath_and_text):
#         filepath, text = filepath_and_text[0], filepath_and_text[1]
#         text = self.get_text(text, intersperse_text=self.intersperse_text)
#         mel = self.get_mel(filepath)
#         return (text, mel)

#     def get_mel(self, filepath):
#         audio, sr = ta.load(filepath)
#         assert sr == self.sample_rate
#         mel = mel_spectrogram(
#             audio,
#             self.n_fft,
#             self.n_mels,
#             self.sample_rate,
#             self.hop_length,
#             self.win_length,
#             self.f_min,
#             self.f_max,
#             center=False,
#         ).squeeze()
#         return mel

#     def get_text(self, text, intersperse_text=True):
#         text_sequence = text_to_sequence(
#             text,
#             self.text_cleaners,
#             p_arpabet=self.p_arpabet,
#             symbol_set=self.symbol_set,
#         )

#         if self.intersperse_text:
#             text_sequence = intersperse(
#                 text_sequence, self.intersperse_token
#             )  # add a blank token, whose id number is len(symbols)

#         text_sequence = torch.IntTensor(text_sequence)
#         return text_sequence

#     def __getitem__(self, index):
#         text, mel = self.get_pair(self.filepaths_and_text[index])
#         item = {"y": mel, "x": text}
#         return item

#     def __len__(self):
#         return len(self.filepaths_and_text)

#     def sample_test_batch(self, size):
#         idx = np.random.choice(range(len(self)), size=size, replace=False)
#         test_batch = []
#         for index in idx:
#             test_batch.append(self.__getitem__(index))
#         return test_batch

In [None]:
# # export

# from uberduck_ml_dev.models.grad_tts import fix_len_compatibility


# class TextMelBatchCollateGradTTS(object):
#     def __call__(self, batch):
#         B = len(batch)
#         y_max_length = max([item["y"].shape[-1] for item in batch])
#         y_max_length = fix_len_compatibility(y_max_length)
#         x_max_length = max([item["x"].shape[-1] for item in batch])
#         n_feats = batch[0]["y"].shape[-2]

#         y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
#         x = torch.zeros((B, x_max_length), dtype=torch.long)
#         y_lengths, x_lengths = [], []

#         for i, item in enumerate(batch):
#             y_, x_ = item["y"], item["x"]
#             y_lengths.append(y_.shape[-1])
#             x_lengths.append(x_.shape[-1])
#             y[i, :, : y_.shape[-1]] = y_
#             x[i, : x_.shape[-1]] = x_

#         y_lengths = torch.LongTensor(y_lengths)
#         x_lengths = torch.LongTensor(x_lengths)
#         return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths}

In [None]:
# skip
# these classes are under development
# make a class containing e.g. texts, sequences, that can be read as a batch in either forward passes and inference
# create lists for particular off-the-shelf models?


class TTSDataset(Dataset):
    def __init__(
        self,
        audiopaths_and_text: str,
        text_cleaners: List[str],
        p_arpabet: float,
        n_mel_channels: int,
        sampling_rate: int,
        mel_fmin: float,
        mel_fmax: float,
        filter_length: int,
        hop_length: int,
        padding: int,
        win_length: int,
        symbol_set: str,
        max_wav_value: float = 32768.0,
        include_f0: bool = False,
        pos_weight: float = 10,
        f0_min: int = 80,
        f0_max: int = 880,
        harmonic_thresh=0.25,
        debug: bool = False,
        debug_dataset_size: int = None,
        oversample_weights=None,
        intersperse_text=False,
        intersperse_token=0,
    ):
        super().__init__()

        # oversample
        path = audiopaths_and_text
        oversample_weights = oversample_weights or {}
        self.audiopaths_and_text = oversample(
            load_filepaths_and_text(path), oversample_weights
        )

        # text to seq parameters
        self.text_cleaners = text_cleaners
        self.p_arpabet = p_arpabet
        self.intersperse_text = intersperse_text
        self.intersperse_token = intersperse_token

        self.stft = MelSTFT(
            filter_length=filter_length,
            hop_length=hop_length,
            win_length=win_length,
            n_mel_channels=n_mel_channels,
            sampling_rate=sampling_rate,
            mel_fmin=mel_fmin,
            mel_fmax=mel_fmax,
            padding=padding,
        )
        self.max_wav_value = max_wav_value
        self.sampling_rate = sampling_rate
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.mel_fmin = mel_fmin
        self.mel_fmax = mel_fmax
        self.include_f0 = include_f0
        self.f0_min = f0_min
        self.f0_max = f0_max
        self.harmonic_threshold = harmonic_thresh
        # speaker id lookup table
        speaker_ids = [i[2] for i in self.audiopaths_and_text]
        self.symbol_set = symbol_set

    def _get_f0(self, audio):
        f0, harmonic_rates, argmins, times = compute_yin(
            audio,
            self.sampling_rate,
            self.filter_length,
            self.hop_length,
            self.f0_min,
            self.f0_max,
            self.harmonic_threshold,
        )
        pad = int((self.filter_length / self.hop_length) / 2)
        f0 = [0.0] * pad + f0 + [0.0] * pad
        f0 = np.array(f0, dtype=np.float32)
        return f0

    def _text_to_seq(self, audiopath_and_text):
        path, transcription, speaker_id = audiopath_and_text
        text_sequence = torch.LongTensor(
            text_to_sequence(
                transcription,
                self.text_cleaners,
                p_arpabet=self.p_arpabet,
                symbol_set=self.symbol_set,
            )
        )
        if self.intersperse_text:
            text_sequence = torch.LongTensor(
                intersperse(text_sequence.numpy(), self.intersperse_token)
            )  # add a blank token, whose id number is len(symbols)
        return text_to_sequence

    #     def _get_f0(self,audio):

    #         if self.include_f0:
    #         else:
    #             return None

    #    def _get_mel(self,):

    def _get_data(self, audiopath_and_text):

        sequence = self._text_to_seq(audiopath_and_text)
        audio = self._get_audio(audiopath_and_text)
        melspec = self._get_mel(audio)
        f0 = self._get_f0(audio)
        speaker_id = self._get_sid(audiopath_and_text)

        return (text_sequence, melspec, speaker_id, f0)


class Collate:
    """
    Collate assembles batches from list indexed by sample id
    text, spectragram, etc"""

    def __init__(**args):
        pass

    #         n_frames_per_step: int = 1,
    #         include_f0: bool = False,
    #         include_sid: bool = False,
    #         batch_format: str

    def _pad_sequence(self, batch):

        batch_size = len(batch)
        input_lengths = [len(x[0].shape[1]) for x in batch]
        max_input_len = input_lengths.max()
        text_padded = torch.LongTensor(batch_size, max_input_len)
        text_padded.zero_()
        for i in range(batch_size):
            text_padded[i, : batch[0][i].shape[1]] = batch[0][i]

        return text_padded

    def _pad_mel(self, batch):

        batch_size = len(batch)
        target_lengths = [len(x[0].shape[1]) for x in batch]
        max_target_len = max(target_len)
        textint_padded = torch.LongTensor(batch_size, max_input_len)
        textint_padded.zero_()
        for i in range(batch_size):
            textint_padded[i, : batch[0][i].shape[1]] = batch[0][i]

        # assert len(f0) = len(mel)
        return text_padded

    def _pad_f0(self, batch):
        return None

    def __call__(self, batch):

        text_padded, input_lengths = _pad_sequence(batch)  # idx
        mel_padded, gate_padded, output_lengths = _pad_mel(batch)
        f0 = None
        batch = Batch(
            text=text_padded,
            input_lengths=input_lengths,
            mel_padded=mel_padded,
            gate_padded=gate_padded,
            output_lengths=output_lengths,
            f0=f0,
            speaker_ids=speaker_ids,
        )

        if batch_format == "taco2ss":
            return (text_padded, mel_padded, mel_padded, output_lengths, input_lengths)
        # if batch_format == 'taco2ms':

        return batch

    def inference(self, batch):

        if batch_format == "taco2ss":
            return (self.text_padded, self.input_lengths)
        if batch_format == "taco2ms":
            return (self.text_padded, self.input_lengths)

    # need to have pad_sequences equivalent
    def _to_tacotron2_singlespeaker_inference(self, batch):

        text_padded, input_lengths = _pad_sequence(batch_list)  # idx
        mel_padded, gate_padded, output_lengths = _pad_mel(batch_list)
        return (self.text_padded, self.input_lengths)

    # NOTE(zach): would model_inputs be better as a namedtuple or dataclass?
    def _to_mellotron_train_f0():

        batch = Batch
        return (
            text_padded,
            input_lengths,
            mel_padded,
            gate_padded,
            output_lengths,
            speaker_ids,
            f0_padded,
        )

    #     if self.include_f0:
    #         model_inputs =
    #     else:
    #         model_inputs = (
    #             text_padded,
    #             input_lengths,
    #             mel_padded,
    #             gate_padded,
    #             output_lengths,
    #             speaker_ids,
    #         )

    def _to_tacotron2_multispeaker_inference(batch):

        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x) for x in batch]), dim=0, descending=True
        )
        max_input_len = input_lengths[0]

        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            text_padded[i, : text.size(0)] = batch[i]

        return (self.text_padded, self.speakers, self.input_lengths, self.sort_indices)


from dataclasses import dataclass

# @dataclass
# class Batch:
#     textint_padded: torch.Tensor,
#     input_lengths: list
#     mel_padded: torch.Tensor
#     gate_padded:
#     output_length: list,
#     speaker_ids: list,
#     f0_padded: list,

# # export

# from uberduck_ml_dev.text.symbols import (
#     DEFAULT_SYMBOLS,
#     IPA_SYMBOLS,
#     NVIDIA_TACO2_SYMBOLS,
#     GRAD_TTS_SYMBOLS,
# )


# def pad_sequences(batch):
#     input_lengths = torch.LongTensor([len(x) for x in batch])
#     max_input_len = input_lengths.max()

#     text_padded = torch.LongTensor(len(batch), max_input_len)
#     text_padded.zero_()
#     for i in range(len(batch)):
#         text = batch[i]
#         text_padded[i, : text.size(0)] = text

#     return text_padded, input_lengths


# def prepare_input_sequence(
#     texts, cpu_run=False, arpabet=False, symbol_set=NVIDIA_TACO2_SYMBOLS
# ):
#     p_arpabet = float(arpabet)
#     seqs = []
#     for text in texts:
#         seqs.append(
#             torch.IntTensor(
#                 text_to_sequence(
#                     text,
#                     ["english_cleaners"],
#                     p_arpabet=p_arpabet,
#                     symbol_set=symbol_set,
#                 )[:]
#             )
#         )
#     text_padded, input_lengths = pad_sequences(seqs)
#     if not cpu_run:
#         text_padded = text_padded.cuda().long()
#         input_lengths = input_lengths.cuda().long()
#     else:
#         text_padded = text_padded.long()
#         input_lengths = input_lengths.long()

#     return text_padded, input_lengths