In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
from datasets import Dataset, load_from_disk
import evaluate
import torch

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import json
from tqdm import tqdm

In [None]:
dataset = load_from_disk("/home/sulcm/datasets/t5/asr-correction-cs-v23/test")

# Create and compute eval data

In [None]:
T5_MODEL_NAME = "/home/sulcm/models/t5/t5-spellchecker-cs-v2"

In [None]:
wer_metric = evaluate.load("wer")

In [None]:
dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
prefix = "spell check: "
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_NAME).to(device)

In [None]:
t5_correction_and_results = {
    "t5_correction": [],
    "w2v2_vs_target_wer": [],
    "t5_vs_target_wer": [],
    "w2v2_vs_t5_wer": []
}

for ds_row in tqdm(dataset):
    # asr_transcription, target_output
    inputs = t5_tokenizer(prefix + ds_row["asr_transcription"], return_tensors="pt").to(device)
    output_sequences = t5_model.generate(**inputs, max_new_tokens=64, num_beams=4, do_sample=True)
    corrected_input = t5_tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

    t5_correction_and_results["w2v2_vs_target_wer"].append(wer_metric.compute(predictions=(ds_row["asr_transcription"],), references=(ds_row["target_output"],)))
    t5_correction_and_results["t5_vs_target_wer"].append(wer_metric.compute(predictions=corrected_input, references=(ds_row["target_output"],)))
    t5_correction_and_results["w2v2_vs_t5_wer"].append(wer_metric.compute(predictions=(ds_row["asr_transcription"],), references=corrected_input))
    t5_correction_and_results["t5_correction"].extend(corrected_input)

In [None]:
with open("./data/error_eval_ds_v23_test_w_t5_v2.json", "w") as f:
    json.dump(t5_correction_and_results, f)

# Evaluation of results
---
!["error_classification"](./error_classification.svg)

In [None]:
with open("./data/error_eval_ds_v23_test_w_t5_v4.json", "r") as f:
    t5_correction_and_results = json.load(f)

In [None]:
def compare_outputs(idx: int) -> None:
    print(
        f"Wav2Vec2.0 Transcription (WER = {t5_correction_and_results['w2v2_vs_target_wer'][idx]:.4f}):    " + dataset[idx]['asr_transcription'],
        f"T5 Correction (WER = {t5_correction_and_results['t5_vs_target_wer'][idx]:.4f}):               " + t5_correction_and_results['t5_correction'][idx],
        "Target output:                              " + dataset[idx]['target_output'],
        sep="\n"
    )

## T5 mistakes on ***correct*** W2V2 transcription

In [None]:
correct_w2v2_transcription = np.argwhere(np.transpose(t5_correction_and_results["w2v2_vs_target_wer"]) == 0.0).flatten()
incorrect_t5_correction = np.argwhere(np.transpose(t5_correction_and_results["t5_vs_target_wer"]) > 0.0).flatten()
correct_asr_transcription_incorect_t5_correction = set(correct_w2v2_transcription).intersection(set(incorrect_t5_correction))
len(correct_asr_transcription_incorect_t5_correction)

In [None]:
correct_asr_transcription_incorect_t5_correction

## T5 good corrections on ***bad*** W2V2 transcription

In [None]:
incorrect_w2v2_transcription = np.argwhere(np.transpose(t5_correction_and_results["w2v2_vs_target_wer"]) > 0.0).flatten()
correct_t5_correction = np.argwhere(np.transpose(t5_correction_and_results["t5_vs_target_wer"]) == 0.0).flatten()
correct_t5_correction_on_bad_asr_transcription = set(correct_t5_correction).intersection(set(incorrect_w2v2_transcription))
len(correct_t5_correction_on_bad_asr_transcription)

In [None]:
correct_t5_correction_on_bad_asr_transcription

## Correct ASR and correction

In [None]:
correct_asr_w_correction = set(correct_w2v2_transcription).intersection(set(correct_t5_correction))
len(correct_asr_w_correction)

In [None]:
correct_asr_w_correction

## Incorrect ASR and incorrect correction

In [None]:
incorrect_asr_w_correction = set(incorrect_w2v2_transcription).intersection(set(incorrect_t5_correction))
incorrect_asr_w_correction_idx = list(incorrect_asr_w_correction)
len(incorrect_asr_w_correction)

In [None]:
incorrect_asr_w_correction

### Less incorrect ASR (T5 corrected some mistakes)

In [None]:
less_incorrect_asr_idx = np.argwhere(np.array(t5_correction_and_results["w2v2_vs_target_wer"])[incorrect_asr_w_correction_idx] > np.array(t5_correction_and_results["t5_vs_target_wer"])[incorrect_asr_w_correction_idx]).flatten()
less_incorrect_asr = set(np.array(incorrect_asr_w_correction_idx)[less_incorrect_asr_idx])
len(less_incorrect_asr)

In [None]:
less_incorrect_asr

### More incorrect ASR (T5 made more mistakes then repaired) 

In [None]:
more_incorrect_asr = incorrect_asr_w_correction.difference(less_incorrect_asr)
len(more_incorrect_asr)

In [None]:
more_incorrect_asr