In [None]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, T5Tokenizer, T5ForConditionalGeneration
from datasets import Dataset, load_dataset, Audio, DatasetDict, load_from_disk
import evaluate
import jiwer
import numpy as np

import torch

from IPython.display import Audio as AudioDisp

from asr_w_spellchecker import ST6

import logging

In [None]:
DATASET = {
    # "path": "facebook/voxpopuli",
    "path": "mozilla-foundation/common_voice_11_0",
    "name": "cs",
    "split": "test"
}

SAMPLING_RATE = 16_000

WAV2VEC_MODEL_NAME = "/home/sulcm/models/wav2vec2/wav2vec2-cs-v23"

T5_MODEL_NAME = "/home/sulcm/models/t5/t5-spellchecker-cs-v4"

In [None]:
dataset = load_dataset(DATASET['path'], DATASET['name'], split=DATASET['split'])
dataset = dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))

In [None]:
dataset = load_from_disk("/home/sulcm/datasets/t5/asr-correction-cs-v23")

In [None]:
dataset

In [None]:
st6_model = ST6(wav2vec2_path=WAV2VEC_MODEL_NAME, t5_path=T5_MODEL_NAME, logging_level=logging.DEBUG)

In [None]:
wav2vec_processor = Wav2Vec2Processor.from_pretrained(WAV2VEC_MODEL_NAME)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(WAV2VEC_MODEL_NAME)

In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_NAME)

In [None]:
inputs = t5_tokenizer(["spell check: " + sentence for sentence in dataset["normalized_text"]], return_tensors="pt", padding=True)

In [None]:
np.percentile(np.count_nonzero(inputs.input_ids, axis=1), q=100)

In [None]:
metrics = eval_metrics = {metric: evaluate.load(metric) for metric in ["sacrebleu", "wer", "cer"]}

In [None]:
idx = 1500

input_audio = dataset[idx]['audio']
sentence = dataset[idx]['sentence'].lower()

print(sentence)
AudioDisp(input_audio['path'])

In [None]:
st6_model(input_audio['array'], False)

In [None]:
outputs = st6_model([d['array'] for d in dataset['audio']])

In [None]:
outputs

In [None]:
def levenstein(ref_word: str, comp_word: str) -> int:
    # mx = []
    prev_row = list(range(len(comp_word)+1))
    # mx.append(prev_row)
    min_err = 0
    for i, l1 in enumerate(ref_word):
        curr_row = [i+1]
        for j, l2 in enumerate(comp_word):
            curr_row.append(min([prev_row[j+1]+1, curr_row[j]+1, prev_row[j]+(l1 != l2)]))
        if (current_min_err := min(curr_row)) > min_err:
            min_err = current_min_err
            idx = curr_row.index(min_err)
            print(l1, comp_word[idx])
            # if curr_row[idx] == curr_row[idx+1]:
            #     print(l1, comp_word[idx+1])
        # if (idx := np.argwhere(np.subtract(curr_row, prev_row) == 0)):
        #     print(l1, comp_word[idx[0][0]-1])
        prev_row = curr_row
        # mx.append(prev_row)
    # print("\n".join([str(r) for r in mx]))
    return prev_row[-1]

In [None]:
i = 0
for ref, pred in zip(*outputs):
    print(str(i) + ":")
    levenstein(ref_word=ref.split(), comp_word=pred.split())
    i += 1

In [None]:
idx = 10
print(dataset[idx]["sentence"], outputs[0][idx], outputs[1][idx], sep="\n")

In [None]:
{name: metric.compute(predictions=['okresy nemají v současnosti na rozdíl od krajů právní funkci'], references=[sentence]) for name, metric in metrics.items()}

In [None]:
inputs = wav2vec_processor(input_audio['array'], sampling_rate=SAMPLING_RATE, return_tensors='pt')

with torch.no_grad():
    logits = wav2vec_model(**inputs).logits

pred_ids = torch.argmax(logits, dim=-1)
transcription = wav2vec_processor.batch_decode(pred_ids)
transcription

In [None]:
inputs = t5_tokenizer(["spell check: " + sentence for sentence in transcription], return_tensors="pt")

output_sequences = t5_model.generate(**inputs, max_new_tokens=20)

t5_tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

Stuff...

In [None]:
t5_dataset = load_from_disk("/home/sulcm/datasets/t5/asr-correction-cs-v23")

In [None]:
t5_dataset["test"][1]

In [None]:
model = torch.load("/home/sulcm/models/wav2vec2/wav2vec2-cs-v1/pytorch_model.bin")

In [None]:
model.keys()

In [None]:
model['wav2vec2.encoder.layers.11.final_layer_norm.weight'].shape

In [None]:
model['lm_head.weight'].shape