# Imports

In [2]:
import time
import os
import pickle
import torch
import speechbrain as sb
from speechbrain.dataio.dataio import read_audio
from loquacious_set_prepare import load_datasets

# from tqdm import tqdm
from tqdm.notebook import tqdm
from hyperpyyaml import load_hyperpyyaml
from speechbrain.dataio.sampler import DynamicBatchSampler

import torchaudio
import torchaudio.transforms as T

import pandas as pd
import numpy as np
from sklearn.cluster import MiniBatchKMeans, KMeans
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


# !pip install -U transformers


# Get dataloader

## load hparams and define data_prep function

In [3]:
hparams_file = "/local_disk/apollon/rwhetten/sss_data_selection/hparams/data_select.yaml"
overrides = {
    "tls_subset": "small",
    "hf_hub": "speechbrain/LoquaciousSet",
    "hf_caching_dir": "/local_disk/apollon/rwhetten/hf_root/datasets",
    "save_int": 5,
    "ckpt_path": "ckpt.pkl",
    "feature_function_name": "mel",
}

with open(hparams_file, encoding="utf-8") as fin:
        hparams = load_hyperpyyaml(fin, overrides)

# tls_subset="small"
# hf_hub="speechbrain/LoquaciousSet"
# hf_caching_dir="/local_disk/apollon/rwhetten/hf_root/datasets"

hf_data_dict = load_datasets(
    hparams["tls_subset"],
    hparams["hf_hub"],
    hparams["hf_caching_dir"],
)

Using the latest cached version of the dataset since speechbrain/LoquaciousSet couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'small' at /local_disk/apollon/rwhetten/hf_root/datasets/speechbrain___loquacious_set/small/0.0.0/720eb654f18f115053c4aea052133234a8bb34b7 (last modified on Mon Jul 14 14:10:50 2025).


Loading dataset shards:   0%|          | 0/66 [00:00<?, ?it/s]

In [4]:
def data_prep(data_dict, hparams):
    # We must rename the 'id' column because SpeechBrain sampling use this
    # name for the sampler already, also it's not an id, but an audio_path.
    train_data = hf_data_dict["train"].rename_column("ID", "audio_id")
    # create list of durations for the dynamic batch sampler, for speed
    train_len_list = list(train_data.select_columns("duration")["duration"])
    # create dataset obj
    train_data = sb.dataio.dataset.DynamicItemDataset.from_arrow_dataset(
        train_data,
    )

    datasets = [train_data]

    # create and add pipeline to datasets
    @sb.utils.data_pipeline.takes("wav")
    @sb.utils.data_pipeline.provides("sig")
    def audio_pipeline(wav):
        sig = read_audio(wav["bytes"])
        return sig

    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)

    sb.dataio.dataset.set_output_keys(
        datasets,
        ["id", "audio_id", "sig"],
    )

    # for now just working with one dataset
    dataset = datasets[0]
    dynamic_hparams_train = hparams["dynamic_batch_sampler_train"]

    # create dynamic batch sampler
    train_batch_sampler = DynamicBatchSampler(
        train_data,
        length_func=lambda x: x["duration"],
        lengths_list=train_len_list,
        **dynamic_hparams_train,
    )

    train_loader_kwargs = {
        "batch_sampler": train_batch_sampler,
        "num_workers": hparams["num_workers"],
    }

    # create dataloader
    dataloader = sb.dataio.dataloader.make_dataloader(
        dataset, **train_loader_kwargs
    )

    return dataloader

## do data prep

In [181]:
dl = data_prep(hf_data_dict, hparams)


In [182]:
from speechbrain.utils.checkpoints import Checkpointer
import yaml


In [183]:
@sb.utils.checkpoints.register_checkpoint_hooks
class Step:
    def __init__(self,step=0,end=0):
        self.step = step
        self.end = end

    def __call__(self):
        self.step = self.step + 1

    def __str__(self):
        return f"step count: {self.step}"
        
    @sb.utils.checkpoints.mark_as_saver
    def _save(self, path):
        save_dict = {
            "step": self.step,
            "end": self.end,
        }
        with open(path, "w", encoding="utf-8") as w:
            w.write(yaml.dump(save_dict))

    @sb.utils.checkpoints.mark_as_loader
    def _recover(self, path, end_of_epoch):
        del end_of_epoch
        with open(path, encoding="utf-8") as f:
            save_dict = yaml.safe_load(f)
        self.step = save_dict["step"]
        self.end = save_dict["end"]

In [184]:
save_path = "/local_disk/apollon/rwhetten/sss_data_selection/save_dl/"
step = Step(0)
checkpointer = Checkpointer(save_path, {"dataloader": dl, "step": step})

In [185]:
for i, batch in enumerate(tqdm(dl)):
    print(f"{i}, {batch.id[0]}")
    step()
    if i == 5:
        break
    if i == 3:
        print('saving')
        _ = checkpointer.save_checkpoint(end_of_epoch = False)

  0%|          | 0/3174 [00:00<?, ?it/s]

0, 70412
1, 83
2, 11028
3, 41
saving
4, 15375
5, 7825


In [201]:
ndl = data_prep(hf_data_dict, hparams)
nstep=Step()
ncheckpointer = Checkpointer(save_path, {"dataloader": ndl, "step": nstep})


In [202]:
_ = ncheckpointer.recover_if_possible()

100
else


In [203]:
print(nstep)

step count: 100


In [204]:
ndl._speechbrain_recovery_skip_to

100

In [205]:
with tqdm(
    ndl,
    initial=nstep.step,
    dynamic_ncols=True,
) as t:
    for batch in t:
        # print(nstep)
        # print(f"{nstep.step}, {batch.id[0]}")
        nstep()
        if nstep.step == 200:
            break
        # if i == 25:
        #     _ = ncheckpointer.save_checkpoint(end_of_epoch = False)

  3%|##8                                                                                         | 100/3174 [0…

In [None]:
assert next(iter(ndl)) == dataset[4]