## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
from functools import partial
from tqdm.auto import tqdm

import torch
from transformers import pipeline
from transformers.models.whisper import (WhisperTokenizer,
                                         WhisperTokenizerFast,
                                         WhisperFeatureExtractor,
                                         WhisperForConditionalGeneration)
from datasets import load_dataset
import evaluate

from dataloader.collator import DataCollatorSpeechSeq2SeqWithPadding
from dataloader.preprocessing_train.preprocessing import prepare_dataset_fct
from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP

if torch.cuda.is_available():
    device = "cuda:0"
elif torch.backends.mps.is_available():  # for Apple Silicon
    device = torch.device("mps")
else:
    device = "cpu"

## Load model

In [77]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path).to(device)
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")

model.generate = partial(model.generate, language="english", task="transcribe",
                         max_length=255, use_cache=True)

## Load LogitProcessor

In [78]:
from transformers.generation.logits_process import NoRepeatNGramLogitsProcessor

In [79]:
logit_processor = NoRepeatNGramLogitsProcessor(2)

## Load dataset

In [80]:
dataset_name = "ami_validation"

ds = EVAL_DATASET_NAME_TO_DATASET_GROUP[dataset_name]()[dataset_name]

if dataset_name != "librispeech_dummy":
    ds = ds.select([3680])



Found cached dataset ami (/Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5)


In [81]:
prepare_dataset = partial(prepare_dataset_fct, tokenizer=tokenizer, feature_extractor=feature_extractor)
ds = ds.map(lambda x: {"text": x["text"].lower()})
ds = ds.map(prepare_dataset, num_proc=4).with_format("pt")

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

num_proc must be <= 1. Reducing num_proc to 1 for dataset of size 1.


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

## Predict

In [82]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(tokenizer=tokenizer,
                                                     feature_extractor=feature_extractor,
                                                     replace_padded_with_loss_mask_for_labels=True,
                                                     discard_first_bos_token=True)

In [83]:
x = ds[0]
x

{'text': 'okay yeah yeah yeah yeah yeah that was really horrible',
 'audio': {'path': None,
  'array': tensor([-0.0080, -0.0015, -0.0049,  ...,  0.0082,  0.0081,  0.0004]),
  'sampling_rate': tensor(16000)},
 'input_features': tensor([[ 0.4937,  0.3521,  0.4399,  ..., -0.4473, -0.4473, -0.4473],
         [ 0.2797,  0.1397,  0.3037,  ..., -0.4473, -0.4473, -0.4473],
         [ 0.0123,  0.2742,  0.5504,  ..., -0.4473, -0.4473, -0.4473],
         ...,
         [-0.2306, -0.3587, -0.4473,  ..., -0.4473, -0.4473, -0.4473],
         [-0.3578, -0.4249, -0.4147,  ..., -0.4473, -0.4473, -0.4473],
         [-0.4052, -0.4152, -0.4409,  ..., -0.4473, -0.4473, -0.4473]]),
 'labels': tensor([50258, 50259, 50359, 50363, 26061,  1338,  1338,  1338,  1338,  1338,
           220,  6780,   390,   534,  9263, 50257])}

In [84]:
outputs = model.generate(x["input_features"][None, ...].to(device), output_scores=True, return_dict_in_generate=True)
tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

[' Okay, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah']

In [85]:
outputs.sequences[0, :20]

tensor([50258, 50259, 50359, 50363,  1033,    11,  1338,    11,  1338,    11,
         1338,    11,  1338,    11,  1338,    11,  1338,    11,  1338,    11],
       device='mps:0')

In [86]:
tokenizer.decode([1338, 11])

' yeah,'

In [87]:
inputs = data_collator([x]).to(device)

In [88]:
inputs

{'input_features': tensor([[[ 0.4937,  0.3521,  0.4399,  ..., -0.4473, -0.4473, -0.4473],
         [ 0.2797,  0.1397,  0.3037,  ..., -0.4473, -0.4473, -0.4473],
         [ 0.0123,  0.2742,  0.5504,  ..., -0.4473, -0.4473, -0.4473],
         ...,
         [-0.2306, -0.3587, -0.4473,  ..., -0.4473, -0.4473, -0.4473],
         [-0.3578, -0.4249, -0.4147,  ..., -0.4473, -0.4473, -0.4473],
         [-0.4052, -0.4152, -0.4409,  ..., -0.4473, -0.4473, -0.4473]]],
       device='mps:0'), 'labels': tensor([[50259, 50359, 50363, 26061,  1338,  1338,  1338,  1338,  1338,   220,
          6780,   390,   534,  9263, 50257]], device='mps:0')}

In [140]:
outputs = model.forward(**inputs)

In [141]:
outputs.logits

