In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, BertTokenizer, BertModel
from datasets import Dataset, load_from_disk
import evaluate
import torch

import numpy as np
import pandas as pd
import re

import matplotlib.pyplot as plt
import json

from tqdm import tqdm
from rapidfuzz.distance import Levenshtein, Opcodes

In [None]:
T5_MODEL_NAME = "/home/sulcm/models/t5/t5-spellchecker-cs-v4"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset = load_from_disk("/home/sulcm/datasets/t5/asr-correction-cs-v23/test")

In [None]:
dataset

# Create and compute eval data

In [None]:
wer_metric = evaluate.load("wer")

In [None]:
prefix = "spell check: "
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_NAME).to(device)

In [None]:
t5_correction_and_results = {
    "t5_correction": [],
    "w2v2_vs_target_wer": [],
    "t5_vs_target_wer": [],
    "w2v2_vs_t5_wer": []
}

for ds_row in tqdm(dataset):
    # asr_transcription, target_output
    inputs = t5_tokenizer(prefix + ds_row["asr_transcription"], return_tensors="pt").to(device)
    output_sequences = t5_model.generate(**inputs, max_new_tokens=64, num_beams=4, do_sample=True)
    corrected_input = t5_tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

    t5_correction_and_results["w2v2_vs_target_wer"].append(wer_metric.compute(predictions=(ds_row["asr_transcription"],), references=(ds_row["target_output"],)))
    t5_correction_and_results["t5_vs_target_wer"].append(wer_metric.compute(predictions=corrected_input, references=(ds_row["target_output"],)))
    t5_correction_and_results["w2v2_vs_t5_wer"].append(wer_metric.compute(predictions=(ds_row["asr_transcription"],), references=corrected_input))
    t5_correction_and_results["t5_correction"].extend(corrected_input)

In [None]:
# with open("./data/error_eval_ds_v23_test_w_t5_v4.json", "w") as f:
#     json.dump(t5_correction_and_results, f)

# Evaluation of results
---
!["error_classification"](./error_classification.svg)

In [None]:
class ST6ErrorAnalysis():
    def __init__(self, classes_def_path: str="") -> None:
        if classes_def_path:
            with open(classes_def_path, "r") as f:
                word_classes = json.load(f)
        else:
            word_classes = {}
        
        self.id2label = list(word_classes.keys())
        self.id2label.append("other")
        self.label2id = {label: id for id, label in enumerate(self.id2label)}
        self.word_classes_examples = list(word_classes.values())

        pass

    def get_error_class(lev_ops: dict):
        for action, src, reference in zip(lev_ops["action"], lev_ops["src"], lev_ops["reference"]):
            print("hello")

In [None]:
with open("./data/error_eval_ds_v23_test_w_t5_v19.json", "r") as f:
    t5_correction_and_results = json.load(f)

In [None]:
def compare_outputs(idx: int) -> None:
    print(
        f"Wav2Vec2.0 Transcription (WER = {t5_correction_and_results['w2v2_vs_target_wer'][idx]:.4f}):    " + dataset[idx]['asr_transcription'],
        f"T5 Correction (WER = {t5_correction_and_results['t5_vs_target_wer'][idx]:.4f}):               " + t5_correction_and_results['t5_correction'][idx],
        "Target output:                              " + dataset[idx]['target_output'],
        sep="\n"
    )

In [None]:
def levenshtein_ops(src: str, reference: str) -> dict:
    lev_ops = Levenshtein.editops(src, reference)
    lev_ops_dict = {
        "action": [],
        "src": [],
        "reference": [],
    }
    for ops in Opcodes.from_editops(lev_ops):
        if ops.tag != "equal":
            lev_ops_dict["action"].append(ops.tag)
            lev_ops_dict["src"].append((src[ops.src_start:ops.src_end],
                                        (src_start[-1] if (src_start := src[:ops.src_start].split(" ")) else "") + 
                                        src[ops.src_start:ops.src_end] + 
                                        (src_end[0] if (src_end := src[ops.src_end:].split(" ")) else "")))
            lev_ops_dict["reference"].append((reference[ops.dest_start:ops.dest_end], 
                                         (dest_start[-1] if (dest_start := reference[:ops.dest_start].split(" ")) else "") + 
                                         reference[ops.dest_start:ops.dest_end] + 
                                         (dest_end[0] if (dest_end := reference[ops.dest_end:].split(" ")) else "")))
    return lev_ops_dict

## T5 mistakes on ***correct*** W2V2 transcription

In [None]:
correct_w2v2_transcription = np.argwhere(np.transpose(t5_correction_and_results["w2v2_vs_target_wer"]) == 0.0).flatten()
incorrect_t5_correction = np.argwhere(np.transpose(t5_correction_and_results["t5_vs_target_wer"]) > 0.0).flatten()
correct_asr_transcription_incorect_t5_correction = set(correct_w2v2_transcription).intersection(set(incorrect_t5_correction))
len(correct_asr_transcription_incorect_t5_correction)

