## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
from tqdm.auto import tqdm

import torch
from transformers import pipeline
from transformers.models.whisper import (WhisperTokenizer,
                                         WhisperTokenizerFast,
                                         WhisperFeatureExtractor,
                                         WhisperForConditionalGeneration)
from datasets import load_dataset
import evaluate

from dataloader.dataset_loader import gen_from_dataset
from dataloader.dataset_for_evaluation.ami_test import AMITestSet
from evaluation.string_edit_metrics import get_string_edit_metrics

device = torch.device('mps')
metric = evaluate.load("wer")

## User input

## Load model

In [4]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")


model.config.forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

whisper_norm = tokenizer._normalize

## Load dataset

In [5]:
ds_group = AMITestSet(streaming=True)



In [6]:
ds = ds_group.str2dataset["ami"]

ds

<datasets.iterable_dataset.IterableDataset at 0x17c035540>

## Create pipeline

In [7]:
whisper_asr = pipeline(task="automatic-speech-recognition",
                       model=model,
                       tokenizer=tokenizer,
                       feature_extractor=feature_extractor,  # type: ignore
                       device=device
)

## Run pipeline

In [8]:
count = 0
n_samples = 200

# Create placeholders for the predictions and references:
predictions = []
references = []

for out in tqdm(whisper_asr(gen_from_dataset(ds),
                            batch_size=16,
                            generate_kwargs={"num_beams": 1}),
                total=n_samples):
    ref = whisper_norm(out["reference"][0])
    pred = whisper_norm(out["text"])

    if not ref.strip():
        continue  # skip empty references to avoid error in WER computation
    
    predictions.append(pred)
    references.append(ref)
    
    count += 1
    if count >= n_samples:
        break

  0%|          | 0/200 [00:00<?, ?it/s]

  if unfinished_sequences.max() == 0:


## Compute string edit metrics

In [9]:
metric.compute(predictions=predictions, references=references)

0.30985915492957744

In [10]:
get_string_edit_metrics(predictions=predictions, references=references)

{'wer': 0.30985915492957744,
 'sub': 0.16735708367854185,
 'del': 0.09113504556752279,
 'ins': 0.05136702568351284}

## Per-example analysis

In [13]:
for idx, (prediction, reference) in enumerate(zip(predictions, references)):
    wer = get_string_edit_metrics(predictions=[prediction], references=[reference])["wer"]
    if wer > 2:
        print("idx: ", idx)
        print("prediction: ", prediction)
        print("reference: ", reference)
        print(wer)
        print()

idx:  18
prediction:  thanks for watching
reference:  yeah
3.0

idx:  57
prediction:  you are so funny
reference:  so
3.0

idx:  146
prediction:  0 it is heavy
reference:  ooh
4.0



In [14]:
idx_of_interest = [117, 152, 156]
for idx, (prediction, reference) in enumerate(zip(predictions, references)):
    wer = get_string_edit_metrics(predictions=[prediction], references=[reference])["wer"]
    if idx in idx_of_interest:
        print("idx: ", idx)
        print("prediction: ", prediction)
        print("reference: ", reference)
        print(wer)
        print()

idx:  117
prediction:  yeah
reference:  yeah
0.0

idx:  152
prediction:  you
reference:  yeah
1.0

idx:  156
prediction:  hahaha
reference:  yeah
1.0

