## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

## Reference

In [4]:
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.config.forced_decoder_ids = None

normalizer = processor.tokenizer._normalize

In [5]:
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
label = normalizer(ds[0]["text"])  # normalize label
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [6]:
# generate token ids
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)



In [7]:
# Compare pred and label after normalization
normalizer(transcription[0]), label

('mister quilter is the apostle of the middle classes and we are glad to welcome his gospel',
 'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel')

In [8]:
predicted_ids

tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
           286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
          7062,   465, 21443,    13, 50256]])

## Experiments

In [9]:
# Tokenize input sequence:
tokenized_seq = torch.tensor([processor.tokenizer(label, add_special_tokens=True).input_ids])
tokenized_seq

tensor([[50257, 50362,    76,  1694,   627,   346,   353,   318,   262, 46329,
           286,   262,  3504,  6097,   290,   356,   389,  9675,   284,  7062,
           465, 21443, 50256]])

In [10]:
processor.tokenizer.batch_decode(tokenized_seq)

['<|startoftranscript|><|notimestamps|>mister quilter is the apostle of the middle classes and we are glad to welcome his gospel<|endoftext|>']

In [11]:
# One-step generation:
print(processor.tokenizer.batch_decode(tokenized_seq[:, :2]))

output = model.forward(input_features=input_features,
                       decoder_input_ids=tokenized_seq[:, :2])

output_tokenized_seq = torch.argmax(output.logits, dim=-1)
processor.tokenizer.batch_decode(output_tokenized_seq)

['<|startoftranscript|><|notimestamps|>']


['<|notimestamps|> Mr']

## Compute $\mathcal{L}_{\mathrm{SEQ} - \mathrm{KD}}$

As a reminder: $\mathcal{L}_{\mathrm{SEQ} - \mathrm{KD}} \approx - \log p(\mathbf{t} = \hat{\mathbf{t}} \mid \mathbf{s})$

In [12]:
teacher_output = normalizer(label)  # assume the teacher perfectly transcribes all examples from LibriSpeech
# Tokenize input sequence:
tokenized_seq = torch.tensor([processor.tokenizer(teacher_output, add_special_tokens=True).input_ids])  # (1, n_tokens)

tokenized_seq.shape

torch.Size([1, 23])

**Note:** In practice, we should directly use `predicted_ids` obtained with `teacher_model.generate`.

In [13]:
# One-step generation:
output = model.forward(input_features=input_features,
                       decoder_input_ids=tokenized_seq[:, :-1])  # get rid of the EOT token "<|endoftext|>" as generation is supposed to stop here -> (1, n_tokens - 1)

log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)
log_prob_all.shape

torch.Size([1, 22, 51864])

In [14]:
log_prob_t_hat_step_wise = log_prob_all.take_along_dim(tokenized_seq[:, 1:, None], dim=-1)  # (1, n_tokens - 1)
log_prob_t_hat_step_wise

tensor([ -0.5336, -16.6798, -14.7979, -15.7406, -15.7755, -15.8578, -15.9635,
        -11.9060, -15.5229, -14.4615, -11.9060, -15.8558, -15.9108, -14.8024,
        -15.8403, -15.4526, -16.9435, -14.5130, -15.1455, -13.8468, -16.2474,
        -13.5256], grad_fn=<GatherBackward0>)

In [15]:
log_prob_t_hat = torch.sum(log_prob_t_hat_step_wise)
log_prob_t_hat

tensor(-317.2286, grad_fn=<SumBackward0>)

In [16]:
torch.exp(log_prob_t_hat)

tensor(0., grad_fn=<ExpBackward0>)

## Bonus: Recreate the `generate` behavior but with the gradient

Note that there is no need to recreate the `generate` behavior for the sequence-level KD as we are only interested in getting the score of $\mathbf{\hat t}$.

In [17]:
# Initialize the sequence
tokenized_seq = torch.tensor([processor.tokenizer("", add_special_tokens=True).input_ids])
tokenized_seq = tokenized_seq[:, :2]  # get rid of the EOT token "<|endoftext|>"
tokenized_seq

tensor([[50257, 50362]])

In [18]:
# One-step generation:
output = model.forward(input_features=input_features,
                       decoder_input_ids=tokenized_seq)

output_tokenized_seq = torch.argmax(output.logits, dim=-1)
processor.tokenizer.batch_decode(output_tokenized_seq)

['<|notimestamps|> Mr']

In [19]:
tokenized_seq.shape

torch.Size([1, 2])

In [20]:
torch.cat([tokenized_seq, output_tokenized_seq[:, -2:-1]], axis=-1)

tensor([[50257, 50362, 50362]])

## Archive (still useful to understand the left-shifted prediction for generative models)

```python
START_OFFSET = 2  # we want to start transcription with "<|startoftranscript|><|notimestamps|>"

res = []
scores = []

for idx in range(START_OFFSET, tokenized_seq.shape[1]):  # we add 1 to finish the loop with the full sentence
    # One-step generation:
    output = model.forward(input_features=input_features,
                           decoder_input_ids=tokenized_seq[:, :idx])
    
    log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)
    
    output_tokenized_seq = torch.argmax(output.logits, dim=-1)
    # scores.append(output.logits[..., output_tokenized_seq])
    # scores.append(output.logits.take_along_dim(output_tokenized_seq[..., None], dim=-1))
    # scores.append(output.logits.take_along_dim(output_tokenized_seq[..., None], dim=-1))
    scores.append(log_prob_all.take_along_dim(tokenized_seq[:, idx]))  # add the score of the ground truth
    res.append(processor.tokenizer.batch_decode(output_tokenized_seq))
```

```
>[['<|notimestamps|> Mr'],
> ['<|notimestamps|> Mrister'],
> ['<|notimestamps|> Mrister Qu'],
> ['<|notimestamps|> Mrister Quil'],
> ['<|notimestamps|> Mrister Quilter'],
> ['<|notimestamps|> Mrister Quilter is'],
> ['<|notimestamps|> Mrister Quilter is the'],
> ['<|notimestamps|> Mrister Quilter is the apostle'],
> ['<|notimestamps|> Mrister Quilter is the apostle of'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes,'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to welcome'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to welcome his'],
> ['<|notimestamps|> Mrister Quilter is the apostle of the middle classes, we are glad to welcome his gospel']]
```