## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

## Load model and dataset

In [4]:
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.config.forced_decoder_ids = None

normalizer = processor.tokenizer._normalize

In [5]:
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


## Create a 2-batch

In [6]:
sample_0 = ds[0]["audio"]
label_0 = normalizer(ds[0]["text"])  # normalize label
input_features_0 = processor(sample_0["array"], sampling_rate=sample_0["sampling_rate"], return_tensors="pt").input_features 

In [7]:
sample_1 = ds[1]["audio"]
label_1 = normalizer(ds[1]["text"])  # normalize label
input_features_1 = processor(sample_1["array"], sampling_rate=sample_1["sampling_rate"], return_tensors="pt").input_features 

In [8]:
label_0, label_1

('mister quilter is the apostle of the middle classes and we are glad to welcome his gospel',
 'nor is mister quilter is manner less interesting than his matter')

### Batch `input_features`

In [9]:
assert input_features_0.shape == input_features_1.shape

In [10]:
input_features = torch.concat([input_features_0, input_features_1], axis=0)
input_features.shape

torch.Size([2, 80, 3000])

### Batch `tokenized_label`

In [11]:
# Need to use `torch.LongTensor` because default dtype is float:
tokenized_label_0 = torch.LongTensor(processor.tokenizer(label_0, add_special_tokens=False).input_ids)
tokenized_label_1 = torch.LongTensor(processor.tokenizer(label_1, add_special_tokens=False).input_ids)

In [12]:
tokenized_label_0.shape, tokenized_label_1.shape

(torch.Size([20]), torch.Size([14]))

⚠️ The `input_features` share the same shape but this is not true for the tokenized sequences!

In [13]:
label_features = [{"input_ids": tokenized_label_0}, {"input_ids": tokenized_label_1}]
labels_batch = processor.tokenizer.pad(label_features, return_tensors="pt")  # type: ignore

In [14]:
labels_batch

{'input_ids': tensor([[   76,  1694,   627,   346,   353,   318,   262, 46329,   286,   262,
          3504,  6097,   290,   356,   389,  9675,   284,  7062,   465, 21443],
        [13099,   318,   285,  1694,   627,   346,   353,   318,  5642,  1342,
          3499,   621,   465,  2300, 50256, 50256, 50256, 50256, 50256, 50256]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])}

In [15]:
labels_batch["input_ids"].shape, labels_batch["attention_mask"].shape

(torch.Size([2, 20]), torch.Size([2, 20]))

## Forward

### Prepare inputs

In [16]:
input_features.shape, input_features.dtype

(torch.Size([2, 80, 3000]), torch.float32)

In [17]:
# Assume the teacher is a perfect model:
teacher_sequences = labels_batch["input_ids"]
teacher_sequences.shape, teacher_sequences.dtype

(torch.Size([2, 20]), torch.int64)

In [18]:
# Invert the attention mask:
attention_mask_labels = labels_batch["attention_mask"]
attention_mask_labels

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])

### Handle the special tokens

In [19]:
batch_size = teacher_sequences.size(0)

# Get prefix tokens:
forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")  # will take care of the EOS token as well
prefix_tokens = torch.IntTensor([processor.tokenizer.bos_token_id] + [token_id for idx, token_id in forced_decoder_ids])
prefix_tokens = prefix_tokens.expand(batch_size, -1)
n_prefix_tokens = prefix_tokens.shape[1]

# Get suffix tokens:
suffix_tokens = torch.IntTensor([processor.tokenizer.eos_token_id])
suffix_tokens = suffix_tokens.expand(batch_size, -1)

# Concatenate the prefix tensor with the original tensor along the second dimension
teacher_sequences_ = torch.cat((prefix_tokens, teacher_sequences, suffix_tokens), dim=1)
labels_ = torch.cat((prefix_tokens, teacher_sequences, suffix_tokens), dim=1)  # should be replaced with `labels` in the final code

teacher_sequences_.shape, labels_.shape

(torch.Size([2, 25]), torch.Size([2, 25]))

### Predict without attention mask

In [20]:
output_no_mask = model.forward(input_features=input_features,
                               decoder_input_ids=teacher_sequences_[:, :-1])  # don't predict when current token is EOS