tensor([[[-0.7393,  0.1077,  3.5260,  ...,  1.8835,  2.4299,  3.7835],
         [-5.8741, -8.5107, -5.6013,  ..., -5.7708, -5.8595, -4.3501],
         [12.8666, 10.9578,  7.6026,  ...,  7.9113,  7.8484,  7.1011],
         ...,
         [ 5.2729,  5.3784,  2.5678,  ...,  1.8604,  0.6186, -2.2679],
         [ 6.4986,  5.0265,  3.4901,  ...,  1.9698,  0.2080, -2.4857],
         [26.8334, 23.0717, 20.2516,  ..., 16.7327, 16.7596, 13.8943]]],
       device='mps:0', grad_fn=<LinearBackward0>)

In [142]:
outputs.logits.shape

torch.Size([1, 15, 51865])

In [143]:
outputs.logits[0, :, [1338, 11]]

tensor([[ 1.2667,  0.1475],
        [-2.0751,  1.6884],
        [ 9.0733, 15.3917],
        [12.4683,  5.3053],
        [24.3392, 23.3065],
        [23.8486, 20.8278],
        [26.3616, 22.5004],
        [26.4366, 22.7905],
        [25.8284, 22.6519],
        [25.1259, 22.4605],
        [21.4763, 21.5335],
        [22.7957, 26.3061],
        [ 3.1814,  7.4455],
        [ 4.8735,  9.7000],
        [24.4710, 23.7138]], device='mps:0', grad_fn=<IndexBackward0>)

In [144]:
logit_processor = NoRepeatNGramLogitsProcessor(1)

In [145]:
batch_size, sequence_length, vocab_size = outputs.logits.shape

outputs_logits = outputs.logits.clone()

list_processed_logits = []
for idx in range(sequence_length):
    list_processed_logits.append(logit_processor(input_ids=inputs["labels"][:, :idx+1], scores=outputs_logits[:, idx, :]).reshape(batch_size, 1, vocab_size))

y = torch.cat(list_processed_logits, dim=1)
y

tensor([[[-0.7393,  0.1077,  3.5260,  ...,  1.8835,  2.4299,  3.7835],
         [-5.8741, -8.5107, -5.6013,  ..., -5.7708, -5.8595, -4.3501],
         [12.8666, 10.9578,  7.6026,  ...,  7.9113,  7.8484,  7.1011],
         ...,
         [ 5.2729,  5.3784,  2.5678,  ...,  1.8604,  0.6186, -2.2679],
         [ 6.4986,  5.0265,  3.4901,  ...,  1.9698,  0.2080, -2.4857],
         [26.8334, 23.0717, 20.2516,  ..., 16.7327, 16.7596, 13.8943]]],
       device='mps:0', grad_fn=<CatBackward0>)

In [146]:
outputs.logits.shape

torch.Size([1, 15, 51865])

In [147]:
outputs.logits[0, :, [1338, 11]]

tensor([[ 1.2667,  0.1475],
        [-2.0751,  1.6884],
        [ 9.0733, 15.3917],
        [12.4683,  5.3053],
        [24.3392, 23.3065],
        [23.8486, 20.8278],
        [26.3616, 22.5004],
        [26.4366, 22.7905],
        [25.8284, 22.6519],
        [25.1259, 22.4605],
        [21.4763, 21.5335],
        [22.7957, 26.3061],
        [ 3.1814,  7.4455],
        [ 4.8735,  9.7000],
        [24.4710, 23.7138]], device='mps:0', grad_fn=<IndexBackward0>)

In [148]:
y[:, [1338, 11]]

tensor([[[-3.3757, -2.9258, -0.3346,  ..., -4.5569, -2.7298, -0.9344],
         [22.2727, 27.1006, 19.2470,  ..., 19.7796, 19.2167, 15.9238]]],
       device='mps:0', grad_fn=<IndexBackward0>)

In [149]:
(outputs.logits == y).all()

tensor(False, device='mps:0')

In [150]:
batch_size, sequence_length, vocab_size = outputs.logits.shape

outputs_logits = outputs.logits.clone()

for idx in range(sequence_length):
    logit_processor(input_ids=inputs["labels"][:, :idx+1], scores=outputs_logits[:, idx, :])

In [152]:
outputs_logits[0, :, [1338, 11]]

tensor([[ 1.2667,  0.1475],
        [-2.0751,  1.6884],
        [ 9.0733, 15.3917],
        [12.4683,  5.3053],
        [   -inf, 23.3065],
        [   -inf, 20.8278],
        [   -inf, 22.5004],
        [   -inf, 22.7905],
        [   -inf, 22.6519],
        [   -inf, 22.4605],
        [   -inf, 21.5335],
        [   -inf, 26.3061],
        [   -inf,  7.4455],
        [   -inf,  9.7000],
        [   -inf, 23.7138]], device='mps:0', grad_fn=<IndexBackward0>)