# Speech to Text Datasets

> Speech to text datasets

In [None]:
#| default_exp audio.datasets.stt

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
#| export
import torch
from torch.utils.data import DataLoader, Dataset
from lightning import LightningDataModule, LightningModule
from matplotlib import pyplot as plt
from lhotse.dataset import BucketingSampler, OnTheFlyFeatures
from lhotse.dataset.collation import TokenCollater
from lhotse.recipes import download_librispeech, prepare_librispeech
from lhotse.dataset.vis import plot_batch
from lhotse import CutSet, RecordingSet, SupervisionSet, Fbank, FbankConfig
from pathlib import Path
from pprint import pprint

## Base class

In [None]:
#| export

class STTDataset(Dataset):
    def __init__(self,
        tokenizer:TokenCollater, # text tokenizer
        num_mel_bins:int=80 # number of mel spectrogram bins
        ):
        self.extractor = OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=num_mel_bins)))
        self.tokenizer = tokenizer

    def __getitem__(self, cuts: CutSet) -> dict:
        cuts = cuts.sort_by_duration()
        feats, feat_lens = self.extractor(cuts)
        tokens, token_lens = self.tokenizer(cuts)
        return {"feats_pad": feats, "ilens": feat_lens, "tokens_pad": tokens}



## LibriSpeech DataModule

In [None]:
#| export
class LibriSpeechDataModule(LightningDataModule):
    def __init__(self,
        target_dir="../data/en", # where data will be saved / retrieved
        dataset_parts="mini_librispeech", # either full librispeech or mini subset
        output_dir="../recipes/stt/librispeech/data", # where to save manifest
        num_jobs=1 # num_jobs depending on number of cpus available
    ):
        super().__init__()
        self.save_hyperparameters(logger=False)
        self.libri = {}

    def prepare_data(self,) -> None:
        download_librispeech(target_dir=self.hparams.target_dir, dataset_parts=self.hparams.dataset_parts)

    def setup(self, stage = None):
        self.libri = prepare_librispeech(Path(self.hparams.target_dir) / "LibriSpeech", dataset_parts=self.hparams.dataset_parts, output_dir=self.hparams.output_dir, num_jobs=self.hparams.num_jobs)
        if stage == "fit" or stage == None:
            self.cuts_train = CutSet.from_manifests(**self.libri["train-clean-5"])
            self.cuts_test = CutSet.from_manifests(**self.libri["dev-clean-2"])
            self.tokenizer = TokenCollater(self.cuts_train)
            self.tokenizer(self.cuts_test.subset(first=2))
            self.tokenizer.inverse(*self.tokenizer(self.cuts_test.subset(first=2)))
        if stage == "test":
            self.cuts_test = CutSet.from_manifests(**self.libri["dev-clean-2"])
            self.tokenizer = TokenCollater(self.cuts_test)
            self.tokenizer(self.cuts_test.subset(first=2))
            self.tokenizer.inverse(*self.tokenizer(self.cuts_test.subset(first=2)))

    def train_dataloader(self):
        train_sampler = BucketingSampler(self.cuts_train, max_duration=300, shuffle=True) #, bucket_method="equal_duration")
        return DataLoader(STTDataset(self.tokenizer), sampler=train_sampler, batch_size=None, num_workers=2)

    def test_dataloader(self):
        test_sampler = BucketingSampler(self.cuts_test, max_duration=400, shuffle=False) #, bucket_method="equal_duration")
        return DataLoader(STTDataset(self.tokenizer), sampler=test_sampler, batch_size=None, num_workers=2)

    @property
    def model_kwargs(self):
        return {
            "odim": len(self.tokenizer.idx2token),
        }

## Usage

In [None]:
dm = LibriSpeechDataModule(
    target_dir="../data/en", 
    dataset_parts="mini_librispeech",
    output_dir="../data/en/LibriSpeech",
    num_jobs=1
)

In [None]:
# skip this at export time to not waste time
# download
dm.prepare_data()

In [None]:
# libri = prepare_librispeech("../data/en/LibriSpeech", dataset_parts='mini_librispeech')

In [None]:
! rm ../data/en/LibriSpeech/*.gz

In [None]:
dm.setup(stage='test')

In [None]:
dm.cuts_test

In [None]:
recs = RecordingSet.from_file("../data/en/LibriSpeech/librispeech_recordings_dev-clean-2.jsonl.gz")
sup = SupervisionSet("../data/en/LibriSpeech/librispeech_supervisions_dev-clean-2.jsonl.gz")
print(len(recs),len(sup))

In [None]:
test_dl = dm.test_dataloader()
b = next(iter(test_dl))
print(b["feats_pad"].shape, b["tokens_pad"].shape, b["ilens"].shape)
plt.imshow(b["feats_pad"][0].transpose(0,1), origin='lower')

# dm.tokenizer.idx2token(b["tokens_pad"][0])
# dm.tokenizer.inverse(b["tokens_pad"][0], b["ilens"][0])

In [None]:
print(dm.cuts_test)
cut = dm.cuts_test[0]
# pprint(cut.to_dict())
cut.plot_audio()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()