## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [4]:
from pathlib import Path

import numpy as np
import pandas as pd
from transformers.models.whisper import WhisperTokenizerFast

import matplotlib.pyplot as plt
import seaborn as sns

from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP
from evaluation.string_edit_metrics import get_string_edit_metrics
from utils.file_io import load_json

sns.set_theme(context="paper", style="ticks")

## Load tokenizer

In [5]:
pretrained_model_name_or_path = "openai/whisper-tiny"
tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")

## Load dataset

In [6]:
dataset_name = "ami"

ds_group = EVAL_DATASET_NAME_TO_DATASET_GROUP[dataset_name]()

if dataset_name == "librispeech_dummy":
    ds = ds_group.str2dataset["librispeech_dummy"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
elif dataset_name in ["ami", "ami_10h"]:
    ds = ds_group.str2dataset["ami"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
else:
    raise ValueError()



Found cached dataset ami (/Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5)
Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5/cache-818d7a4553237f1c.arrow


## Load predictions

In [8]:
cache_preds_filepath = "notebooks/data/whisper_preds/with_ts/ami_test-medium.json"
assert Path(cache_preds_filepath).is_file()

data = load_json(cache_preds_filepath)
results = data["predictions"]
references = data["references"]
print(f"Loaded cached predictions from `{cache_preds_filepath}`.")

predictions = [x["text"].lower() for x in results]

Loaded cached predictions from `notebooks/data/whisper_preds/with_ts/ami_test-medium.json`.


In [9]:
string_edit_metrics = 100 * pd.Series(get_string_edit_metrics(references=references, predictions=predictions))

string_edit_metrics

wer    39.098432
sub    24.473044
del    11.270716
ins     3.354672
dtype: float64

## Add predictions to dataset

In [10]:
# Tokenize labels:
ds = ds.map(lambda batch: {"labels": tokenizer(batch["text"]).input_ids}, batched=True)

Loading cached processed dataset at /Users/Tony/.cache/huggingface/datasets/edinburghcstr___ami/ihm/0.0.0/0d128d0aa8145d0f16f3d5b4da86c5d5759dbe9e8f947fda04b25edb56442bd5/cache-f3a2015a802ea41f.arrow


In [11]:
ds = add_features_to_ds(ds, results, tokenizer=tokenizer)

ds.features

Map:   0%|          | 0/12643 [00:00<?, ? examples/s]

Map:   0%|          | 0/12643 [00:00<?, ? examples/s]

Map:   0%|          | 0/12643 [00:00<?, ? examples/s]

Map:   0%|          | 0/12643 [00:00<?, ? examples/s]

Map:   0%|          | 0/12643 [00:00<?, ? examples/s]

Map:   0%|          | 0/12643 [00:00<?, ? examples/s]

{'text': Value(dtype='string', id=None),
 'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),
 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'teacher_text': Value(dtype='string', id=None),
 'teacher_labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'audio_length': Value(dtype='float64', id=None),
 'n_tokens_labels': Value(dtype='int64', id=None),
 'n_tokens_teacher': Value(dtype='int64', id=None),
 'n_overlaps': Value(dtype='int64', id=None),
 'is_fast_utterance': Value(dtype='bool', id=None),
 'diff_n_tokens': Value(dtype='int64', id=None),
 'max_token_repetitions_labels': Value(dtype='int64', id=None),
 'max_token_repetitions_teacher': Value(dtype='int64', id=None)}

In [12]:
savepath = f"notebooks/data/whisper_hallucinations_cached_ds/{dataset_name}"
Path(savepath).parent.mkdir(parents=True, exist_ok=True)
ds.save_to_disk(savepath)

print(f"Cached dataset at `{savepath}`")

Saving the dataset (0/3 shards):   0%|          | 0/12643 [00:00<?, ? examples/s]

Cached dataset at `notebooks/data/whisper_hallucinations_cached_ds/ami`
