In [9]:
%cd /home/adelmou/proj/speechbrain/gslms/speechbrain/recipes/LibriSpeech/SpeechLM/discrete


/home/adelmou/proj/speechbrain/gslms/speechbrain/recipes/LibriSpeech/SpeechLM/discrete


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [11]:
import os
import sys
import torch
import torchaudio
import logging
import speechbrain as sb
from speechbrain.utils.distributed import run_on_main, if_main_process
from hyperpyyaml import load_hyperpyyaml
from pathlib import Path
from speechbrain.lobes.models.huggingface_transformers.discrete_speechlm import (
    DiscreteSpeechLM,
    DiscreteSpeechLMConfig,
    InterleavedCodebookPattern,
)
from torch.nn import functional as F
import torch.nn as nn 
logger = logging.getLogger(__name__)

In [59]:
import torch
from torch.utils.data import Dataset
from copy import deepcopy

class PackedDatasetWrapper(Dataset):
    """Wrapper that packs tokens from an existing DynamicItemDataset."""
    
    def __init__(self, original_dataset, block_size, token_key="tokens", pad_token_id=-1):
        self.original_dataset = original_dataset
        self.block_size = block_size
        self.token_key = token_key
        self.pad_token_id = pad_token_id
        
        # Precompute the mapping from block index to original data indices
        self.blocks = []
        self.blocks_ids = []
        self._prepare_blocks()

    def _prepare_blocks(self):
        """
        Prepares the packed blocks by iterating through the original dataset,
        concatenating tokens until reaching `block_size`, and handling padding.
        """
        print("preparing blocks")
        # Generate list of indices
        indices = list(range(len(self.original_dataset)))

        buffer = []
        buffer_length = 0
        buffer_ids = []
        print(self.original_dataset[0]['tokens'][:, 0])
        for idx in indices:
            data_point = self.original_dataset[idx]
            tokens = data_point['tokens']  # Assuming 'tokens' is a list or tensor
            id_name = data_point.get('id', idx)  # Optional: fetch an 'id' if available

            # Convert tokens to tensor if they aren't already
            if isinstance(tokens, list):
                tokens = torch.tensor(tokens, dtype=torch.long)
            elif not isinstance(tokens, torch.Tensor):
                raise ValueError(f"Unsupported token type: {type(tokens)}")

            seq_length = tokens.size(0)
            if buffer_length + seq_length > self.block_size:
                if buffer:
                    num_tokens_to_keep = self.block_size - buffer_length
                    assert num_tokens_to_keep >= 0, num_tokens_to_keep
                    new_tokens = tokens[:num_tokens_to_keep]
                    buffer.append(new_tokens)
                    buffer_ids.append(id_name)
                    packed_tokens = torch.cat(buffer, dim=0)
                    self.blocks.append(packed_tokens)
                    self.blocks_ids.append(buffer_ids)
                    
                     # reset buffer
                    buffer = []
                    buffer_length = 0
                    buffer_ids = []

                    # calculate upper boundary of tokens to keep
                    splitted_tokens = tokens[num_tokens_to_keep:].split(
                        self.block_size
                    )
                    if len(splitted_tokens) > 0:
                        for t in splitted_tokens:
                            if t.shape[0] == self.block_size:
                                buffer_ids.append(id_name)
                                self.blocks.append(t)
                                self.blocks_ids.append(buffer_ids)

                    # Reset buffer and add the remaining tokens
                    buffer = [splitted_tokens[-1]]
                    buffer_ids = [id_name]
                    buffer_length = splitted_tokens[-1].size(0)
                else:
                    # If single data point exceeds block_size, split it
                    new_tokens = tokens[:self.block_size]
                    buffer_ids.append(id_name)
                    self.blocks.append(new_tokens)
                    self.blocks_ids.append(buffer_ids)

                    # reset buffer
                    buffer = []
                    buffer_length = 0
                    buffer_ids = []

                    # calculate upper boundary of tokens to keep
                    splitted_tokens = tokens[self.block_size:].split(
                        self.block_size
                    )
                    if len(splitted_tokens) > 0:
                        for t in splitted_tokens:
                            if t.shape[0] == self.block_size:
                                buffer_ids.append(id_name)
                                self.blocks.append(t)
                                self.blocks_ids.append(buffer_ids)

                    buffer = [splitted_tokens[-1]]
                    buffer_ids = [id_name]
                    buffer_length = splitted_tokens[-1].size(0) # - self.block_size
            else:
                buffer.append(tokens)
                buffer_ids.append(id_name)
                buffer_length += seq_length

        # Handle remaining tokens in the buffer
        if buffer:
            concatenated = torch.cat(buffer, dim=0)
            buffer_ids.append(id_name)
            if buffer_length < self.block_size:
                padding_length = self.block_size - buffer_length
                padded_buffer = torch.cat([
                    concatenated,
                    torch.full((padding_length, concatenated.size(1)), self.pad_token_id, dtype=concatenated.dtype)
                ], dim=0)
                self.blocks.append(padded_buffer)
                self.blocks_ids.append(buffer_ids)
            else:
                self.blocks.append(concatenated)
                self.blocks_ids.append(buffer_ids)

    def __len__(self) -> int:
        return len(self.blocks)

    def __getitem__(self, idx: int) -> torch.Tensor:
        data_id = '-'.join(self.blocks_ids[idx])
        data_point = self.blocks[idx]
        return {
            "id": data_id,
            "tokens": data_point,
        }

