In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

from models.whisper_zero_cross_attention import WhisperForConditionalGenerationZeroCrossAttention

## Reference

In [4]:
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.config.forced_decoder_ids = None

# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [5]:
# generate token ids
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
predicted_ids, transcription



(tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 21443,    13, 50256]]),
 [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'])

## Zero cross-attention

In [6]:
model_zero_cross_attention = WhisperForConditionalGenerationZeroCrossAttention.from_pretrained("openai/whisper-tiny.en")
model_zero_cross_attention.config.forced_decoder_ids = None

### Sanity check

In [7]:
tokenized_seq = torch.tensor([processor.tokenizer("Hello my name is", add_special_tokens=False).input_ids])
output = model_zero_cross_attention.forward(input_features=input_features,
                                            decoder_input_ids=tokenized_seq)
type(output), output.keys()

(transformers.modeling_outputs.Seq2SeqLMOutput,
 odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state']))

In [8]:
output.encoder_last_hidden_state

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [9]:
assert torch.all(output.encoder_last_hidden_state == 0.).item(), "Encoder should output a tensor full of 0s."

### Sentence-completion

**Comments:** Because we have no accoustic model, Whisper is unable to use the audio source to decode. Thus, we get some garbage output as expected.

In [10]:
# User input:
input_seq = "Hello, my name is Tony."

# Tokenize input sequence:
tokenized_seq = torch.tensor([processor.tokenizer(input_seq, add_special_tokens=False).input_ids])

# Shift inputs for next-word prediction:
decoder_input_ids = tokenized_seq[:, 1:]
shifted_left_decoder_input_ids = tokenized_seq[:, :-1]

# One-step generation:
output = model_zero_cross_attention.forward(input_features=input_features,
                                            decoder_input_ids=decoder_input_ids)

# Convert logits to log-probabilities:
log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)

# Take probabilities for the ground-truth tokens:
log_prob = log_prob_all.take_along_dim(shifted_left_decoder_input_ids[..., None], dim=-1)

# Compute perplexity:
perplexity = torch.exp(-log_prob.mean())

log_prob, perplexity

(tensor([[[-14.9954],
          [ -0.1458],
          [ -8.8252],
          [ -4.5569],
          [ -7.2125],
          [ -3.7046]]], grad_fn=<GatherBackward0>),
 tensor(715.8162, grad_fn=<ExpBackward0>))

With an another example. Since the input sequence is garbage, the obtained perplexity should be much higher.

In [11]:
# User input:
input_seq = "mountain no laptop apple sunny cambridge"

# Tokenize input sequence:
tokenized_seq = torch.tensor([processor.tokenizer(input_seq, add_special_tokens=False).input_ids])

# Shift inputs for next-word prediction:
decoder_input_ids = tokenized_seq[:, 1:]
shifted_left_decoder_input_ids = tokenized_seq[:, :-1]

# One-step generation:
output = model_zero_cross_attention.forward(input_features=input_features,
                                            decoder_input_ids=decoder_input_ids)

# Convert logits to log-probabilities:
log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)

# Take probabilities for the ground-truth tokens:
log_prob = log_prob_all.take_along_dim(shifted_left_decoder_input_ids[..., None], dim=-1)

# Compute perplexity:
perplexity = torch.exp(-log_prob.mean())

log_prob, perplexity

(tensor([[[-14.3764],
          [-34.3825],
          [-15.2447],
          [ -8.7378],
          [ -7.8441],
          [-10.8189],
          [ -6.8493]]], grad_fn=<GatherBackward0>),
 tensor(1246996.2500, grad_fn=<ExpBackward0>))

### Bonus: Behavior with `generate`

In [12]:
# Tokenize input sequence:
decoder_input_ids = torch.tensor([processor.tokenizer("Hello my name is", add_special_tokens=False).input_ids])

# One-step generation:
output = model_zero_cross_attention.forward(input_features=input_features,
                                            decoder_input_ids=decoder_input_ids)

In [13]:
# generate token ids
predicted_ids = model_zero_cross_attention.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
predicted_ids, transcription

(tensor([[50257, 50362,   314,  1101, 50256]]),
 ["<|startoftranscript|><|notimestamps|> I'm<|endoftext|>"])