## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
from functools import partial

import torch
from torch.utils.data import DataLoader

from transformers import WhisperTokenizerFast, WhisperFeatureExtractor, WhisperForConditionalGeneration
from datasets import load_dataset

from dataloader.collator import DataCollatorSpeechSeq2SeqWithPadding
from dataloader.preprocessing_train.preprocessing import prepare_dataset_fct, preprocess_dataset
from dataloader.collator import DataCollatorSpeechSeq2SeqWithPadding
from trainer.prompting import get_labels_with_prompt

## User input

## Load model

In [4]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
tokenizer = WhisperTokenizerFast.from_pretrained("openai/whisper-tiny", language="english", task="transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")

model.config.forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language="english", task="transcribe")

normalizer = tokenizer._normalize

## Load dataset

In [5]:
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


## Preprocess

In [6]:
prepare_dataset = partial(prepare_dataset_fct,
                          tokenizer=tokenizer,
                          feature_extractor=feature_extractor)
ds = ds.map(prepare_dataset, num_proc=4)

Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b/cache-03b95896c8a778b5_*_of_00004.arrow


## Create dataloader

In [7]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(tokenizer=tokenizer,
                                                     feature_extractor=feature_extractor,
                                                     return_attention_mask=True,
                                                     replace_padded_with_loss_mask_for_labels=False,
                                                     discard_first_bos_token=True)

dataloader = DataLoader(ds,
                        batch_size=2,
                        shuffle=False,
                        collate_fn=data_collator)

## EWC

In [8]:
x = next(iter(dataloader))

x.keys()

You're using a WhisperTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


dict_keys(['input_features', 'labels', 'attention_mask'])

In [9]:
x.labels.shape

torch.Size([2, 39])

In [10]:
x.labels

tensor([[50258, 50259, 50359, 50363,    44,  2343,  5568,  7246,  4620,  5568,
          6205,  5663,  5372, 28067,  2634, 11944,  5663, 32394,    35,  2634,
         12855, 19678,  2358,  8093, 15813, 22515, 16225,  6112,  8232,   343,
          3158,    34, 23344, 45470,   460,  4367,    47,  3158, 50257],
        [50258, 50259, 50359, 50363,    45,  2483,  6205,   376,  2343,  5568,
          7246,  4620,  5568,     6,    50, 15372, 24499,   441, 12268, 30219,
         14497,  3017,  3578,  1770, 45470,  5904,  5568, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257]])

In [11]:
outputs = model(**x)

logits = outputs.logits

In [12]:
outputs.keys()

odict_keys(['loss', 'logits', 'past_key_values', 'encoder_last_hidden_state'])

In [13]:
log_prob_all = torch.nn.functional.log_softmax(logits, dim=-1)
log_prob_all

tensor([[[-24.9806, -26.2099, -20.7334,  ..., -22.3784, -20.9287, -21.0396],
         [-31.5374, -35.2910, -30.8864,  ..., -32.0047, -32.2249, -29.0637],
         [-28.3196, -31.2453, -27.7084,  ..., -28.6791, -28.9073, -26.4736],
         ...,
         [-11.4710, -10.5310, -11.0410,  ..., -13.1237, -13.4056, -13.2360],
         [ -7.2245,  -7.4540, -12.2601,  ..., -11.0964, -11.6803, -12.4698],
         [ -2.3228,  -6.5579, -10.2790,  ..., -11.7745, -11.6631, -13.3594]],

        [[-22.5718, -23.0225, -18.2253,  ..., -20.4687, -18.9090, -19.0197],
         [-26.7977, -31.3636, -29.3994,  ..., -28.4255, -28.9399, -26.4021],
         [-29.9452, -32.8312, -29.4124,  ..., -30.4019, -30.6734, -28.2721],
         ...,
         [ -6.9904,  -4.5795,  -7.5210,  ..., -10.2067, -10.2836, -13.1559],
         [ -7.1982,  -4.9151,  -7.4744,  ...,  -9.8854, -10.0806, -13.3250],
         [ -6.9058,  -4.6428,  -7.4882,  ..., -10.0364, -10.2691, -13.5632]]],
       grad_fn=<LogSoftmaxBackward0>)

In [14]:
log_prob_all.shape

torch.Size([2, 39, 51865])

In [15]:
log_prob = log_prob_all.take_along_dim(x.labels[..., None], dim=-1).squeeze().sum(dim=-1)
log_prob.shape

torch.Size([2])

In [16]:
log_likelihood = torch.mean(log_prob)
log_likelihood

tensor(-114.8479, grad_fn=<MeanBackward0>)

In [17]:
grad_log_likelihood = torch.autograd.grad(log_likelihood, model.parameters())
len(grad_log_likelihood)

167

In [18]:
count = 0
for (name, param), grad_param in zip(model.named_parameters(), grad_log_likelihood):
    if count > 5:
        break
    print(name, param.shape, grad_param.shape)
    count += 1

model.encoder.conv1.weight torch.Size([384, 80, 3]) torch.Size([384, 80, 3])
model.encoder.conv1.bias torch.Size([384]) torch.Size([384])
model.encoder.conv2.weight torch.Size([384, 384, 3]) torch.Size([384, 384, 3])
model.encoder.conv2.bias torch.Size([384]) torch.Size([384])
model.encoder.embed_positions.weight torch.Size([1500, 384]) torch.Size([1500, 384])
model.encoder.layers.0.self_attn.k_proj.weight torch.Size([384, 384]) torch.Size([384, 384])
