## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

/Users/Tony/Other Docs/distilling-and-forgetting-in-large-pre-trained-models


In [3]:
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve, auc

from transformers.models.whisper import WhisperTokenizerFast
from datasets import load_from_disk

import matplotlib.pyplot as plt
import seaborn as sns

from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP
from evaluation.string_edit_metrics import get_string_edit_metrics_ortho_and_norm
from normalization.whisper_normalization import get_whisper_normalizer
from utils.whisper_hallucinations.get_features import add_features_to_ds, compute_gzip_compression_ratio
from utils.whisper_hallucinations.eval_filter_criterion import eval_filter_criterion
from utils.notebook_utils import listen_to_audio

sns.set_theme(context="paper", style="ticks")

OUTPUT_DIR = Path("notebooks/outputs/8_1_best_kd/ami_100h")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## User input

In [4]:
load_from_pickle = True

pickle_filepath = "notebooks/data/ami_100h_medium_cached_33p.pkl"
ds_dirpath = "/home/tw581/rds/hpc-work/ami_100h_medium_cached_33p"

## Load data

In [5]:
LIST_FEATURES = [
    'text',
    'teacher_text',
    'n_instant_tokens',
    'max_subarray_length',
    'audio_length',
    'n_tokens_labels',
    'n_tokens_teacher',
    'diff_n_tokens',
    'gzip_ratio',
    'teacher_gzip_ratio',
    'diff_gzip_ratio'
]

if load_from_pickle:
    df = pd.read_pickle(pickle_filepath)
else:
    ds = load_from_disk(ds_dirpath)
    ds = ds.map(lambda x: {"teacher_text": tokenizer.decode(x["teacher_sequences"], skip_special_tokens=True)})
    ds = add_features_to_ds(ds)
    df = pd.DataFrame({col: ds[col] for col in ds.features.keys() if col in LIST_FEATURES})
    df.to_pickle(pickle_filepath)

In [6]:
df.head()

Unnamed: 0,text,teacher_text,audio_length,n_tokens_labels,n_tokens_teacher,diff_n_tokens,gzip_ratio,teacher_gzip_ratio,diff_gzip_ratio,n_instant_tokens,max_subarray_length
0,if you if you s. s. h. and they have this big ...,"If you ask this agent, they have this big war...",4.21,33,30,-3,1.084211,1.041237,-0.042973,22,23
1,i've gotten mm hardly any,Hardly any.,1.63,11,9,-2,0.555556,0.375,-0.180556,5,6
2,it's yeah i mean the wave data are obviously n...,because the wave data are obviously not going...,5.25,23,22,-1,0.929412,0.914634,-0.014778,13,12
3,yeah it'll it'll play them in some order in wh...,"Yeah, it'll play them in some order in which ...",6.37,30,33,3,1.118812,1.075472,-0.04334,26,27
4,yeah,Yeah.,0.37,6,7,1,0.166667,0.230769,0.064103,4,4


In [7]:
from tqdm.auto import trange
from functools import partial
from utils.whisper_hallucinations.get_features import max_contiguous_ngrams

In [12]:
for n in trange(1, 8):
    df[f"max_contiguous_ngrams_{n}"] = df["teacher_text"].apply(partial(max_contiguous_ngrams, n=n))

  0%|          | 0/7 [00:00<?, ?it/s]

In [13]:
df.sort_values("max_contiguous_ngrams_4", ascending=False).head(10)

Unnamed: 0,text,teacher_text,audio_length,n_tokens_labels,n_tokens_teacher,diff_n_tokens,gzip_ratio,teacher_gzip_ratio,diff_gzip_ratio,n_instant_tokens,max_subarray_length,max_contiguous_ngrams_1,max_contiguous_ngrams_2,max_contiguous_ngrams_3,max_contiguous_ngrams_4,max_contiguous_ngrams_5,max_contiguous_ngrams_6,max_contiguous_ngrams_7
8571,no no no no no no no,No no no no no no no no no no no no no no no ...,1.13,12,226,214,0.8,22.1,21.300001,219,116,220,110,73,55,44,36,31
22414,s,"So, so, so, so, so, so, so, so, so, so, so, s...",1.8,6,226,220,0.047619,14.766666,14.719048,217,58,109,54,36,27,21,18,15
16133,well no no,"No, no, no, no, no, no, no, no, no, no, no, n...",1.23,8,226,218,0.357143,14.766666,14.409524,217,147,109,54,36,27,21,18,15
5910,well 'cause sometimes they go you know like th...,"Because sometimes they go, uh, uh, uh, uh, uh...",7.27,36,227,191,1.145631,8.5,7.354369,218,88,108,54,36,27,21,18,15
8840,yeah yeah parametri yeah i i i'm looking for p...,"Yeah, yeah, yeah, yeah, yeah, yeah, yeah, yea...",3.39,21,226,205,1.015152,19.558823,18.543671,214,135,109,54,36,27,21,18,15
19881,for instance um let's say oh oh um,"for instance, let's see, oh, oh, oh, oh, oh, ...",6.3,14,226,212,0.666667,9.06,8.393333,217,88,107,53,35,26,21,17,15
69,so then you'd start with all your utterances h...,so then you start with all your utterances he...,7.8,40,231,191,1.377551,6.910891,5.53334,219,102,99,49,33,24,19,16,14
26281,fifty one one two three four five six seven eight,"1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...",4.67,18,226,208,0.790323,3.741379,2.951057,213,104,56,28,18,14,11,9,8
18474,so i'm thinking if we imagine that we're takin...,So I'm thinking if we imagine that we're taki...,10.200001,54,53,-1,1.45082,1.396825,-0.053994,44,44,1,1,1,2,1,1,1
24166,oh there are issues oh there are issues,"Oh, there are issues. Oh, there are issues.",2.42,15,19,4,0.951219,1.023256,0.072036,13,14,1,1,1,2,1,1,1


In [14]:
for n in range(1, 8):
    print()
    print(f"{n = }")
    for row in df.sort_values(f"max_contiguous_ngrams_{n}", ascending=False).head(10).itertuples():
        print(row.teacher_text)
        print()
    print()


n = 1
 No no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no

 Yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yeah, yea

In [10]:
for row in df.sort_values("max_contiguous_ngrams_4", ascending=False).head(10).itertuples():
    print(row.teacher_text)
    print()

 No no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no no

 So, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so, so