## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [22]:
from functools import partial
from tqdm.auto import tqdm

import torch
from transformers import pipeline
from transformers.models.whisper import (WhisperTokenizer,
                                         WhisperTokenizerFast,
                                         WhisperFeatureExtractor,
                                         WhisperForConditionalGeneration)
from datasets import load_dataset
import evaluate

from dataloader.collator import DataCollatorSpeechSeq2SeqWithPadding
from dataloader.preprocessing_train.preprocessing import prepare_dataset_fct
from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP

device = torch.device('mps')

## Load model

In [9]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")

model.generate = partial(model.generate, language="english", task="transcribe",
                         max_length=255, use_cache=True)

## Load dataset

In [36]:
# dataset_name = "librispeech_dummy"
dataset_name = "ami"

ds = EVAL_DATASET_NAME_TO_DATASET_GROUP[dataset_name]()[dataset_name]

if dataset_name == "ami":
    ds = ds.select(list(range(32)))



Found cached dataset ami (/Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5)


In [37]:
prepare_dataset = partial(prepare_dataset_fct, tokenizer=tokenizer, feature_extractor=feature_extractor)
ds = ds.map(prepare_dataset, num_proc=4).with_format("pt")

Map (num_proc=4):   0%|          | 0/32 [00:00<?, ? examples/s]

## Predict

In [119]:
x = ds[25]

x["text"]

"SO EVEN IF IT DOESN'T WORK YOU CAN JIGGERY POKERY AROUND AND MAKE IT WORK"

In [120]:
outputs = model.generate(x["input_features"][None, ...], num_beams=3, num_return_sequences=3)

tokenizer.batch_decode(outputs, skip_special_tokens=True)

[" So even if it doesn't work, you can jigory poke it around and make it work.",
 " So even if it doesn't work, you can jiggery poke it around and make it work.",
 " So, even if it doesn't work, you can jiggery poke it around and make it work."]

In [121]:
outputs = model.generate(x["input_features"][None, ...], do_sample=True, top_p=0.92, num_return_sequences=3)

tokenizer.batch_decode(outputs, skip_special_tokens=True)

[" So even if it doesn't work, you can jiggerle it up around and make it work.",
 " So even if it doesn't work you can jitter eat pot agree around and make it work.",
 " Even if it doesn't work, you can jigger rip out around and make it work."]