logits_no_mask = output_no_mask.logits[:, n_prefix_tokens-1:, :]  # remove what the model tried to predict for the special tokens
logits_no_mask.shape

torch.Size([2, 21, 51864])

⚠️ To be used with categorical targets, `F.cross_entropy` needs to be used with a tensor for which the 2nd dimension is the class dimension. See [documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html) for more information.

In [21]:
F.cross_entropy(input=rearrange(logits_no_mask, pattern="b n v -> b v n"),
                target=teacher_sequences_[:, n_prefix_tokens:],
                ignore_index=processor.tokenizer.pad_token_id)

tensor(2.0745, grad_fn=<NllLoss2DBackward0>)

### Predict with attention mask

In [22]:
attention_prefix = torch.ones(batch_size, n_prefix_tokens)
attention_mask_labels_ = torch.cat((attention_prefix, attention_mask_labels), dim=1)

In [23]:
output_with_mask = model.forward(input_features=input_features,
                                 decoder_input_ids=teacher_sequences_[:, :-1],  # don't predict when current token is EOS
                                 decoder_attention_mask=attention_mask_labels_)
logits_with_mask = output_with_mask.logits[:, prefix_tokens.size(1)-1:, :]  # remove what the model tried to predict for the special tokens

F.cross_entropy(input=rearrange(logits_with_mask, pattern="b n v -> b v n"),
                target=teacher_sequences_[:, prefix_tokens.size(1):],
                ignore_index=processor.tokenizer.pad_token_id)

tensor(2.0745, grad_fn=<NllLoss2DBackward0>)

### Output comparison

In [24]:
logits_no_mask[1, -1, :]

tensor([5.6267, 8.0974, 5.5852,  ..., 4.2737, 4.8921, 2.6027],
       grad_fn=<SliceBackward0>)

In [25]:
logits_with_mask[1, -1, :]

tensor([20.0693, 19.4200, 14.3050,  ..., 12.6726, 12.9365, 10.4247],
       grad_fn=<SliceBackward0>)

In [26]:
logits_no_mask == logits_with_mask

