## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%cd ..
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.getcwd()))))

In [None]:
from IPython.display import Audio, display
from tqdm.auto import tqdm

import numpy as np
import torch
import whisper
import datasets

import matplotlib.pyplot as plt
import seaborn as sns

from evaluation.eval_dataset_name_to_dataset_group import EVAL_DATASET_NAME_TO_DATASET_GROUP

device = torch.device('mps')
sns.set_theme(context="paper", style="ticks")

## User input

## Load model

In [None]:
# Load the Whisper model
model = whisper.load_model("tiny")

## Load dataset

In [None]:
dataset_name = "ami_10h"

ds_group = EVAL_DATASET_NAME_TO_DATASET_GROUP[dataset_name]()

if dataset_name == "librispeech_dummy":
    ds = ds_group.str2dataset["librispeech_dummy"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
elif dataset_name in ["ami", "ami_10h"]:
    ds = ds_group.str2dataset["ami"]
    ds = ds.map(lambda x: {"text": x.lower()}, input_columns=["text"])
else:
    raise ValueError()

In [None]:
# Fill `list_idx_hallucination` according to previous analysis:
list_idx_hallucination = [38, 223, 227, 602, 789, 852, 1146]

ds_hallucinations = ds.select(indices=list_idx_hallucination)

In [None]:
results = []
references = []
list_audios = []

for sample in tqdm(ds_hallucinations, total=ds_hallucinations.num_rows):
    results.append(model.transcribe(sample["audio"]["array"].astype(np.float32),
                                    language="en",
                                    temperature=0.0,
                                    word_timestamps=True))
    references.append(sample["text"].lower())
    list_audios.append(dict(data=sample["audio"]["array"], rate=sample["audio"]["sampling_rate"]))

In [None]:
for audio, result, ref in zip(list_audios, results, references):
    # Load Audio player:
    display(Audio(**audio))
    
    # Print the transcribed text and word timestamps
    print("Reference: ", ref)
    print("Prediction: ", result["text"])
    print()
     
    # Plot the token timestamps
    timestamps = []
    tokens = []
    counter = 0
    for segment in result["segments"]:
        for word in segment["words"]:
            timestamps.append(word["start"])
            timestamps.append(word["end"])
            tokens.append(str(counter) + word["word"])
            tokens.append(str(counter) + word["word"])
            counter += 1
    plt.figure(figsize=(8, len(tokens)*0.1))
    plt.plot(timestamps[::-1], tokens[::-1])
    plt.xlabel("Time (s)")
    plt.ylabel("Token")
    plt.title("Token Timestamps")
    plt.show()