## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

In [None]:
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd

from transformers.models.whisper import WhisperTokenizerFast
from datasets import load_from_disk

import matplotlib.pyplot as plt
import seaborn as sns

from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP
from evaluation.string_edit_metrics import get_string_edit_metrics_ortho_and_norm
from normalization.whisper_normalization import get_whisper_normalizer
from utils.whisper_hallucinations.get_features import add_features_to_ds, compute_gzip_compression_ratio
from utils.whisper_hallucinations.eval_filter_criterion import eval_filter_criterion
from utils.notebook_utils import listen_to_audio

sns.set_theme(context="paper", style="ticks")

OUTPUT_DIR = Path("notebooks/outputs/8_1_best_kd/librispeech_100h")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## User input

In [None]:
load_from_pickle = True

pickle_filepath = "/home/tw581/rds/hpc-work/librispeech_medium_cached_33p.pkl"
ds_dirpath = "/home/tw581/rds/rds-altaslp-8YSp2LXTlkY/experiments/tw581/cache/huggingface/k_beam_search_cache/librispeech_clean_100h/whisper-medium/k_1/train"

## Load tokenizer

In [None]:
pretrained_model_name_or_path = "openai/whisper-tiny"
tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")

## Load data

In [None]:
LIST_FEATURES = [
    'text',
    'teacher_text',
    'n_instant_tokens',
    'max_subarray_length',
    'audio_length',
    'n_tokens_labels',
    'n_tokens_teacher',
    'diff_n_tokens',
    'gzip_ratio',
    'teacher_gzip_ratio',
    'diff_gzip_ratio'
]

if load_from_pickle:
    df = pd.read_pickle(pickle_filepath)
else:
    ds = load_from_disk(ds_dirpath)
    ds = ds.select(list(range(ds.num_rows // 3)))
    ds = ds.map(lambda x: {"teacher_text": tokenizer.decode(x["teacher_sequences"], skip_special_tokens=True)})
    ds = add_features_to_ds(ds)
    df = pd.DataFrame({col: ds[col] for col in ds.features.keys() if col in LIST_FEATURES})
    df.to_pickle(pickle_filepath)

In [None]:
df.head()

In [None]:
df.sort_values("n_tokens_teacher", ascending=False)[:10]

## First analysis

In [None]:
dict_string_edit_metrics = get_string_edit_metrics_ortho_and_norm(references=df["text"], predictions=df["teacher_text"], norm_fn=get_whisper_normalizer("english"))

dict_string_edit_metrics

**Observation:** Using the Whisper normalizer drastically decreased the different string edit metric errors. Since 1-best operated without any normalization, we will focus on the orthographic WER in this study.

In [None]:
plt.figure(figsize=(5, 3))
df["audio_length"].plot.hist();

In [None]:
plt.figure(figsize=(12, 3))
sns.boxplot(data=df[["n_tokens_labels", "n_tokens_teacher"]], orient="h");

In [None]:
x_col = "n_tokens_labels"
y_col = "n_tokens_teacher"

sns.jointplot(data=df, x=x_col, y=y_col, alpha=0.3)
line_max_coord = min(df[x_col].max(), df[y_col].max())
plt.plot([0, line_max_coord], [0, line_max_coord], 'b--', label=r"$y=x$")
plt.legend()
plt.tight_layout()

savepath = OUTPUT_DIR / "analysis" / "n_tokens_teacher_wrt_n_tokens_label.png"
savepath.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(savepath)
print(f"Figure saved at `{savepath}`.")

In [None]:
plt.figure(figsize=(12, 1.5))
sns.boxplot(data=df[["diff_n_tokens"]], orient="h")

savepath = OUTPUT_DIR / "analysis" / "boxplot_n_diff_tokens.png"
savepath.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(savepath)
print(f"Figure saved at `{savepath}`.")

In [None]:
plt.figure(figsize=(12, 1.5))
sns.boxplot(data=df[["diff_n_tokens"]], orient="h")
plt.xlim(-50, 50);

In [None]:
sns.scatterplot(data=df, x="audio_length", y="diff_n_tokens", alpha=0.3);

In [None]:
fig, axis = plt.subplots(1, 2, figsize=(8, 3), sharey=True)
sns.scatterplot(data=df, x="audio_length", y="n_tokens_labels", label="Labels", alpha=0.3, ax=axis[0])
sns.scatterplot(data=df, x="audio_length", y="n_tokens_teacher", label="Predictions", c="coral", alpha=0.3, ax=axis[1])
axis[0].set_ylabel("n_tokens")
fig.tight_layout()

savepath = OUTPUT_DIR / "analysis" / "n_tokens_wrt_audio_length.png"
savepath.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(savepath)
print(f"Figure saved at `{savepath}`.")