## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

from trainer.prompting import get_labels_with_prompt

## User input

## Load model

In [4]:
checkpoint_dirpath = "checkpoints/tiny-finetuned_on_ami/"

In [5]:
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_dirpath)
processor = WhisperProcessor.from_pretrained(checkpoint_dirpath)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

normalizer = processor.tokenizer._normalize

## Load dataset

In [6]:
ds = load_dataset("edinburghcstr/ami",
                  name="ihm",
                  split="train",
                  streaming=True)

In [7]:
ds_iter = iter(ds)
n_skip = 3

for _ in range(n_skip):
    next(ds_iter)

x = next(ds_iter)

x

{'meeting_id': 'EN2001a',
 'audio_id': 'AMI_EN2001a_H04_MEO069_0145515_0146152',
 'text': "YEAH IT'LL IT'LL PLAY THEM IN SOME ORDER IN WHICH THEY WERE SET BECAUSE OTHERWISE IT'S GONNA BE MORE ENTERTAINING",
 'audio': {'path': 'EN2001a/train_ami_en2001a_h04_meo069_0145515_0146152.wav',
  'array': array([ 0.00000000e+00,  0.00000000e+00,  6.10351562e-05, ...,
         -6.10351562e-05, -6.10351562e-05, -3.05175781e-05]),
  'sampling_rate': 16000},
 'begin_time': 1455.15,
 'end_time': 1461.52,
 'microphone_id': 'H04',
 'speaker_id': 'MEO069'}

In [8]:
label = normalizer(x["text"])  # normalize label
input_features = processor(x["audio"]["array"], sampling_rate=x["audio"]["sampling_rate"], return_tensors="pt").input_features

label, input_features.shape

('yeah it will it will play them in some order in which they were set because otherwise it is going to be more entertaining',
 torch.Size([1, 80, 3000]))

## Tokenize the labels for teacher-forcing

In [9]:
tokenized_label = torch.LongTensor(processor.tokenizer(label, add_special_tokens=False).input_ids)

# Add batch dim:
tokenized_labels = tokenized_label[None, :]

tokenized_labels

tensor([[19650,   309,   486,   309,   486,   862,   220, 47959,   294,   512,
          1668,   294,   597,   220, 13162,   645,   992,   570,  5911,   309,
           307,   516,   220,  1353,   312,   544, 20402]])

## Add prompts to teacher-forced labels

In [10]:
processor.tokenizer.get_decoder_prompt_ids(language=None, task=None)

[(1, 50259), (2, 50359), (3, 50363)]

In [11]:
labels_with_prompt, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=tokenized_labels, language="english", task="transcribe", tokenizer=processor.tokenizer)

labels_with_prompt

tensor([[50258, 50259, 50359, 50363, 19650,   309,   486,   309,   486,   862,
           220, 47959,   294,   512,  1668,   294,   597,   220, 13162,   645,
           992,   570,  5911,   309,   307,   516,   220,  1353,   312,   544,
         20402, 50257]])

In [12]:
processor.tokenizer.batch_decode(labels_with_prompt, skip_special_tokens=False, normalize=False)

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|>yeah it will it will play them in some order in which they were set because otherwise it is going to be more entertaining<|endoftext|>']

## Predict

## Teacher-forced from greedy search

In [13]:
# Generate with greedy search - vanilla
pred_gen_raw = model.generate(inputs=input_features)
pred_gen_str = processor.tokenizer.batch_decode(pred_gen_raw, skip_special_tokens=True, normalize=True)
pred_gen = torch.LongTensor(processor.tokenizer.encode(pred_gen_str[0], add_special_tokens=False))[None, :]

pred_gen



tensor([[19650,   309,   486,   309,   486,   862,   220, 47959,   294,   512,
          1668,   294,   597,   220, 13162,   645,   992,  3082,  5911,   309,
           307,   516,   220,  1353,   312,   544, 20402]])

In [14]:
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)

['yeah it will it will play them in some order in which they were set cause otherwise it is going to be more entertaining']

In [15]:
pred_gen_with_prompts, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=pred_gen, language=None, task=None, tokenizer=processor.tokenizer)

pred_gen_with_prompts

tensor([[50258, 50259, 50359, 50363, 19650,   309,   486,   309,   486,   862,
           220, 47959,   294,   512,  1668,   294,   597,   220, 13162,   645,
           992,  3082,  5911,   309,   307,   516,   220,  1353,   312,   544,
         20402, 50257]])

In [16]:
# Sanity check (`pred_gen_str` won't be used here):
pred_gen_str = processor.tokenizer.batch_decode(pred_gen_with_prompts, skip_special_tokens=False, normalize=False)
pred_gen_str

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|>yeah it will it will play them in some order in which they were set cause otherwise it is going to be more entertaining<|endoftext|>']

In [17]:
output = model.forward(input_features=input_features,
                       decoder_input_ids=pred_gen_with_prompts)
logits = output.logits
pred_ids = torch.argmax(logits, dim=-1)

pred_ids

tensor([[50258, 50359, 50363,  1338,   309,   486,   309,   486,   862,   220,
         47959,   294,   512,  1668,   294,   597,   220, 13162,   645,   992,
          3082,  5911,   309,   307,   516,   220,  1353,   312,   544, 20402,
         50257, 50257]])

In [18]:
processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=False, normalize=False)

['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it will play them in some order in which they were set cause otherwise it is going to be more entertaining<|endoftext|><|endoftext|>']

## With `generate`

In [19]:
# Generate with greedy search - vanilla
pred_gen = model.generate(inputs=input_features)
pred_gen

tensor([[50258, 50259, 50359, 50363,  1338,   309,   486,   309,   486,   862,
           220, 47959,   294,   512,  1668,   294,   597,   220, 13162,   645,
           992,  3082,  5911,   309,   307,   516,   220,  1353,   312,   544,
         20402, 50257]])

In [20]:
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> yeah it will it will play them in some order in which they were set cause otherwise it is going to be more entertaining<|endoftext|>']

## Comparison

In [21]:
processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, normalize=False)

[' yeah it will it will play them in some order in which they were set cause otherwise it is going to be more entertaining']

In [22]:
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=True, normalize=False)

[' yeah it will it will play them in some order in which they were set cause otherwise it is going to be more entertaining']

In [23]:
label

'yeah it will it will play them in some order in which they were set because otherwise it is going to be more entertaining'

## Bonus: Step-wise teacher-forced

In [24]:
res = []

for idx in range(1, pred_gen_with_prompts.shape[1]):  # we add 1 to finish the loop with the full sentence
    # One-step generation:
    output = model.forward(input_features=input_features,
                           decoder_input_ids=pred_gen_with_prompts[:, :idx])
    
    log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)
    
    output_tokenized_seq = torch.argmax(output.logits, dim=-1)
    res.append(processor.tokenizer.batch_decode(output_tokenized_seq))

In [25]:
res

[['<|startoftranscript|>'],
 ['<|startoftranscript|><|transcribe|>'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|>'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it will'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it will play'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it will play '],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it will play them'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it will play them in'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it will play them in some'],
 ['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will it w