## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

In [None]:
from typing import Dict, Any
from tqdm.auto import tqdm

import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from transformers import pipeline
from transformers.models.whisper import (WhisperTokenizer,
                                         WhisperTokenizerFast,
                                         WhisperFeatureExtractor,
                                         WhisperForConditionalGeneration)

from dataloader.dataset_loader import gen_from_dataset
from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP
from evaluation.string_edit_metrics import get_string_edit_metrics

from utils.constants import GEN_MAX_LENGTH, DEFAULT_EVAL_NUM_BEAMS

device = torch.device('mps')
sns.set_theme(context="paper", style="ticks")

## User input

## Load model

In [None]:
pretrained_model_name_or_path = "openai/whisper-tiny"

model = WhisperForConditionalGeneration.from_pretrained(pretrained_model_name_or_path)
feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path, language="english", task="transcribe")

whisper_norm = tokenizer._normalize

## Load dataset

In [None]:
dataset_name = "ami_10h"

ds_group = EVAL_DATASET_NAME_TO_DATASET_GROUP[dataset_name]()

if dataset_name == "librispeech_dummy":
    ds = ds_group.str2dataset["librispeech_dummy"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
elif dataset_name in ["ami", "ami_10h"]:
    ds = ds_group.str2dataset["ami"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
else:
    raise ValueError()

## Create pipeline

In [None]:
whisper_asr = pipeline(task="automatic-speech-recognition",
                       model=model,
                       tokenizer=tokenizer,
                       feature_extractor=feature_extractor,
                       device=device)

## Run pipeline

In [None]:
generate_kwargs = {"max_length": GEN_MAX_LENGTH,
                   "num_beams": DEFAULT_EVAL_NUM_BEAMS,
                   "language": language,
                   "task": task}

# Create placeholders for the predictions and references:
predictions = []
references = []

for out in tqdm(whisper_asr(gen_from_dataset(ds),
                            batch_size=32,
                            generate_kwargs=generate_kwargs),
                total=ds.num_rows):
    ref = out["reference"][0].lower()
    pred = out["text"].lower()
    references.append(ref)
    predictions.append(pred)

## Compute string edit metrics

In [None]:
string_edit_metrics = 100 * pd.Series(get_string_edit_metrics(references=references, predictions=predictions))

string_edit_metrics

## Analysis

### Add audio length to the dataset features

In [None]:
def get_audio_length_in_seconds(x: Dict[str, Any]) -> Dict[str, float]:
    assert "audio" in x  # TODO
    audio = x["audio"]
    audio_length = len(audio["array"]) / audio["sampling_rate"]
    return {"audio_length": audio_length}

In [None]:
ds = ds.map(get_audio_length_in_seconds)

ds.features

### Add predictions to the dataset features

In [None]:
assert ds.num_rows == len(references) == len(predictions)
ds = ds.add_column(name="pred", column=predictions)

ds.features

### Tokenize both labels and predictions

In [None]:
ds = ds.map(lambda batch: {"labels": tokenizer(batch["text"]).input_ids,
                           "pred_tokenized": tokenizer(batch["pred"]).input_ids},
            batched=True)

ds.features

### Add n_tokens to the dataset features

In [None]:
ds = ds.map(lambda x: {"n_tokens_labels": len(x["labels"]), "n_tokens_pred": len(x["pred_tokenized"])})

ds.features

### Get DataFrame

In [None]:
cols_of_interest = ["audio_length", "text", "labels", "n_tokens_labels", "pred", "pred_tokenized", "n_tokens_pred"]
df = pd.DataFrame({col: ds[col] for col in cols_of_interest})

df["diff_n_tokens"] = df["n_tokens_pred"] - df["n_tokens_labels"]

df.head()

In [None]:
plt.figure(figsize=(5, 3))
df["audio_length"].plot.hist();

In [None]:
plt.figure(figsize=(5, 3))
sns.histplot(data=df[["n_tokens_labels", "n_tokens_pred"]])

plt.figure(figsize=(5, 3))
sns.histplot(data=df["diff_n_tokens"]);

In [None]:
df["n_tokens_labels"].value_counts(bins=10, sort=False)

In [None]:
df["n_tokens_labels"].describe()

In [None]:
df["n_tokens_pred"].value_counts(bins=10, sort=False)

In [None]:
df["n_tokens_pred"].describe()

In [None]:
df["diff_n_tokens"].value_counts(bins=10, sort=False)

In [None]:
df["diff_n_tokens"].describe()

In [None]:
plt.figure(figsize=(5, 3))
sns.histplot(data=df["diff_n_tokens"])
plt.xlim(-15, 30)

In [None]:
df_candidates = df[df["diff_n_tokens"]>=7]

In [None]:
for idx in range(len(df_candidates)):
    print(f"Idx = {df_candidates.iloc[idx].name}")
    print("Reference: ", df_candidates.iloc[idx]["text"])
    print("Prediction: ", df_candidates.iloc[idx]["pred"])
    print()

In [None]:
df["max_token_repetitions"] = df["pred_tokenized"].apply(lambda x: pd.Series(x).value_counts().max())
df["max_token_repetitions"]

In [None]:
df["df["max_token_repetitions"].value_counts(bins=[0, 1, 2, 3, 4, 5, 10, 20, 50, 100], sort=False)"].value_counts(sort=False)

In [None]:
df["max_token_repetitions"].value_counts(bins=[0, 1, 2, 3, 4, 5, 10, 20, 50, 100], sort=False)

In [None]:
for x in df.loc[df["max_token_repetitions"]>=7, "pred"]:
    print(x)

In [None]:
sns.scatterplot(data=df, x="audio_length", y="diff_n_tokens");

In [None]:
sns.scatterplot(data=df, x="audio_length", y="diff_n_tokens")
plt.ylim(0, 10)