In [60]:
def dataio_prepare(hparams):
    """This function prepares the datasets to be used in the brain class.
    It also defines the data processing pipeline through user-defined functions."""
    data_folder = "test"


    valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
        csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
    )
    valid_data = valid_data.filtered_sorted(sort_key="duration")


    datasets = [valid_data]

    # 1. Define tokens pipeline:
    tokens_loader = hparams["tokens_loader"]
    num_codebooks = hparams["num_codebooks"]
    
    @sb.utils.data_pipeline.takes("id")
    @sb.utils.data_pipeline.provides("tokens")
    def tokens_pipeline(id):
        # (T, C)
        tokens = tokens_loader.tokens_by_uttid(id, num_codebooks=num_codebooks)
        # concat eos_token to the end of the sequence
        eos_token = torch.full((1, tokens.size(-1)), hparams['eos_token'], dtype=tokens.dtype, device=tokens.device)
        tokens = torch.cat([tokens, eos_token], dim=0)
        return tokens
        
    sb.dataio.dataset.add_dynamic_item(datasets, tokens_pipeline)

    # 2. Set output:
    sb.dataio.dataset.set_output_keys(
        datasets, ["id", "tokens"],
    )

    valid_data = PackedDatasetWrapper(valid_data, 32, token_key="tokens", pad_token_id=hparams["pad_token"])
    return valid_data


In [63]:
valid_data = dataio_prepare(hparams)
dataloader = sb.dataio.dataloader.make_dataloader(
        valid_data,
        batch_size=3,
        num_workers=0,
        collate_fn=lambda x: sb.dataio.batch.PaddedBatch(x, padding_kwargs={"value": hparams["pad_token"]})
)


for i, batch in enumerate(dataloader):
    tokens, _ = batch['tokens']
    print(tokens[:, :, 0])
    break

preparing blocks
tensor([ 256,  907,  488,  153,    1,  478,  706,  170,  175,  859,  284,  324,
         198,  726,  958,  330,  734,  942,  379,  640,  256,  256,  726,  369,
         135,  890,  760,  241,   57,  894,  387,  829,  464,  984,  241,  441,
         462,  965,  109,  242,  209,  241,  216,  324,  100,  771,  563,  635,
        1017,  598,  178,  730,  750,  867,  324,  178,   32,  896,  484,  320,
         256,  759,  163,  726,  151,  192,  958,  330,  175,  817,  367,  634,
         949, 1024])
