## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [7]:
from tqdm.auto import tqdm
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration, pipeline
from datasets import load_dataset
import evaluate

from dataloader.dataloader import gen_from_dataset
from dataloader.dataset_for_evaluation.librispeech_dummy_dataset import LibriSpeechDummyDataset
from evaluation.string_edit_metrics import get_string_edit_metrics
from utils.constants import GEN_MAX_LENGTH

metric = evaluate.load("wer")

## User input

## Load model

In [4]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

whisper_norm = processor.tokenizer._normalize

## Load dataset

In [5]:
ds_group = LibriSpeechDummyDataset()



Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [6]:
ds = ds_group.str2dataset["librispeech_dummy"]

ds

Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})

## Create pipeline

In [9]:
whisper_asr = pipeline(task="automatic-speech-recognition",
                       model=model,
                       tokenizer=processor.tokenizer,  # type: ignore
                       feature_extractor=processor.feature_extractor,  # type: ignore
)

## Run pipeline

In [16]:
# Create placeholders for the predictions and references:
predictions = []
references = []

for out in tqdm(whisper_asr(gen_from_dataset(ds),
                            batch_size=8,
                            generate_kwargs={"max_length": 255, "num_beams": 5}),
                total=ds.num_rows):
    
    ref = whisper_norm(out["reference"][0])
    pred = whisper_norm(out["text"])

    if not ref.strip():
        continue  # skip empty references to avoid error in WER computation

    predictions.append(pred)
    references.append(ref)

  0%|          | 0/73 [00:00<?, ?it/s]

In [17]:
for idx, (pred, ref) in enumerate(zip(predictions, references)):
    if idx > 10:
        break
    print(pred, " | ", ref)

mister quilter is the apostle of the middle classes and we are glad to welcome his gospel  |  mister quilter is the apostle of the middle classes and we are glad to welcome his gospel
nor is mister quilter is manner less interesting than his matter  |  nor is mister quilter is manner less interesting than his matter
he tells us that at this festive season of the year with christmas and roast beef looming before us similarly drawn from eating and its results occur most readily to the mind  |  he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind
he has grave doubts whether sir frederick layton is work is really greek after all and can discover in it but little of rocky ithaca  |  he has grave doubts whether sir frederick leighton is work is really greek after all and can discover in it but little of rocky ithaca
linel is pictures are a sort of upguards and atom painting

## Compute string edit metrics

In [18]:
metric.compute(predictions=predictions, references=references)

0.09923011120615911

In [19]:
get_string_edit_metrics(predictions=predictions, references=references)

{'wer': 0.09923011120615911,
 'sub': 0.06843455945252352,
 'del': 0.011976047904191617,
 'ins': 0.018819503849443968}