## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
from tqdm.auto import tqdm

import torch
from transformers import pipeline
from transformers.models.whisper import (WhisperTokenizer,
                                         WhisperTokenizerFast,
                                         WhisperFeatureExtractor,
                                         WhisperForConditionalGeneration)
from datasets import load_dataset
import evaluate

from dataloader.dataset_loader import gen_from_dataset
from dataloader.dataset_for_evaluation.librispeech_dummy_dataset import LibriSpeechDummyDataset
from evaluation.string_edit_metrics import get_string_edit_metrics

device = torch.device('mps')
metric = evaluate.load("wer")

## User input

## Load model

In [4]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")


model.config.forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

whisper_norm = tokenizer._normalize

## Load dataset

In [5]:
ds_group = LibriSpeechDummyDataset()



Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


In [6]:
ds = ds_group.str2dataset["librispeech_dummy"]

ds

Dataset({
    features: ['file', 'audio', 'text', 'speaker_id', 'chapter_id', 'id'],
    num_rows: 73
})

## Create pipeline

In [7]:
whisper_asr = pipeline(task="automatic-speech-recognition",
                       model=model,
                       tokenizer=tokenizer,
                       feature_extractor=feature_extractor,  # type: ignore
                       device=device
)

## Run pipeline

In [8]:
# Create placeholders for the predictions and references:
predictions = []
references = []

for out in tqdm(whisper_asr(gen_from_dataset(ds),
                            batch_size=16,
                            generate_kwargs={"num_beams": 1}),
                total=ds.num_rows):
    
    ref = out["reference"][0]
    pred = out["text"]

    if not ref.strip():
        continue  # skip empty references to avoid error in WER computation

    predictions.append(pred)
    references.append(ref)

  0%|          | 0/73 [00:00<?, ?it/s]

  if unfinished_sequences.max() == 0:


## Compute string edit metrics

Let's try different norm and WER functions.

In [9]:
predictions_norm = [whisper_norm(x) for x in predictions]
references_norm = [whisper_norm(x) for x in references]

In [10]:
metric.compute(predictions=predictions_norm, references=references_norm)

0.11804961505560307

In [11]:
get_string_edit_metrics(predictions=predictions_norm, references=references_norm)

{'wer': 0.11804961505560307,
 'sub': 0.08297690333618478,
 'del': 0.013686911890504704,
 'ins': 0.0213857998289136}