## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

from trainer.prompting import get_labels_with_prompt

## User input

## Load model

In [4]:
checkpoint_dirpath = "checkpoints/tiny-finetuned_on_ami/"

In [5]:
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_dirpath)
processor = WhisperProcessor.from_pretrained(checkpoint_dirpath)

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

normalizer = processor.tokenizer._normalize

## Load dataset

In [6]:
ds = load_dataset("edinburghcstr/ami",
                  name="ihm",
                  split="train",
                  streaming=True)

In [7]:
ds_iter = iter(ds)
n_skip = 3

for _ in range(n_skip):
    next(ds_iter)

x = next(ds_iter)

x

{'meeting_id': 'EN2001a',
 'audio_id': 'AMI_EN2001a_H04_MEO069_0145515_0146152',
 'text': "YEAH IT'LL IT'LL PLAY THEM IN SOME ORDER IN WHICH THEY WERE SET BECAUSE OTHERWISE IT'S GONNA BE MORE ENTERTAINING",
 'audio': {'path': 'EN2001a/train_ami_en2001a_h04_meo069_0145515_0146152.wav',
  'array': array([ 0.00000000e+00,  0.00000000e+00,  6.10351562e-05, ...,
         -6.10351562e-05, -6.10351562e-05, -3.05175781e-05]),
  'sampling_rate': 16000},
 'begin_time': 1455.15,
 'end_time': 1461.52,
 'microphone_id': 'H04',
 'speaker_id': 'MEO069'}

In [8]:
label = normalizer(x["text"])  # normalize label
input_features = processor(x["audio"]["array"], sampling_rate=x["audio"]["sampling_rate"], return_tensors="pt").input_features

label, input_features.shape

('yeah it will it will play them in some order in which they were set because otherwise it is going to be more entertaining',
 torch.Size([1, 80, 3000]))

## 🆕 Change a few words to see if the model turned to a unigram predictor

In [9]:
label = 'yes it will in some order in which they were set because otherwise it is going to be more entertaining'

## Tokenize the labels for teacher-forcing

In [10]:
tokenized_label = torch.LongTensor(processor.tokenizer(label, add_special_tokens=False).input_ids)

# Add batch dim:
tokenized_labels = tokenized_label[None, :]

tokenized_labels

tensor([[ 2346,   309,   486,   294,   512,  1668,   294,   597,   220, 13162,
           645,   992,   570,  5911,   309,   307,   516,   220,  1353,   312,
           544, 20402]])

## Add prompts to teacher-forced labels

In [11]:
processor.tokenizer.get_decoder_prompt_ids(language=None, task=None)

[(1, 50259), (2, 50359), (3, 50363)]

In [12]:
labels_with_prompt, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=tokenized_labels, language="english", task="transcribe", tokenizer=processor.tokenizer)

labels_with_prompt

tensor([[50258, 50259, 50359, 50363,  2346,   309,   486,   294,   512,  1668,
           294,   597,   220, 13162,   645,   992,   570,  5911,   309,   307,
           516,   220,  1353,   312,   544, 20402, 50257]])

In [13]:
processor.tokenizer.batch_decode(labels_with_prompt, skip_special_tokens=False, normalize=False)

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|>yes it will in some order in which they were set because otherwise it is going to be more entertaining<|endoftext|>']

## Predict

## Teacher-forced from ground-truth

In [14]:
# Sanity check (`pred_gen_str` won't be used here):
labels_with_prompt_ids = processor.tokenizer.batch_decode(labels_with_prompt, skip_special_tokens=False, normalize=False)
labels_with_prompt_ids

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|>yes it will in some order in which they were set because otherwise it is going to be more entertaining<|endoftext|>']

In [15]:
output = model.forward(input_features=input_features,
                       decoder_input_ids=labels_with_prompt)
logits = output.logits
pred_ids = torch.argmax(logits, dim=-1)

pred_ids

tensor([[50258, 50359, 50363,  1338,   309,   486,   862,   862,  1668,   294,
           597,   220, 13162,   645,   992,  3082,  5911,   309,   307,   516,
           220,  1353,   312,   544, 20402, 50257, 50257]])

In [16]:
processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=False, normalize=False)

['<|startoftranscript|><|transcribe|><|notimestamps|> yeah it will play play order in which they were set cause otherwise it is going to be more entertaining<|endoftext|><|endoftext|>']