## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

## Load model and dataset

In [4]:
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.config.forced_decoder_ids = None

normalizer = processor.tokenizer._normalize

In [5]:
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


## Create a 2-batch

In [6]:
sample_0 = ds[0]["audio"]
label_0 = normalizer(ds[0]["text"])  # normalize label
input_features_0 = processor(sample_0["array"], sampling_rate=sample_0["sampling_rate"], return_tensors="pt").input_features 

In [7]:
sample_1 = ds[1]["audio"]
label_1 = normalizer(ds[1]["text"])  # normalize label
input_features_1 = processor(sample_1["array"], sampling_rate=sample_1["sampling_rate"], return_tensors="pt").input_features 

In [8]:
label_0, label_1

('mister quilter is the apostle of the middle classes and we are glad to welcome his gospel',
 'nor is mister quilter is manner less interesting than his matter')

### Batch `input_features`

In [9]:
assert input_features_0.shape == input_features_1.shape

In [10]:
input_features = torch.concat([input_features_0, input_features_1], axis=0)
input_features.shape

torch.Size([2, 80, 3000])

### Batch `tokenized_label`

In [11]:
# Need to use `torch.LongTensor` because default dtype is float:
tokenized_label_0 = torch.LongTensor(processor.tokenizer(label_0, add_special_tokens=True).input_ids)
tokenized_label_1 = torch.LongTensor(processor.tokenizer(label_1, add_special_tokens=True).input_ids)

In [12]:
tokenized_label_0.shape, tokenized_label_1.shape

(torch.Size([23]), torch.Size([17]))

⚠️ The `input_features` share the same shape but this is not true for the tokenized sequences!

In [13]:
label_features = [{"input_ids": tokenized_label_0}, {"input_ids": tokenized_label_1}]
labels_batch = processor.tokenizer.pad(label_features, return_tensors="pt")  # type: ignore

In [14]:
labels_batch

