## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

from trainer.prompting import get_labels_with_prompt

## User input

## Load model

In [4]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

normalizer = processor.tokenizer._normalize

## Load dataset

In [5]:
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [6]:
x = next(iter(ds))

x

{'file': '/Users/Tony/.cache/huggingface/datasets/downloads/extracted/ebb1d3f740add5af71e53b628d8c9c55e64fc2ff14a6ff31de01228adc704d35/dev_clean/1272/128104/1272-128104-0000.flac',
 'audio': {'path': '/Users/Tony/.cache/huggingface/datasets/downloads/extracted/ebb1d3f740add5af71e53b628d8c9c55e64fc2ff14a6ff31de01228adc704d35/dev_clean/1272/128104/1272-128104-0000.flac',
  'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,
         0.0010376 ]),
  'sampling_rate': 16000},
 'text': 'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL',
 'speaker_id': 1272,
 'chapter_id': 128104,
 'id': '1272-128104-0000'}

In [7]:
label = normalizer(x["text"])  # normalize label
input_features = processor(x["audio"]["array"], sampling_rate=x["audio"]["sampling_rate"], return_tensors="pt").input_features

label, input_features.shape

('mister quilter is the apostle of the middle classes and we are glad to welcome his gospel',
 torch.Size([1, 80, 3000]))

## Tokenize the labels for teacher-forcing

In [8]:
tokenized_label = torch.LongTensor(processor.tokenizer(label, add_special_tokens=False).input_ids)

# Add batch dim:
tokenized_labels = tokenized_label[None, :]

tokenized_labels

tensor([[   76,  1964, 31619,   391,   307,   220,  3322, 50244,   295,   220,
          3322,  2808,  5359,   293,   321,   366,  5404,   220,  1353,  2928,
           702, 14943]])

## Add prompts to teacher-forced labels

In [9]:
processor.tokenizer.get_decoder_prompt_ids(language="english", task="transcribe")

[(1, 50259), (2, 50359), (3, 50363)]

In [10]:
processor.tokenizer.get_decoder_prompt_ids(language="french", task="transcribe")

[(1, 50265), (2, 50359), (3, 50363)]

In [11]:
labels_with_prompt, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=tokenized_labels, language="french", task="transcribe", tokenizer=processor.tokenizer)

labels_with_prompt

tensor([[50258, 50265, 50359, 50363,    76,  1964, 31619,   391,   307,   220,
          3322, 50244,   295,   220,  3322,  2808,  5359,   293,   321,   366,
          5404,   220,  1353,  2928,   702, 14943, 50257]])

In [12]:
processor.tokenizer.batch_decode(labels_with_prompt, skip_special_tokens=False, normalize=False)

['<|startoftranscript|><|fr|><|transcribe|><|notimestamps|>mister quilter is the apostle of the middle classes and we are glad to welcome his gospel<|endoftext|>']

## Predict

## Teacher-forced from greedy search

In [13]:
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="transcribe")

In [14]:
# Generate with greedy search - vanilla
pred_gen_raw = model.generate(inputs=input_features)
pred_gen_str = processor.tokenizer.batch_decode(pred_gen_raw, skip_special_tokens=True, normalize=False)
pred_gen = torch.LongTensor(processor.tokenizer.encode(pred_gen_str[0], add_special_tokens=False))[None, :]

pred_gen



tensor([[ 1456,  1804,   368,   635,   417,  8604,   871,   368,   635, 32400,
          1030,  4666,  2795,  1443,   263,   373,   892,   287,     6,  1246,
           368,   287,     6,  1246,   368,   635,   417,  8604,    13]])

In [15]:
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)

[" Le plus de la chasse est de la classe et nous débrouillons l'air de l'air de la chasse."]

In [16]:
pred_gen_with_prompts, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=pred_gen, language="french", task="transcribe", tokenizer=processor.tokenizer)

pred_gen_with_prompts

tensor([[50258, 50265, 50359, 50363,  1456,  1804,   368,   635,   417,  8604,
           871,   368,   635, 32400,  1030,  4666,  2795,  1443,   263,   373,
           892,   287,     6,  1246,   368,   287,     6,  1246,   368,   635,
           417,  8604,    13, 50257]])

In [17]:
# Sanity check (`pred_gen_str` won't be used here):
pred_gen_str = processor.tokenizer.batch_decode(pred_gen_with_prompts, skip_special_tokens=False, normalize=False)
pred_gen_str

["<|startoftranscript|><|fr|><|transcribe|><|notimestamps|> Le plus de la chasse est de la classe et nous débrouillons l'air de l'air de la chasse.<|endoftext|>"]

In [18]:
output = model.forward(input_features=input_features,
                       decoder_input_ids=pred_gen_with_prompts)
logits = output.logits
pred_ids = torch.argmax(logits, dim=-1)

pred_ids

tensor([[50259, 50358, 50363,  1456,  1804,   368,   635,   417,  8604,   871,
           368,   635, 32400,  1030,  4666,  2795,  1443,   263,   373,   892,
           287,     6,  1246,   368,   287,     6,  1246,   368,   635,   417,
          8604,    13, 50257, 50257]])

In [19]:
processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=False, normalize=False)

["<|en|><|translate|><|notimestamps|> Le plus de la chasse est de la classe et nous débrouillons l'air de l'air de la chasse.<|endoftext|><|endoftext|>"]

**Interesting:** The model recognized that the source is English as it outputed `<|en|>`. Yet, it got teacher-forced to predict `<|fr|>`. At that moment, the model decided the task of interest was `<|translate|>` which makes sense. However, the rest of the prediction is not a translation of the original sentence because of teacher-forcing again.