tensor([[[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True]],

        [[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])

**Comments:**
- The first first rows ARE equal because our model uses a causal attention mechanism. Therefore, the attention doesn't have the chance to consider the pad tokens at the end of the sequence. Hence the similarity.
- The last rows differ for the same reason. Although we could just set them to 0 (which is something we will do eventually), it is good practice to mask them properly. Moreover, we can save some computation time here.

### Apply softmax to get the vocab probabilities per step and per example

In [27]:
output_log_prob = torch.nn.functional.log_softmax(logits_with_mask, dim=-1)  # (batch_size, n_tokens-1, vocab_size)
output_log_prob

tensor([[[-17.6328, -16.8028, -17.1586,  ..., -17.6936, -17.4344, -19.5560],
         [ -7.2255,  -8.6642, -10.1993,  ..., -10.0005, -10.4366, -11.5987],
         [-10.0087,  -8.3535, -13.8741,  ...,  -9.9771,  -9.5403, -11.5162],
         ...,
         [-13.8001, -12.3024, -16.6508,  ..., -18.9416, -18.2672, -19.6933],
         [-11.7289, -12.2270, -15.8037,  ..., -17.4969, -18.7981, -19.3218],
         [ -2.9698,  -5.3366, -12.4520,  ..., -15.4419, -15.0755, -17.9387]],

        [[-18.5667, -17.4100, -18.5190,  ..., -18.7813, -18.6986, -20.4263],
         [-13.3955, -13.2569, -16.5791,  ..., -15.3639, -15.4941, -16.4072],
         [ -9.1820,  -8.6636, -12.2782,  ..., -11.1930, -11.2660, -14.2472],
         ...,
         [ -7.0847,  -7.7393, -12.9708,  ..., -14.5783, -14.3619, -16.8758],
         [ -6.9128,  -7.0502, -12.9996,  ..., -14.9758, -14.7272, -17.2178],
         [ -7.1692,  -7.8185, -12.9335,  ..., -14.5660, -14.3020, -16.8139]]],
       grad_fn=<LogSoftmaxBackward0>)

### Set the values associated to the pad tokens to 0

First, let's get the vocabulary size.

In [28]:
n_vocab = output_with_mask.logits.shape[-1]
n_vocab

51864

Because we will sum the log-probabilities to make use of the product rule in the log space, a sufficient method to ignore the padded values is to set them to 0.

In [29]:
# Repeat attention_mask for the n_vocab dimension:
mask = attention_mask_labels_[:, n_prefix_tokens-1:, None].expand(-1, -1, n_vocab)
mask.shape

torch.Size([2, 21, 51864])

In [30]:
output_log_prob_masked = output_log_prob.masked_fill(mask.ne(1), 0)
output_log_prob_masked

tensor([[[-17.6328, -16.8028, -17.1586,  ..., -17.6936, -17.4344, -19.5560],
         [ -7.2255,  -8.6642, -10.1993,  ..., -10.0005, -10.4366, -11.5987],
         [-10.0087,  -8.3535, -13.8741,  ...,  -9.9771,  -9.5403, -11.5162],
         ...,
         [-13.8001, -12.3024, -16.6508,  ..., -18.9416, -18.2672, -19.6933],
         [-11.7289, -12.2270, -15.8037,  ..., -17.4969, -18.7981, -19.3218],
         [ -2.9698,  -5.3366, -12.4520,  ..., -15.4419, -15.0755, -17.9387]],

        [[-18.5667, -17.4100, -18.5190,  ..., -18.7813, -18.6986, -20.4263],
         [-13.3955, -13.2569, -16.5791,  ..., -15.3639, -15.4941, -16.4072],
         [ -9.1820,  -8.6636, -12.2782,  ..., -11.1930, -11.2660, -14.2472],
         ...,
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]],
       grad_fn=<MaskedFillBackward0>)

In [31]:
# Row 1 should be non-null and row 2 should be full of 0s:
output_log_prob_masked[:, -1, :]

tensor([[ -2.9698,  -5.3366, -12.4520,  ..., -15.4419, -15.0755, -17.9387],
        [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],
       grad_fn=<SliceBackward0>)

## Compute the log-probability of each sentence

In [32]:
log_prob_t_hat_step_wise = output_log_prob_masked.take_along_dim(teacher_sequences_[:, n_prefix_tokens:, None], dim=-1)
log_prob_t_hat_step_wise.shape

torch.Size([2, 21, 1])

In [33]:
log_prob_t_hat_step_wise.squeeze().sum(-1)

tensor([-30.6683, -43.4395], grad_fn=<SumBackward1>)

## Archive (still useful to understand the left-shifted prediction for generative models)

```python
START_OFFSET = 2  # we want to start transcription with "<|startoftranscript|><|notimestamps|>"

res = []
scores = []

for idx in range(START_OFFSET, tokenized_seq.shape[1]):  # we add 1 to finish the loop with the full sentence
    # One-step generation:
    output = model.forward(input_features=input_features,
                           decoder_input_ids=tokenized_seq[:, :idx])
    
    log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)
    
    output_tokenized_seq = torch.argmax(output.logits, dim=-1)
    # scores.append(output.logits[..., output_tokenized_seq])
    # scores.append(output.logits.take_along_dim(output_tokenized_seq[..., None], dim=-1))
    # scores.append(output.logits.take_along_dim(output_tokenized_seq[..., None], dim=-1))
    scores.append(log_prob_all.take_along_dim(tokenized_seq[:, idx]))  # add the score of the ground truth
    res.append(processor.tokenizer.batch_decode(output_tokenized_seq))
```

```
>[['<|notimestamps|> Mr'],
> ['<|notimestamps|> Mrister'],
> ['<|notimestamps|> Mrister Qu'],
> ['<|notimestamps|> Mrister Quil'],
> ['<|notimestamps|> Mrister Quilter'],
> ['<|notimestamps|> Mrister Quilter is'],
> ['<|notimestamps|> Mrister Quilter is the'],
> ['<|notimestamps|> Mrister Quilter is the apostle'],
> ['<|notimestamps|> Mrister Quilter is the apostle of'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes,'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to welcome'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to welcome his'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to welcome his gospel']]
```