In [None]:
import os, sys, glob
import git

root = git.Repo('.', search_parent_directories=True).working_tree_dir
os.chdir(root)
print(f"Changed working directory to {root}")

import pandas as pd

In [None]:
asr_outputs = glob.glob('data/asr_outputs/*.csv')
vsr_outputs = glob.glob('data/vsr_outputs/*.csv')
print(f"{len(asr_outputs)=} {len(vsr_outputs)=}")

# Concatenate ASR outputs in to a single dataframe
def read_transcriptions(paths):
    transcriptions = {}
    for path in paths:
        df = pd.read_csv(path)

        # path_col is the first column
        path_col, text_col = df.columns

        # Filenames are of the format: `<prefix path>/id08701/z8t-KFSoYLI/00478.<ext>`
        # We want to extract the filename `id08701/z8t-KFSoYLI/00478` from them
        df[path_col] = df[path_col].apply(
            lambda x: '/'.join(os.path.splitext(x)[0].split('/')[-3:])
        )

        for i, r in df.iterrows():
            if r[text_col] == 'None' or pd.isna(r[text_col]):
                continue
            transcriptions[r[path_col]] = r[text_col].strip()
    return transcriptions

asr = read_transcriptions(asr_outputs)
vsr = read_transcriptions(vsr_outputs)

In [None]:
from tqdm import tqdm
from jiwer import wer, cer

wers = {}
for k in tqdm(vsr.keys()):
    if k in asr:
        wers[k] = cer(asr[k], vsr[k])

In [None]:
# Average WER
if len(wers) > 0:
    avg_wer = sum(wers.values()) / len(wers)
    print(f"Average WER: {avg_wer:.4f}")