In [None]:
correct_asr_transcription_incorect_t5_correction

In [None]:
idx = 600
levenshtein_ops(src=dataset[idx]["asr_transcription"], reference=dataset[idx]["target_output"])

In [None]:
compare_outputs(idx=idx)

## T5 good corrections on ***bad*** W2V2 transcription

In [None]:
incorrect_w2v2_transcription = np.argwhere(np.transpose(t5_correction_and_results["w2v2_vs_target_wer"]) > 0.0).flatten()
correct_t5_correction = np.argwhere(np.transpose(t5_correction_and_results["t5_vs_target_wer"]) == 0.0).flatten()
correct_t5_correction_on_bad_asr_transcription = set(correct_t5_correction).intersection(set(incorrect_w2v2_transcription))
len(correct_t5_correction_on_bad_asr_transcription)

In [None]:
correct_t5_correction_on_bad_asr_transcription

## Correct ASR and correction

In [None]:
correct_asr_w_correction = set(correct_w2v2_transcription).intersection(set(correct_t5_correction))
len(correct_asr_w_correction)

In [None]:
correct_asr_w_correction

## Incorrect ASR and incorrect correction

In [None]:
incorrect_asr_w_correction = set(incorrect_w2v2_transcription).intersection(set(incorrect_t5_correction))
incorrect_asr_w_correction_idx = list(incorrect_asr_w_correction)
len(incorrect_asr_w_correction)

In [None]:
incorrect_asr_w_correction

### Less incorrect ASR (T5 corrected some mistakes)

In [None]:
less_incorrect_asr_idx = np.argwhere(np.array(t5_correction_and_results["w2v2_vs_target_wer"])[incorrect_asr_w_correction_idx] > np.array(t5_correction_and_results["t5_vs_target_wer"])[incorrect_asr_w_correction_idx]).flatten()
less_incorrect_asr = set(np.array(incorrect_asr_w_correction_idx)[less_incorrect_asr_idx])
len(less_incorrect_asr)

In [None]:
less_incorrect_asr

### More incorrect ASR (T5 made more mistakes then repaired) 

In [None]:
more_incorrect_asr = incorrect_asr_w_correction.difference(less_incorrect_asr)
len(more_incorrect_asr)

In [None]:
more_incorrect_asr

# Measuring semantic closeness between reference sentence and infered ones

In [None]:
bert_tokenizer = BertTokenizer.from_pretrained("fav-kky/FERNET-C5")
bert_model = BertModel.from_pretrained("fav-kky/FERNET-C5").to(device)

In [None]:
cosine_sim = torch.nn.CosineSimilarity(dim=-1)

In [None]:
semantic_sim = {
    "sim_w2v2_to_ref": [],
    "sim_t5_to_ref": []
}
for i in tqdm(range(len(dataset))):
    ref = dataset[i]["target_output"]
    
    inputs = bert_tokenizer([dataset[i]["asr_transcription"], ref], padding=True, return_tensors="pt").to(device)
    cls_emb = bert_model(**inputs).last_hidden_state[:, 0, :]
    semantic_sim["sim_w2v2_to_ref"].append(cosine_sim.forward(cls_emb[0], cls_emb[1]).item())

    inputs = bert_tokenizer([t5_correction_and_results["t5_correction"][i], ref], padding=True, return_tensors="pt").to(device)
    cls_emb = bert_model(**inputs).last_hidden_state[:, 0, :]
    semantic_sim["sim_t5_to_ref"].append(cosine_sim.forward(cls_emb[0], cls_emb[1]).item())

In [None]:
np.mean(semantic_sim["sim_w2v2_to_ref"])

In [None]:
np.min(semantic_sim["sim_w2v2_to_ref"])

In [None]:
np.std(semantic_sim["sim_w2v2_to_ref"])

In [None]:
np.mean(semantic_sim["sim_t5_to_ref"])

In [None]:
np.min(semantic_sim["sim_t5_to_ref"])

In [None]:
np.std(semantic_sim["sim_t5_to_ref"])

In [None]:
dataset[441]["asr_transcription"]

In [None]:
t5_correction_and_results["t5_correction"][441]

In [None]:
dataset[441]["target_output"]

In [None]:
idx = 750
w2v2_output=dataset[idx]["asr_transcription"]
t5_output=t5_correction_and_results["t5_correction"][idx]
ref = dataset[idx]["target_output"]

In [None]:
w2v2_output

In [None]:
t5_output

In [None]:
ref

In [None]:
inputs = bert_tokenizer([t5_output, ref], padding=True, return_tensors="pt").to(device)

In [None]:
embeddings = bert_model(**inputs)
cls_emb = embeddings.last_hidden_state[:, 0, :]

In [None]:
cls_emb

In [None]:
cls_emb

In [None]:
cosine_sim.forward(cls_emb[0], cls_emb[1])