# Lhotse support for datasets

> allows to leverage preliminary data prep from lhotse recipes

In [None]:
#| default_exp data.utils.lhotse

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

## TTS Lhotse

In [None]:
#| export
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader

from lhotse import CutSet, RecordingSet, SupervisionSet, Fbank, FbankConfig, MonoCut, NumpyFilesWriter, NumpyHdf5Writer
from lhotse.dataset import BucketingSampler, OnTheFlyFeatures, DynamicBucketingSampler
from lhotse.dataset.collation import TokenCollater
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.dataset.vis import plot_batch
from lhotse.recipes import download_librispeech, prepare_librispeech, download_ljspeech, prepare_ljspeech

from typing import Tuple, Dict
import json
import numpy as np

from nimrod.audio.embedding import EncoDecExtractor
from nimrod.text.normalizers import TTSTextNormalizer
from nimrod.text.phonemizers import Phonemizer


In [None]:
#| export
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

### Usage

#### Download data and load into Lhotse cuts

In [None]:
# download_ljspeech('~/Data/en/')
# skip this step already done
ljspeech = prepare_ljspeech('../data/en/LJSpeech-1.1', '../recipes/tts/ljspeech/data')

In [None]:
cut_set = CutSet.from_manifests(**ljspeech)
subset = cut_set.subset(first=3)
subset.to_file('../recipes/tts/ljspeech/data/first_3.jsonl.gz')
reload_subset = CutSet.from_file('../recipes/tts/ljspeech/data/first_3.jsonl.gz')

In [None]:
print(subset[1])
print(reload_subset[1])
print(len(subset))

#### Encodec feature extractor

In [None]:
encodec_extractor = EncoDecExtractor()

In [None]:
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)

In [None]:
# TODO: fix bug for n_jobs >1
cuts = subset.compute_and_store_features(
    extractor=encodec_extractor,
    storage_path="../recipes/tts/ljspeech/data/encodec",
    num_jobs=1,
    # storage_type=NumpyHdf5Writer
)

In [None]:
print(cuts[0])

In [None]:
cuts.to_file("../recipes/tts/ljspeech/data/first_3.encodec.jsonl.gz")
cuts[0]
reload_cuts = CutSet.from_file("../recipes/tts/ljspeech/data/first_3.encodec.jsonl.gz")
reload_cuts[0]

In [None]:
# cuts[0].recording
!soxi '../data/en/LJSpeech-1.1/wavs/LJ001-0001.wav'

In [None]:
strategy = PrecomputedFeatures()
feats, feats_len = strategy(cuts)

# print([(f"feat: {feat.shape}", f"len: {feat_len}") for feat in feats for feat_len in feats_len])
print([feat.shape for feat in feats])
print([int(feat_len) for feat_len in feats_len])
print(feats.shape, feats_len.shape)
# TODO: debug OnTheFlyFeature case
# strategy = OnTheFlyFeatures(extractor=encodec_extractor)
# feats, feats_len = strategy(cuts)
# print(feats, feats_len)

#### Text normalization, tokenization and numericalization

In [None]:
cleaner = TTSTextNormalizer()
tokenizer = Phonemizer()

In [None]:
cleaner("tutu. this is ture!")

In [None]:
n_jobs = 1
unique_phonemes = set()
with CutSet.open_writer('../recipes/tts/ljspeech/data/first_3.final.jsonl.gz', overwrite=True) as writer:
    for cut in cuts:
        text = cut.supervisions[0].text
        print(text)
        normalized = cleaner(text)
        print(normalized)
        phonemes = tokenizer(text)
        print(phonemes)
        cut.custom = {'normalized': normalized, 'phonemes': phonemes}
        writer.write(cut, flush=True)
        unique_phonemes.update(list(phonemes))


#### Export phoneme lexicon

In [None]:
cuts = CutSet.from_file("../data/en/LJSpeech-1.1/first_3.final.jsonl.gz")
print(cuts[0])
map = {}
unique_syms = set()
for cut in cuts:
    unique_syms.update(list(cut.custom['phonemes']))
for (i, v) in enumerate(sorted(list(unique_syms))):
    map[i] = v
map[len(map)] = "<eps>"
print(map, len(map))

json_map = json.dumps(map)
with open("../data/en/LJSpeech-1.1/map.json","w") as f:
    f.write(json_map)

In [None]:
with open('../data/en/LJSpeech-1.1/map.json', 'r') as f:
    data = json.load(f)

print(data)

#### Collate

In [None]:
#| export
class PhonemeCollater(TokenCollater):
    def __init__(
            self,  cuts: CutSet,
            add_eos: bool = True,
            add_bos: bool = True,
            pad_symbol: str = "<pad>",
            bos_symbol: str = "<bos>",
            eos_symbol: str = "<eos>",
            unk_symbol: str = "<unk>",
        ):
        super().__init__(
            cuts,
            add_eos=add_eos,
            add_bos=add_bos,
            pad_symbol=pad_symbol,
            bos_symbol=bos_symbol,
            eos_symbol=eos_symbol,
            unk_symbol=unk_symbol
            )
        tokens = {char for cut in cuts for char in cut.custom['phonemes']}
        tokens_unique = (
            [pad_symbol, unk_symbol]
            + ([bos_symbol] if add_bos else [])
            + ([eos_symbol] if add_eos else [])
            + sorted(tokens)
        )

        self.token2idx = {token: idx for idx, token in enumerate(tokens_unique)}
        self.idx2token = [token for token in tokens_unique]
    
    def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.Tensor]:
        token_sequences = [" ".join(cut.custom['phonemes']) for cut in cuts]
        max_len = len(max(token_sequences, key=len))
        seqs = [
            ([self.bos_symbol] if self.add_bos else [])
            + list(seq)
            + ([self.eos_symbol] if self.add_eos else [])
            + [self.pad_symbol] * (max_len - len(seq))
            for seq in token_sequences
        ]

        tokens_batch = torch.from_numpy(
            np.array(
                [[self.token2idx[token] for token in seq] for seq in seqs],
                dtype=np.int64,
            )
        )

        tokens_lens = torch.IntTensor(
            [
                len(seq) + int(self.add_eos) + int(self.add_bos)
                for seq in token_sequences
            ]
        )

        return tokens_batch, tokens_lens


In [None]:
cuts[0]

In [None]:
pc = PhonemeCollater(cuts)
tokens, tokens_len = pc(cuts)
print(tokens, tokens_len)
print(pc.inverse(tokens, tokens_len))

In [None]:
class ValleDataset(Dataset):
    def __init__(
            self,
            cuts:CutSet,
            strategy:BatchIO=PrecomputedFeatures()
        ):
        self.extractor = strategy
        self.tokenizer = PhonemeCollater(cuts)

    def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
        # getitem is on full cutset not just one cut like usual for pytorch datasets
        cuts = cuts.sort_by_duration()
        feats, feat_lens = self.extractor(cuts)
        tokens, token_lens = self.tokenizer(cuts)
        return {"feats_pad": feats, "feats_lens": feat_lens, "tokens_pad": tokens, "tokens_lens": token_lens}

In [None]:
ds = ValleDataset(cuts)
# Dataset performs batching by itself, so we have to indicate that to the DataLoader with batch_size=None
# train_sampler = BucketingSampler(cuts, max_duration=300, shuffle=True, bucket_method="equal_duration")
train_sampler = DynamicBucketingSampler(cuts, max_duration=300, shuffle=True, num_buckets=2)
dl = DataLoader(ds, sampler=train_sampler, batch_size=None, num_workers=0)
print(next(iter(dl)))

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()