tensor([[ 256,  907,  488,  153,    1,  478,  706,  170,  175,  859,  284,  324,
          198,  726,  958,  330,  734,  942,  379,  640,  256,  256,  726,  369,
          135,  890,  760,  241,   57,  894,  387,  829],
        [ 464,  984,  241,  441,  462,  965,  109,  242,  209,  241,  216,  324,
          100,  771,  563,  635, 1017,  598,  178,  730,  750,  867,  324,  178,
           32,  896,  484,  320,  256,  759,  163,  726],
        [ 151,  192,  958,  330,  175,  81

it seems to work. We see the first item splitted into multiple batches.

In [17]:
hparams = r"""
seed: 1986
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
output_folder: !ref results/LibriSpeech/discrete_speechLM/wavlm/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# Data files
tokens_folder: /scratch/adelmou/results/dac/librispeech
csv_folder: /home/adelmou/proj/speechbrain/gslms/speechbrain/recipes/LibriSpeech/SpeechLM/discrete/results/LibriSpeech/discrete_speechLM/wavlm/1986
train_csv: !ref <csv_folder>/train.csv
valid_csv: !ref <csv_folder>/dev-clean.csv
test_csv:
   - !ref <csv_folder>/test-clean.csv
   - !ref <csv_folder>/test-other.csv

# Training parameters
number_of_epochs: 20
lr: 3e-3
sorting: ascending
precision: bf16
skip_prep: True 
num_codebooks: 8 
sample_rate: 16000
codebook_size: 1024
block_size: 2048 
eos_token: !ref <codebook_size>
pad_token: !ref <codebook_size> + 1
vocabsize: !ref <codebook_size> + 2 # card(codebook) + eos + pad -> 1026

# aim for something like 500k tokens / BP step.
# 32 * 2048 * 8 ~= 500k tokens
batch_size: 32
grad_accumulation_factor: 8
test_batch_size: 1
num_workers: 0

### discrete SSL configuration

# Dataloader options
train_dataloader_opts:
  shuffle: True
  batch_size: !ref <batch_size>
  num_workers: !ref <num_workers>
  collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
      padding_kwargs:
          value: !ref <pad_token>
# todo: limit the num of steps on val
valid_dataloader_opts:
  batch_size: !ref <batch_size>
  num_workers: !ref <num_workers>
  collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
      padding_kwargs:
          value: !ref <pad_token>
test_dataloader_opts:
  batch_size: !ref <test_batch_size>
  # num_workers: !ref <num_workers>
  collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
      padding_kwargs:
          value: !ref <pad_token>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
   limit: !ref <number_of_epochs>

### Config for Tokenizer
tokens_loader: !new:utils.tokens.TokensLoader
   data_path: !ref <tokens_folder>

codebook_pattern: !new:speechbrain.lobes.models.huggingface_transformers.discrete_speechlm.InterleavedCodebookPattern
  audio_pad_token: !ref <pad_token>

# define LM in the YAML
config: !new:speechbrain.lobes.models.huggingface_transformers.discrete_speechlm.DiscreteSpeechLMConfig
  source: "HuggingFaceTB/SmolLM-135M"
  cache_dir: "/scratch/adelmou/hf_home/hub/"
  block_size: !ref <block_size>
  n_codebooks: !ref <num_codebooks>
  vocabsize: !ref <vocabsize>
  tie_embds: True

model: !new:speechbrain.lobes.models.huggingface_transformers.discrete_speechlm.DiscreteSpeechLM
  config: !ref <config>

lr_annealing: !new:speechbrain.nnet.schedulers.LinearScheduler
  initial_value: !ref <lr>
  final_value: 3e-5
  epoch_count: !ref <number_of_epochs>

opt_class: !name:torch.optim.AdamW
  lr: !ref <lr>
  betas: (0.9, 0.95)
   
modules:
  model: !ref <model>
"""
hparams = load_hyperpyyaml(hparams)

INFO:speechbrain.utils.seed:[rank: 0] Setting seed to 1986


vocabsize =  1026


In [21]:
valid_data = dataio_prepare(hparams)


prepare blocks
