## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [4]:
import torch
from transformers.models.whisper import WhisperTokenizerFast, WhisperFeatureExtractor, WhisperForConditionalGeneration

import matplotlib.pyplot as plt
import seaborn as sns

from functools import partial
from dataloader.preprocessing_train.preprocessing import prepare_dataset_fct
from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP

device = torch.device('cpu')
sns.set_theme(context="paper", style="ticks")

## Load tokenizer

In [5]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path).to(device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]

tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)

## Load dataset

In [9]:
dataset_name = "ami_validation"

ds = EVAL_DATASET_NAME_TO_DATASET_GROUP["ami_eval"]()["ami_validation"]



Found cached dataset ami (/Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5)
Found cached dataset ami (/Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5)


In [10]:
ds = ds.select(list(range(8)))
prepare_dataset = partial(prepare_dataset_fct, tokenizer=tokenizer, feature_extractor=feature_extractor)
ds = ds.map(prepare_dataset, num_proc=4).with_format("pt")

Map (num_proc=4):   0%|          | 0/8 [00:00<?, ? examples/s]

In [11]:
predicted_ids = model.generate(ds["input_features"], return_token_timestamps=True)

predicted_ids.keys()



odict_keys(['sequences', 'encoder_attentions', 'decoder_attentions', 'cross_attentions', 'token_timestamps'])

In [12]:
predicted_ids["sequences"][:3]

tensor([[50258, 50259, 50359, 50363,   583,   411,  6013, 10216,   362, 11171,
           293,   436,   434,  7084,    13, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257],
        [50258, 50259, 50359, 50363,   291, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257],
        [50258, 50259, 50359, 50363,  4919,    13, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257]])

In [13]:
predicted_ids["token_timestamps"][:3]

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.2800,  0.4200,  0.7600,
          1.0600,  1.4200,  1.8000,  2.2000,  2.2800,  2.2800,  2.6600,  5.2800,
         23.0800, 23.0800, 23.0800, 23.0800, 23.0800, 23.1000, 23.1400],
        [ 0.0000,  0.0000, 29.6200, 29.6200, 29.6200, 29.6200, 29.6200, 29.6400,
         29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600,
         29.6600, 29.6600, 29.6600, 29.7800, 29.7800, 29.7800, 29.7800],
        [ 0.0000,  0.0000, 11.6000, 29.6400, 29.6400, 29.6400, 29.6400, 29.6400,
         29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.7800, 29.7800,
         29.7800, 29.7800, 29.7800, 29.7800, 29.7800, 29.7800, 29.7800]])