## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
from typing import Dict, Any
from tqdm.auto import tqdm

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from transformers import pipeline
from transformers.models.whisper import (WhisperTokenizer,
                                         WhisperTokenizerFast,
                                         WhisperFeatureExtractor,
                                         WhisperForConditionalGeneration)
from datasets import load_dataset

from dataloader.dataset_loader import gen_from_dataset
from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP
from evaluation.string_edit_metrics import get_string_edit_metrics

from utils.constants import GEN_MAX_LENGTH, DEFAULT_EVAL_NUM_BEAMS

device = torch.device('mps')
sns.set_theme(context="paper", style="ticks")

## User input

## Load model

In [4]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path).to(device)
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path,
                                                 language="english",
                                                 task="transcribe",
                                                 predict_timestamps=True)

whisper_norm = tokenizer._normalize

## Load dataset

In [5]:
dataset_name = "librispeech_dummy"

ds_group = EVAL_DATASET_NAME_TO_DATASET_GROUP[dataset_name]()

if dataset_name == "librispeech_dummy":
    ds = ds_group.str2dataset["librispeech_dummy"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
elif dataset_name in ["ami", "ami_10h"]:
    ds = ds_group.str2dataset["ami"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
else:
    raise ValueError()



Downloading builder script:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr_dummy/clean to /Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset librispeech_asr_dummy downloaded and prepared to /Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b. Subsequent calls will reuse this data.


Map:   0%|          | 0/73 [00:00<?, ? examples/s]

In [22]:
sample = ds[0]["audio"]

In [26]:
input_features = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features

In [32]:
generate_ids = model.generate(input_features.to(device), return_timestamps=True, task="transcribe")

In [34]:
tokenizer.decode(generate_ids[0], decode_with_timestamps=True)

'<|startoftranscript|><|en|><|transcribe|><|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|><|endoftext|>'