<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [None]:
#default_exp data_loader

In [None]:
#export
import os
import random
import re
from pathlib import Path
from typing import List

import numpy as np
from scipy.io.wavfile import read
import torch
from torch.utils.data import Dataset

from uberduck_ml_dev.models.common import STFT, MelSTFT
from uberduck_ml_dev.text.util import text_to_sequence
from uberduck_ml_dev.utils import load_filepaths_and_text


In [None]:
?text_to_sequence

In [None]:
# export

def _orig_to_dense_speaker_id(speaker_ids):
    speaker_ids = sorted(list(set(speaker_ids)))
    return {
        orig: idx for orig, idx in zip(speaker_ids, range(len(speaker_ids)))
    }


class TextMelDataset(Dataset):
    def __init__(
        self,
        dataset_path: str,
        audiopaths_and_text: str,
        text_cleaners: List[str],
        n_mel_channels: int,
        sample_rate: int,
        mel_fmin: float,
        mel_fmax: float,
        filter_length: int,
        hop_length: int,
        win_length: int,
        max_wav_value: float = 32768.0,
        include_f0: bool = False,
        debug: bool = False,
        debug_dataset_size: int = None,
    ):
        super().__init__()
        if include_f0:
            raise NotImplemented
        path = str(Path(dataset_path) / audiopaths_and_text)
        self.dataset_path = dataset_path
        self.audiopaths_and_text = load_filepaths_and_text(path)
        self.text_cleaners = text_cleaners
        self.stft = MelSTFT(
            filter_length=filter_length,
            hop_length=hop_length,
            win_length=win_length,
            n_mel_channels=n_mel_channels,
            sampling_rate=sample_rate,
            mel_fmin=mel_fmin,
            mel_fmax=mel_fmax,
        )
        self.max_wav_value = max_wav_value
        self.include_f0 = include_f0
        # speaker id lookup table
        speaker_ids =[i[2] for i in self.audiopaths_and_text]
        self._speaker_id_map = _orig_to_dense_speaker_id(speaker_ids)
        self.debug = debug
        self.debug_dataset_size = debug_dataset_size
        

    def _get_data(self, audiopath_and_text):
        path, transcription, speaker_id = audiopath_and_text
        speaker_id = self._speaker_id_map[speaker_id]
        path = Path(self.dataset_path) / path
        sample_rate, wav_data = read(path)
        text_sequence = torch.LongTensor(text_to_sequence(transcription, self.text_cleaners))
        audio = torch.FloatTensor(wav_data)
        audio_norm = audio / self.max_wav_value
        audio_norm = audio_norm.unsqueeze(0)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0)
        return (text_sequence, melspec, speaker_id)
        

    def __getitem__(self, idx):
        """Return"""
        return self._get_data(self.audiopaths_and_text[idx])

    def __len__(self):
        if self.debug and self.debug_dataset_size:
            return self.debug_dataset_size
        return len(self.audiopaths_and_text)

In [None]:
# export


class TextMelCollate:
    def __init__(self, n_frames_per_step: int = 1, include_f0: bool = False):
        self.n_frames_per_step = n_frames_per_step
        self.include_f0 = include_f0

    def __call__(self, batch):
        """Collate's training batch from normalized text and mel-spectrogram
        PARAMS
        ------
        batch: [text_normalized, mel_normalized, speaker_id]
        """
        # Right zero-pad all one-hot text sequences to max input length
        input_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
        )
        max_input_len = input_lengths[0]

        text_padded = torch.LongTensor(len(batch), max_input_len)
        text_padded.zero_()
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, : text.size(0)] = text

        # Right zero-pad mel-spec
        num_mels = batch[0][1].size(0)
        max_target_len = max([x[1].size(1) for x in batch])
        if max_target_len % self.n_frames_per_step != 0:
            max_target_len += (
                self.n_frames_per_step - max_target_len % self.n_frames_per_step
            )
            assert max_target_len % self.n_frames_per_step == 0

        # include mel padded, gate padded and speaker ids
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
        mel_padded.zero_()
        gate_padded = torch.FloatTensor(len(batch), max_target_len)
        gate_padded.zero_()
        output_lengths = torch.LongTensor(len(batch))
        speaker_ids = torch.LongTensor(len(batch))
        if self.include_f0:
            f0_padded = torch.FloatTensor(len(batch), 1, max_target_len)
            f0_padded.zero_()

        for i in range(len(ids_sorted_decreasing)):
            mel = batch[ids_sorted_decreasing[i]][1]
            mel_padded[i, :, : mel.size(1)] = mel
            gate_padded[i, mel.size(1) - 1 :] = 1
            output_lengths[i] = mel.size(1)
            speaker_ids[i] = batch[ids_sorted_decreasing[i]][2]
            if self.include_f0:
                f0 = batch[ids_sorted_decreasing[i]][3]
                f0_padded[i, :, : f0.size(1)] = f0

        # NOTE(zach): would model_inputs be better as a namedtuple or dataclass?
        if self.include_f0:
            model_inputs = (
                text_padded,
                input_lengths,
                mel_padded,
                gate_padded,
                output_lengths,
                speaker_ids,
                f0_padded,
            )
        else:
            model_inputs = (
                text_padded,
                input_lengths,
                mel_padded,
                gate_padded,
                output_lengths,
                speaker_ids,
            )

        return model_inputs

In [None]:
_orig_to_dense_speaker_id([4, 2, 9, 3, 9])

{2: 0, 3: 1, 4: 2, 9: 3}

In [None]:
ds = TextMelDataset(
    "../dataset",
     "val.txt",
    ["english_cleaners"],
    80,
    22050,
    0,
    8000,
    1024,
    256,
    1024,
    debug=True,
    debug_dataset_size=12,
)
len(ds)

12

In [None]:
from torch.utils.data import DataLoader
collate_fn = TextMelCollate()
dl = DataLoader(ds, 12, collate_fn=collate_fn)
for i, batch in enumerate(dl):
    print(i)
    print(batch)

0
(tensor([[97, 79, 86,  ..., 88, 81,  4],
        [86, 83, 80,  ...,  0,  0,  0],
        [94, 82, 75,  ...,  0,  0,  0],
        ...,
        [89, 82,  4,  ...,  0,  0,  0],
        [78, 83, 78,  ...,  0,  0,  0],
        [99, 79, 75,  ...,  0,  0,  0]]), tensor([86, 76, 66, 51, 50, 35, 22, 20, 16, 14, 14,  9]), tensor([[[ -9.2433,  -8.2504,  -7.7021,  ...,   0.0000,   0.0000,   0.0000],
         [ -8.8916,  -8.1060,  -7.6521,  ...,   0.0000,   0.0000,   0.0000],
         [ -9.2697,  -8.5779,  -8.2722,  ...,   0.0000,   0.0000,   0.0000],
         ...,
         [-11.5129, -11.3297, -10.4252,  ...,   0.0000,   0.0000,   0.0000],
         [-11.5129, -11.2788, -10.6261,  ...,   0.0000,   0.0000,   0.0000],
         [-11.5129, -11.5129, -10.5454,  ...,   0.0000,   0.0000,   0.0000]],

        [[-11.5129, -11.5129, -10.1166,  ...,  -8.9449, -10.6371, -11.5129],
         [-11.5129, -11.5129, -10.0481,  ...,  -8.8486, -10.6160, -11.5129],
         [-11.5129, -11.5129, -10.7568,  ...,  -8.79

