In [2]:
import pandas as pd
import numpy as np
import json
from glob import glob
import os

from datasets import Dataset, Audio, load_dataset
from tqdm import tqdm

In [42]:
DATASET_NAME = "mozilla-foundation/common_voice_11_0"
NAME = "cs"
SPLIT = "test"
SAMPLING_RATE = 16_000

In [43]:
hf_dataset = load_dataset(DATASET_NAME, NAME, split=SPLIT)
hf_dataset = hf_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))

In [44]:
wav_lenghts = np.array([len(example['audio']['array']) for example in tqdm(hf_dataset.to_iterable_dataset(), total=len(hf_dataset))])

100%|██████████| 7714/7714 [00:26<00:00, 289.54it/s]


In [45]:
wav_lenghts_secs = wav_lenghts/SAMPLING_RATE

In [46]:
np.mean(wav_lenghts_secs)

4.5338156922478605

In [47]:
total_secs = np.sum(wav_lenghts_secs)
total_secs, total_secs/3600

(34973.85425, 9.714959513888887)

Training table sumarisation

In [3]:
train_paths = glob('/home/sulcm/models/wav2vec2/*')

In [11]:
def format_latex_table(results: dict, best_results: dict) -> str:
    formated_table = ''
    for run, metrics in results.items():
        formated_table += run + ' & ' + ' & '.join([f'\\textbf{{{v}}}' if run == best_results[m][0] else v for m, v in metrics.items()]) + ' \\\\\n'
    
    return formated_table

In [19]:
metrics = ['wer', 'cer']
best_results = dict.fromkeys(metrics, ('', 1.0))
table_prep = {}

for train_res in train_paths:
    path2res = train_res + '/all_results.json'
    if not os.path.exists(path2res):
        continue

    with open(path2res, 'r') as f:
        results = json.load(f)
    
    table_prep[train_res.split('/')[-1].split('-')[-1]] = {m: f'{100.0*results[f"eval_{m}"]:.02f}' for m in metrics}

    for metric in metrics:
        if best_results[metric][1] > results[f"eval_{metric}"]:
            best_results[metric] = (train_res.split('/')[-1].split('-')[-1], results[f"eval_{metric}"])

print(table_prep)
print(best_results)

{'v3': {'wer': '12.63', 'cer': '2.92'}, 'v1': {'wer': '16.49', 'cer': '3.70'}, 'baseline': {'wer': '12.73', 'cer': '2.91'}, 'v2': {'wer': '14.47', 'cer': '3.24'}, 'v4': {'wer': '13.42', 'cer': '3.07'}, 'v5': {'wer': '12.90', 'cer': '2.95'}, 'v6': {'wer': '11.58', 'cer': '2.66'}, 'v7': {'wer': '11.73', 'cer': '2.71'}, 'v8': {'wer': '11.71', 'cer': '2.79'}, 'v13': {'wer': '14.80', 'cer': '3.29'}, 'v16': {'wer': '14.83', 'cer': '3.37'}}
{'wer': ('v6', 0.11579490391878697), 'cer': ('v6', 0.026599265812120118)}


In [20]:
print(format_latex_table(table_prep, best_results))

v3 & 12.63 & 2.92 \\
v1 & 16.49 & 3.70 \\
baseline & 12.73 & 2.91 \\
v2 & 14.47 & 3.24 \\
v4 & 13.42 & 3.07 \\
v5 & 12.90 & 2.95 \\
v6 & \textbf{11.58} & \textbf{2.66} \\
v7 & 11.73 & 2.71 \\
v8 & 11.71 & 2.79 \\
v13 & 14.80 & 3.29 \\
v16 & 14.83 & 3.37 \\