{'input_ids': tensor([[50257, 50362,    76,  1694,   627,   346,   353,   318,   262, 46329,
           286,   262,  3504,  6097,   290,   356,   389,  9675,   284,  7062,
           465, 21443, 50256],
        [50257, 50362, 13099,   318,   285,  1694,   627,   346,   353,   318,
          5642,  1342,  3499,   621,   465,  2300, 50256, 50256, 50256, 50256,
         50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])}

In [15]:
labels_batch["input_ids"].shape, labels_batch["attention_mask"].shape

(torch.Size([2, 23]), torch.Size([2, 23]))

## Forward

### Prepare inputs

In [16]:
input_features.shape, input_features.dtype

(torch.Size([2, 80, 3000]), torch.float32)

In [17]:
# Assume the teacher is a perfect model:
teacher_sequences = labels_batch["input_ids"]
teacher_sequences.shape, teacher_sequences.dtype

(torch.Size([2, 23]), torch.int64)

In [18]:
labels_batch["attention_mask"].shape

torch.Size([2, 23])

### Predict without attention mask

In [19]:
output_no_mask = model.forward(input_features=input_features,
                               decoder_input_ids=teacher_sequences[:, :-1],
                               labels=teacher_sequences[:, 1:])
output_no_mask.loss

tensor(1.1817, grad_fn=<NllLossBackward0>)

### Predict with attention mask

In [20]:
output_with_mask = model.forward(input_features=input_features,
                                 decoder_input_ids=teacher_sequences[:, :-1],
                                 labels=teacher_sequences[:, 1:],
                                 decoder_attention_mask=labels_batch["attention_mask"][:, :-1])
output_with_mask.loss

tensor(1.1591, grad_fn=<NllLossBackward0>)

### Output comparison

Note how the 2 losses are different. We can then deduce that the logits are also different but let's display a few values for the 2nd example to confirm our assumption (the 1st example has no padded tokens because it is the longest of the two).

In [21]:
output_no_mask.logits[1]

tensor([[ 4.3879,  4.8095,  3.3221,  ...,  4.8180,  4.0716,  3.8354],
        [ 3.6629,  2.7278,  0.0581,  ...,  2.5407,  2.9286,  1.1074],
        [13.8955, 14.5302,  8.8014,  ..., 10.1489,  9.2304,  8.5802],
        ...,
        [10.9755, 10.8193,  8.5395,  ...,  6.0354,  5.9234,  3.9300],
        [10.9274, 11.0745,  9.0230,  ...,  6.5457,  6.4562,  4.4609],
        [10.0324, 10.8805,  8.8594,  ...,  6.5510,  6.4207,  4.4512]],
       grad_fn=<SelectBackward0>)

In [22]:
output_with_mask.logits[1]

tensor([[ 4.3879,  4.8095,  3.3221,  ...,  4.8180,  4.0716,  3.8354],
        [ 3.6629,  2.7278,  0.0581,  ...,  2.5407,  2.9286,  1.1074],
        [13.8955, 14.5302,  8.8014,  ..., 10.1489,  9.2304,  8.5802],
        ...,
        [19.6800, 18.7776, 13.9299,  ..., 11.1630, 11.0890,  8.8313],
        [19.2387, 18.4690, 13.6283,  ..., 11.0549, 10.9934,  8.7443],
        [18.5419, 17.7504, 13.1372,  ..., 10.5019, 10.4164,  8.1769]],
       grad_fn=<SelectBackward0>)

**Comments:**
- The first first rows ARE equal because our model uses a causal attention mechanism. Therefore, the attention doesn't have the chance to consider the pad tokens at the end of the sequence. Hence the similarity.
- The last rows differ for the same reason. Although we could just set them to 0 (which is something we will do eventually), it is good practice to mask them properly. Moreover, we can save some computation time here.

### Apply softmax to get the vocab probabilities per step and per example

In [23]:
output_log_prob = torch.nn.functional.log_softmax(output_with_mask.logits, dim=-1)  # (batch_size, n_tokens-1, vocab_size)
output_log_prob

tensor([[[-15.3003, -14.7483, -15.9119,  ..., -14.5757, -15.3342, -15.6124],
         [-11.7886, -12.7940, -14.2199,  ..., -15.2796, -16.0519, -18.6459],
         [-10.8017, -10.6024, -10.8728,  ..., -10.2841, -11.3689, -13.7109],
         ...,
         [-13.6878, -11.9338, -15.5179,  ..., -17.6529, -16.9946, -18.3999],
         [-12.3648, -13.3746, -15.7683,  ..., -17.4096, -18.9184, -19.5019],
         [ -3.2312,  -6.1660, -13.3379,  ..., -16.0522, -15.7055, -18.1209]],

        [[-14.8043, -14.3827, -15.8701,  ..., -14.3742, -15.1206, -15.3568],
         [-11.9963, -12.9314, -15.6011,  ..., -13.1185, -12.7306, -14.5518],
         [ -8.8980,  -8.2633, -13.9921,  ..., -12.6446, -13.5631, -14.2132],
         ...,
         [ -5.7692,  -6.6717, -11.5194,  ..., -14.2862, -14.3603, -16.6180],
         [ -6.1038,  -6.8735, -11.7142,  ..., -14.2876, -14.3490, -16.5981],
         [ -6.2201,  -7.0116, -11.6248,  ..., -14.2601, -14.3456, -16.5851]]],
       grad_fn=<LogSoftmaxBackward0>)

### Set the values associated to the pad tokens to 0

Because we will sum the log-probabilities to make use of the product rule in the log space, a sufficient method to ignore the padded values is to set them to 0.

In [24]:
n_vocab = output_with_mask.logits.shape[-1]
n_vocab

51864

In [25]:
# Repeat attention_mask for the n_vocab dimension:
mask = labels_batch["attention_mask"][:, :-1, None].expand(-1, -1, n_vocab)
mask.shape

torch.Size([2, 22, 51864])

In [26]:
output_log_prob_masked = output_log_prob.masked_fill(mask.ne(1), 0)
output_log_prob_masked

tensor([[[-15.3003, -14.7483, -15.9119,  ..., -14.5757, -15.3342, -15.6124],
         [-11.7886, -12.7940, -14.2199,  ..., -15.2796, -16.0519, -18.6459],
         [-10.8017, -10.6024, -10.8728,  ..., -10.2841, -11.3689, -13.7109],
         ...,
         [-13.6878, -11.9338, -15.5179,  ..., -17.6529, -16.9946, -18.3999],
         [-12.3648, -13.3746, -15.7683,  ..., -17.4096, -18.9184, -19.5019],
         [ -3.2312,  -6.1660, -13.3379,  ..., -16.0522, -15.7055, -18.1209]],

        [[-14.8043, -14.3827, -15.8701,  ..., -14.3742, -15.1206, -15.3568],
         [-11.9963, -12.9314, -15.6011,  ..., -13.1185, -12.7306, -14.5518],
         [ -8.8980,  -8.2633, -13.9921,  ..., -12.6446, -13.5631, -14.2132],
         ...,
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]],
       grad_fn=<MaskedFillBackward0>)

In [27]:
# Row 1 should be non-null and row 2 should be full of 0s:
output_log_prob_masked[:, -1, :]

tensor([[ -3.2312,  -6.1660, -13.3379,  ..., -16.0522, -15.7055, -18.1209],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
       grad_fn=<SliceBackward0>)

## Compute the log-probability of each sentence

In [28]:
log_prob_t_hat_step_wise = output_log_prob_masked.take_along_dim(teacher_sequences[:, 1:, None], dim=-1)
log_prob_t_hat_step_wise.shape

torch.Size([2, 22, 1])

In [29]:
log_prob_t_hat_step_wise.squeeze().sum(-1)

tensor([-19.2571, -30.6052], grad_fn=<SumBackward1>)