## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
from datasets import load_dataset
import evaluate

from dataloader.dataloader import gen_from_dataset
from dataloader.dataset_for_evaluation.ami_test import AMITestSet
from evaluation.string_edit_metrics import get_string_edit_metrics

metric = evaluate.load("wer")

## User input

## Load model

In [4]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", language="english", task="transcribe")

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

whisper_norm = processor.tokenizer._normalize

## Load dataset

In [5]:
ds_group = AMITestSet(streaming=True)



In [6]:
ds = ds_group.str2dataset["ami_test"]

ds

<datasets.iterable_dataset.IterableDataset at 0x28ccdf4c0>

## Create pipeline

In [7]:
whisper_asr = pipeline(task="automatic-speech-recognition",
                       model=model,
                       tokenizer=processor.tokenizer,  # type: ignore
                       feature_extractor=processor.feature_extractor,  # type: ignore
)

## Run pipeline

In [8]:
count = 0

# Create placeholders for the predictions and references:
predictions = []
references = []

for out in whisper_asr(gen_from_dataset(ds),
                       batch_size=4,
                       generate_kwargs={"num_beams": 1}):  # type: ignore
    if count > 100:
        break
    
    ref = whisper_norm(out["reference"][0])
    pred = whisper_norm(out["text"])

    if not ref.strip():
        continue  # skip empty references to avoid error in WER computation

    predictions.append(pred)
    references.append(ref)
    
    count += 1



In [9]:
for pred, ref in zip(predictions, references):
    print(pred, " | ", ref)

thank you  |  yeah
yeah we are going to meet up  |  yeah we are going to meet up
we are not even there yet  |  i mean we are not even there yet
yeah  |  yeah
yeah  |  yeah yeah
monday afternoon  |  monday afternoon
you use my pen  |  get to use my pen
yeah  |  yeah yeah
everything  |  everything
yeah that is what we need  |  yeah that is what we need yeah
it is very scary  |  3 is good though
okay  |  yeah
he messed up  |  you better start
we can always decide that  |  we can always decide then i mean yeah
he is so .  |  you saw it yeah
yeah  |  yeah true
well it is kind of like .  |  well 0 dear
i mean  |  i mean
thanks for watching  |  yeah
0 i do not know  |  0 i do not know
yeah  |  yeah
thank you  |  i th
that is cool  |  that is cool
so even if it does not work you can jiggery poke it around and make it work  |  so even if it does not work you can jiggery pokery around and make it work
but most will probably want to go with the faults  |  but most will probably want to go with de

## Compute string edit metrics

In [10]:
metric.compute(predictions=predictions, references=references)

0.3147208121827411

In [11]:
get_string_edit_metrics(predictions=predictions, references=references)

{'wer': 0.3147208121827411,
 'sub': 0.17597292724196278,
 'del': 0.06937394247038917,
 'ins': 0.06937394247038917}