## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/home/tw581/mlmi_dissertation/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
from utils.initialize import initialize_env, print_envs
initialize_env()

In [23]:
from typing import Dict, Any

import torch

assert torch.cuda.is_available(), "This script requires a GPU."
device = torch.device("cuda:0")

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

from dataloader.smart_load_dataset_dict import smart_load_dataset_dict
from utils.distil_config import DistilConfig
from utils.constants import GEN_MAX_LENGTH


In [5]:
config = DistilConfig.from_yaml("configs/distill_configs/seq_level_k_best_uniform/distil_base_to_tiny-seq_level_k_best_uniform-k_3-debug.yaml")
config

DistilConfig(experiment_name='distil_whisper_base_to_tiny-seq_level_k_best_uniform-k_3-debug', lang_name='english', task='transcribe', method='seq_level_k_best_uniform', teacher_model_name_or_path='openai/whisper-base', student_model_name_or_path='openai/whisper-tiny', is_tokenizer_multilingual=True, model_dir='./checkpoints/distillation/whisper_base_to_tiny/librispeech_debug/seq_level_k_best_uniform/k_3/', freeze_encoder=True, freeze_decoder=False, batch_size=16, gradient_accumulation_steps=4, eval_accumulation_steps=None, gradient_checkpointing=True, data_augmentation=False, dataset_name='librispeech_dummy', force_reprocess_dataset=False, optim='adamw_torch', learning_rate=1e-05, warmup_steps=5, eval_steps=5, generation_num_beams=3, save_steps=100, save_total_limit=2, logging_steps=5, num_train_epochs=10, early_stopping_patience=-1, ce_alpha=0.5, temperature=None, distillation_num_beams=3, decay_beta=1.0, smart_load=True, force_reprocess_k_best=False, log_preds_to_wandb=True, n_sampl

In [21]:
# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.config.forced_decoder_ids = None

normalizer = processor.tokenizer._normalize

In [6]:
# Load processor (contains both tokenizer and feature extractor)
processor = WhisperProcessor.from_pretrained(
    config.teacher_model_name_or_path,
    language=config.lang_name,
    task=config.task
)

In [7]:
dataset_dict = smart_load_dataset_dict(config=config, processor=processor)

Previously preprocessed dataset found at `/home/tw581/rds/hpc-work/preprocessed_datasets/librispeech_dummy/multilingual_tokenizer`. Loading from disk...


In [8]:
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id', 'input_features', 'labels'],
        num_rows: 72
    })
    validation: Dataset({
        features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id', 'input_features', 'labels'],
        num_rows: 73
    })
    test: Dataset({
        features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id', 'input_features', 'labels'],
        num_rows: 73
    })
})

In [52]:
num_beams = 3


def get_k_beam_features(batch: Dict[str, Any]) -> Dict[str, Any]:
    """
    Utility to create K-Beam features for a dataset.
    """
    
    batch_size = len(batch)
    
    # Note that we need to move the data to the device manually (which is not the case with Trainer):
    # input_features = data["input_features"].to(device)  # type: ignore
    input_features = batch["input_features"]

    # Generate teacher predictions using K-beam search:
    outputs = model.generate(input_features,  # type: ignore
                             max_length=GEN_MAX_LENGTH,
                             num_beams=num_beams,
                             num_return_sequences=num_beams,
                             output_scores=True,
                             return_dict_in_generate=True)
    
    # outputs.sequences -> (batch_size * num_beams, n_tokens)
    # outputs.sequences_scores -> (batch_size * num_beams,)
    
    batch["sequences"] = list(torch.split(outputs.sequences,
                                     split_size_or_sections=num_beams,
                                     dim=0))
    batch["sequences_scores"] = list(torch.split(outputs.sequences_scores,
                                            split_size_or_sections=num_beams,
                                            dim=0))
    
    return batch

In [28]:
ds_test = dataset_dict.with_format("pt")["test"]

In [48]:
ds_test.features

{'file': Value(dtype='string', id=None),
 'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'text': Value(dtype='string', id=None),
 'speaker_id': Value(dtype='int64', id=None),
 'chapter_id': Value(dtype='int64', id=None),
 'id': Value(dtype='string', id=None),
 'input_features': Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}

In [53]:
ds_test_ = ds_test.map(get_k_beam_features, batched=True, batch_size=2)

Map:   0%|          | 0/73 [00:00<?, ? examples/s]

In [55]:
ds_test_.features

{'file': Value(dtype='string', id=None),
 'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'text': Value(dtype='string', id=None),
 'speaker_id': Value(dtype='int64', id=None),
 'chapter_id': Value(dtype='int64', id=None),
 'id': Value(dtype='string', id=None),
 'input_features': Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'sequences': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'sequences_scores': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None)}

In [56]:
ds_test_["sequences"][0]

tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
           286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
          7062,   465, 21443,    13, 50256],
        [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
           286,   262,  3504,  6097,   290,   356,   389,  9675,   284,  7062,
           465, 21443,    13, 50256, 50256],
        [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
           286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
          7062,   465, 23244,    13, 50256]])

In [43]:
ds_test_["sequences_scores"][0]

tensor([-0.1096, -0.1257, -0.1332])

In [None]:
def get_batched_k_beam_search_output_from_inputs(inputs,
                                                 col_id: str,
                                                 distillation_num_beams: int,
                                                 id_to_k_beam_search_output: Dict[str, BeamSearchEncoderDecoderOutput]) -> BeamSearchEncoderDecoderOutput:
    """
    This function is used to get the K-Beam search output for a batch of inputs using the pre-computed K-beam results
    in `id_to_k_beam_search_output`. The returned object should be strictly identical to the output of `generate`
    on the batched `inputs` tensor.
    """
    
    batch_size, n_tokens = inputs.shape  # n_tokens is such that 1 <= n_tokens <= GEN_MAX_LENGTH
    beam_search_size = id_to_k_beam_search_output[col_id].sequences.shape[0]
    
    # Sanity checks:
    assert col_id in inputs.features, f"Column `{col_id}` not found in inputs."
    assert distillation_num_beams <= beam_search_size, \
        f"Invalid `distillation_num_beams` value `{distillation_num_beams}`. Must be <= `{beam_search_size}`."
    
    
    # Initialize the output tensors:
    sequences = torch.zeros((batch_size * distillation_num_beams, n_tokens), dtype=torch.long, device=device)  # (batch_size * distillation_num_beams, n_tokens)
    sequences_scores = torch.zeros((batch_size * distillation_num_beams,), dtype=torch.float, device=device)  # (batch_size * distillation_num_beams,)
    
    # Loop over the batch:
    for idx, sample in enumerate(inputs):
        # Get the inputs for the current sample:
        sample_id = sample[col_id]  # TODO: str or int???
        
        # Get the K-beam search output for the current sample:
        k_beam_search_output = id_to_k_beam_search_output[sample_id]
        
        # Get the sequence and its score:
        sequence = k_beam_search_output.sequences  # (beam_search_size, n_tokens)
        sequence_scores = k_beam_search_output.sequences_scores  # (beam_search_size,)
        
        # Store the sequence and its score in their respective tensors:
        sequences[idx:idx+distillation_num_beams, :len(sequence)] = sequence[:distillation_num_beams, :]
        sequences_scores[idx:idx+distillation_num_beams] = sequence_scores[:distillation_num_beams]  # type: ignore
    
    return BeamSearchEncoderDecoderOutput(sequences=sequences, sequences_scores=sequences_scores)  # type: ignore