## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [5]:
from pathlib import Path
from tqdm.auto import tqdm

import numpy as np
import pandas as pd
import torch
from transformers import pipeline
from transformers.models.whisper import WhisperTokenizerFast, WhisperFeatureExtractor, WhisperForConditionalGeneration

import matplotlib.pyplot as plt
import seaborn as sns

from dataloader.dataset_loader import gen_from_dataset
from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP
from evaluation.string_edit_metrics import get_string_edit_metrics_ortho_and_norm
from normalization.whisper_normalization import get_whisper_normalizer
from utils.file_io import load_json
from utils.whisper_hallucinations.dataloader import load_dataset
from utils.whisper_hallucinations.get_features import add_features_to_ds

device = torch.device('cpu')
sns.set_theme(context="paper", style="ticks")

## Load tokenizer

In [6]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path).to(device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]

tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)

## Load dataset

In [64]:
dataset_name = "ami_validation"

dataset = load_dataset(dataset_name)



Found cached dataset ami (/Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5)
Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5/cache-76a34bc037fa70e6.arrow
Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5/cache-8c6e325cf1e5403b.arrow


In [65]:
from functools import partial
from dataloader.preprocessing_train.preprocessing import prepare_dataset_fct

In [66]:
dataset = dataset.select(list(range(8)))

In [67]:
dataset

Dataset({
    features: ['text', 'audio'],
    num_rows: 8
})

In [68]:
prepare_dataset = partial(prepare_dataset_fct, tokenizer=tokenizer, feature_extractor=feature_extractor)

In [69]:
ds = dataset.map(prepare_dataset, num_proc=4).with_format("pt")

Map (num_proc=4):   0%|          | 0/8 [00:00<?, ? examples/s]

In [70]:
predicted_ids = model.generate(ds["input_features"], return_token_timestamps=True)

In [71]:
predicted_ids.keys()

odict_keys(['sequences', 'encoder_attentions', 'decoder_attentions', 'cross_attentions', 'token_timestamps'])

In [72]:
predicted_ids["sequences"][:3]

tensor([[50258, 50259, 50359, 50363,   583,   411,  6013, 10216,   362, 11171,
           293,   436,   434,  7084,    13, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257],
        [50258, 50259, 50359, 50363,   291, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257],
        [50258, 50259, 50359, 50363,  4919,    13, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257]])

In [73]:
predicted_ids["token_timestamps"][:3]

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.2800,  0.4200,  0.7600,
          1.0600,  1.4200,  1.8000,  2.2000,  2.2800,  2.2800,  2.6600,  5.2800,
         23.0800, 23.0800, 23.0800, 23.0800, 23.0800, 23.1000, 23.1400],
        [ 0.0000,  0.0000, 29.6200, 29.6200, 29.6200, 29.6200, 29.6200, 29.6400,
         29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600,
         29.6600, 29.6600, 29.6600, 29.7800, 29.7800, 29.7800, 29.7800],
        [ 0.0000,  0.0000, 11.6000, 29.6400, 29.6400, 29.6400, 29.6400, 29.6400,
         29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.6600, 29.7800, 29.7800,
         29.7800, 29.7800, 29.7800, 29.7800, 29.7800, 29.7800, 29.7800]])

In [43]:
tokenizer.batch_decode(predicted_ids["sequences"][0:1])

['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z Z']

In [45]:
x = predicted_ids["token_timestamps"][0]
x

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  2.1000,  4.0600,  4.1000,
         4.1000,  4.5400,  5.2600,  5.5000,  5.9400,  6.2200,  6.2600,  6.5800,
         7.2400,  9.3200,  9.3200, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600,
        16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600,
        16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600,
        16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600, 16.2600,
        16.2600, 16.2600, 16.2800, 16.2800, 16.2800, 16.2800, 16.2800, 16.3600,
        16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600,
        16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600,
        16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600,
        16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600,
        16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600, 16.3600,
        16.3600, 16.3600, 16.3600, 16.36

In [51]:
x.shape

torch.Size([448])

In [46]:
import torch

def count_zero_length_elements(tensor):
    end_times = torch.roll(tensor, -1)
    return torch.sum(end_times == tensor)

In [50]:
count_zero_length_elements(x)

tensor(424)

In [58]:
import torch

def max_subarray_length(x):
    # Compute the differences between adjacent elements
    diffs = torch.diff(x)

    # Find the indices where the differences are non-zero
    indices = torch.nonzero(diffs)

    # Compute the lengths of the subarrays between the indices
    lengths = torch.diff(torch.cat([torch.tensor([-1]), indices.flatten(), torch.tensor([len(x)])]))

    # Find the maximum length of any subarray with one unique value
    # max_length = torch.max(lengths[x[indices[:, 0]] == x[indices[:, 0] + 1]])
    max_length = torch.max(lengths)

    return max_length

In [59]:
max_subarray_length(x)

tensor(172)