## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [26]:
from typing import Dict, Any
from functools import partial
from pprint import pprint

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

from dataloader.smart_load_dataset_dict import smart_load_dataset_dict
from utils.distil_config import DistilConfig
from utils.constants import GEN_MAX_LENGTH, DEFAULT_LABEL_STR_COL, DEFAULT_LABEL_TOKENIZED_COL

In [29]:
device = torch.device('mps')

## Preliminary

In [6]:
config = DistilConfig.from_yaml("configs/distil_configs/debug/distil_base_to_tiny-seq_level_k_best_uniform-k_3-debug.yaml")
pprint(config)

DistilConfig(experiment_name='distil_whisper_base_to_tiny-seq_level_k_best_uniform-k_3-debug',
             lang_name='english',
             task='transcribe',
             method_distil='seq_level_k_best_uniform',
             teacher_model_name_or_path='openai/whisper-base',
             student_model_name_or_path='openai/whisper-tiny',
             is_tokenizer_multilingual=True,
             model_dir='./checkpoints/distillation/whisper_base_to_tiny/librispeech_debug/seq_level_k_best_uniform/k_3/',
             freeze_encoder=True,
             freeze_decoder=False,
             batch_size=32,
             gradient_accumulation_steps=1,
             gradient_checkpointing=True,
             dataset_name='librispeech_dummy',
             optim='adamw_torch',
             learning_rate=1e-05,
             warmup_steps=5,
             eval_steps=10,
             generation_num_beams=1,
             save_steps=10000,
             logging_steps=10,
             num_train_epochs=10,
   

In [34]:
# Load model:
model = WhisperForConditionalGeneration.from_pretrained(config.student_model_name_or_path).to(device)

# Load processor (contains both tokenizer and feature extractor):
processor = WhisperProcessor.from_pretrained(
    config.student_model_name_or_path,
    language=config.lang_name,
    task=config.task
)
normalizer = processor.tokenizer._normalize

# Disable zero-shot:
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language=config.lang_name, task=config.task)

In [35]:
# Load dataset:
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [38]:
num_beams = 2

def get_k_beam_features(batch: Dict[str, Any], processor) -> Dict[str, Any]:
    """
    Utility to create K-Beam features for a dataset.
    """
    
    batch_size = len(batch)
    
    # ===== Code from prepare_dataset_fct ====
    
    audio = batch["audio"]
    
    # Extract features from audio (including log-Mel input features):
    # Note: the sampling rate arg is redundant but required to dismiss warnings.
    batch["input_features"] = processor.feature_extractor(audio["array"],
                                                          sampling_rate=processor.feature_extractor.sampling_rate,
                                                          return_tensors="pt").input_features
    
    # Encode from target text to label ids:
    batch[DEFAULT_LABEL_TOKENIZED_COL] = processor.tokenizer(batch[DEFAULT_LABEL_STR_COL]).input_ids  # type: ignore
    
    # =========================================
    
    # Note that we need to move the data to the device manually (which is not the case with Trainer):
    # input_features = data["input_features"].to(device)  # type: ignore
    input_features = batch["input_features"]
    
    # Generate teacher predictions using K-beam search:
    outputs = model.generate(input_features.to(device),  # type: ignore
                             max_length=GEN_MAX_LENGTH,
                             num_beams=num_beams,
                             num_return_sequences=num_beams,
                             output_scores=True,
                             return_dict_in_generate=True)
    
    # outputs.sequences -> (batch_size * num_beams, n_tokens)
    # outputs.sequences_scores -> (batch_size * num_beams,)
    
    batch["sequences"] = list(torch.split(outputs.sequences,
                                     split_size_or_sections=num_beams,
                                     dim=0))
    batch["sequences_scores"] = list(torch.split(outputs.sequences_scores,
                                            split_size_or_sections=num_beams,
                                            dim=0))
    
    return batch

In [39]:
ds = ds.map(partial(get_k_beam_features, processor=processor), num_pro)

Map:   0%|          | 0/73 [00:00<?, ? examples/s]

  input_ids = input_ids.repeat_interleave(expand_size, dim=0)
  sent_lengths_max = sent_lengths.max().item() + 1


In [40]:
ds.features

{'file': Value(dtype='string', id=None),
 'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'text': Value(dtype='string', id=None),
 'speaker_id': Value(dtype='int64', id=None),
 'chapter_id': Value(dtype='int64', id=None),
 'id': Value(dtype='string', id=None),
 'input_features': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'sequences': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),
 'sequences_scores': Sequence(feature=Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None), length=-1, id=None)}

In [45]:
x = next(iter(ds))

In [47]:
x.keys()

dict_keys(['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id', 'input_features', 'labels', 'sequences', 'sequences_scores'])

In [60]:
tokenized_seq = torch.LongTensor(x["sequences"][0])
tokenized_seq

tensor([[50258, 50259, 50359, 50363,  2221,    13,  2326,   388,   391,   307,
           264, 50244,   295,   264,  2808,  5359,   293,   321,   366,  5404,
           281,  2928,   702, 14943,    13, 50257, 50257],
        [50258, 50259, 50359, 50363,  2221,    13,  2326,   388,   391,   307,
           264, 50244,   295,   264,  2808,  5359,    11,   293,   321,   366,
          5404,   281,  2928,   702, 14943,    13, 50257]])

In [62]:
processor.tokenizer.batch_decode(tokenized_seq)

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|><|endoftext|>',
 '<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.<|endoftext|>']

In [70]:
processor.tokenizer(x["text"])

{'input_ids': [50258, 50259, 50359, 50363, 44, 2343, 5568, 7246, 4620, 5568, 6205, 5663, 5372, 28067, 2634, 11944, 5663, 32394, 35, 2634, 12855, 19678, 2358, 8093, 15813, 22515, 16225, 6112, 8232, 343, 3158, 34, 23344, 45470, 460, 4367, 47, 3158, 50257], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}