In [1]:
from glob import glob
import pandas as pd
import json
import os

from pandarallel import pandarallel
from jiwer import wer
from tqdm import tqdm
import difflib

tqdm.pandas()
pandarallel.initialize(nb_workers=16, progress_bar=True)

INFO: Pandarallel will run on 16 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [2]:
threshold = 0.05

In [3]:
def load_cached_metadata(metadata_dirs):
    metadata = []
    for metadata_dir in metadata_dirs:
        filepaths = glob(f'{metadata_dir}/*.jsonl')
        for filepath in filepaths:
            for line in open(filepath).readlines():
                sample = json.loads(line)
                metadata.append(sample)
            
    return metadata

In [4]:
infer_data_dir = "/data/asr-research/src/kaldi/data/f88_infer"
train_data_dir = "/data/asr-research/src/kaldi/data/f88_train"

In [5]:
if not os.path.exists(train_data_dir):
    os.mkdir(train_data_dir)
    
if not os.path.exists(infer_data_dir):
    os.mkdir(infer_data_dir)

In [6]:
metadata_dirs = [
    "/data/asr-research/data/s2_stt_metadata_w2v2_part_1",
    "/data/asr-research/data/s2_stt_metadata_w2v2_part_2"
]
metadata_v1 = load_cached_metadata(metadata_dirs)
metadata_v1 = pd.DataFrame(metadata_v1)
metadata_v1["id"] = metadata_v1.audio_filepath.apply(lambda x: os.path.basename(x))
metadata_v1 = metadata_v1.drop_duplicates("audio_filepath")
metadata_v1.duration.sum() / 3600

np.float64(3098.1582365829545)

In [7]:
metadata_dirs = [
    "/data/asr-research/data/s2_stt_metadata_nemo_ctc"
]
metadata_v2 = load_cached_metadata(metadata_dirs)
metadata_v2 = pd.DataFrame(metadata_v2)
metadata_v2["id"] = metadata_v2.audio_filepath.apply(lambda x: os.path.basename(x))
metadata_v2 = metadata_v2.drop_duplicates("audio_filepath")
metadata_v2 = metadata_v2.drop(columns=["text"]).rename(columns={"pred_text": "pred"})
metadata_v2.duration.sum() / 3600

np.float64(3128.4959655861726)

In [8]:
df = pd.merge(metadata_v1, metadata_v2, on="id", how="inner", suffixes=["_nemo", "_w2v2"])
df["wer"] = df.progress_apply(lambda x: wer(x["pred_w2v2"], x["pred_nemo"]), axis=1)
df.shape

100%|██████████| 2879186/2879186 [01:14<00:00, 38767.40it/s]


(2879186, 10)

In [None]:
df.duration_w2v2.sum() / 3600

In [None]:
df[df.wer <= 0.08].duration_w2v2.sum() / 3600

In [None]:
df[df.wer < 0.05].duration_w2v2.hist(bins=100)

In [12]:
def get_the_same_segments(text_1, text_2):
    sqm = difflib.SequenceMatcher(None, text_1, text_2)
    result = sqm.get_matching_blocks()

    matching_segments = [
        " ".join(text_1[segment.a: segment.a + segment.size])
        for segment in result
        if segment.size >= 3
        ]
    
    return matching_segments

def save(metadata, data_dir):
    wavscp_path = f'{data_dir}/wav.scp'
    text_path = f'{data_dir}/text'
    spk2utt_path = f'{data_dir}/spk2utt'
    utt2spk_path = f'{data_dir}/utt2spk'
    meta_path = f'{data_dir}/meta.jsonl'

    def create_text_file(f, contents):
        line = "\t".join(contents)
        f.write(line + "\n")
        
    with open(meta_path, "w") as f:
        for index in metadata.index:
            json_obj = json.dumps(metadata.loc[index].to_dict(), ensure_ascii=False)
            f.write(json_obj + "\n")
        print(f'###saved to: {meta_path}')

    with open(wavscp_path, "w", encoding="utf-8") as f:
        metadata.sort_values("id").apply(lambda x: create_text_file(f, (x["id"], x["audio_filepath_w2v2"])), axis=1)
        print(f'###saved to: {wavscp_path}')

    with open(text_path, "w", encoding="utf-8") as f:
        metadata.sort_values("id").apply(lambda x: create_text_file(f, (x["id"], x["pred_w2v2"])), axis=1)
        print(f'###saved to: {text_path}')
        
    with open(spk2utt_path, "w", encoding="utf-8") as f:
        metadata.sort_values("id").apply(lambda x: create_text_file(f, (x["id"], x["id"])), axis=1)
        print(f'###saved to: {spk2utt_path}')
        
    with open(utt2spk_path, "w", encoding="utf-8") as f:
        metadata.sort_values("id").apply(lambda x: create_text_file(f, (x["id"], x["id"])), axis=1)
        print(f'###saved to: {utt2spk_path}')

In [None]:
filtered_metadata = df[df.wer > 0.05].copy()
filtered_metadata["matching_segments"] = filtered_metadata.parallel_apply(
    lambda row: get_the_same_segments(
        text_1=row["pred_w2v2"].split(), 
        text_2=row["pred_nemo"].split()
        ), 
    axis=1
)
filtered_metadata.duration_w2v2.sum() / 3600, filtered_metadata.shape[0]

In [15]:
# save(metadata=filtered_metadata, data_dir=infer_data_dir)
# print(f'###saved infer data to {infer_data_dir}')

In [None]:
filtered_metadata = []
for text, group in df[df.wer <= threshold].groupby("pred_w2v2"):
    if group.shape[0] > 32:
        group = group.sample(32, random_state=42)

    filtered_metadata.append(group)

filtered_metadata = pd.concat(filtered_metadata)
filtered_metadata.duration_w2v2.sum() / 3600

In [17]:
# save(metadata=filtered_metadata, data_dir=train_data_dir)
# print(f'###saved train data to {train_data_dir}')

In [18]:
saved_df = filtered_metadata[
    [
        "audio_filepath_w2v2", "duration_w2v2", 
        "pred_w2v2"
        ]
    ].copy()
saved_df = saved_df.rename(
    columns={
        "audio_filepath_w2v2": "audio_filepath",
        "duration_w2v2": "duration",
        "pred_w2v2": "text"
    }
)

In [19]:
clean_metadata_filepath = f"/data/asr-research/data/metadata/f88_wer_{threshold}_v1.jsonl"

In [None]:
with open(clean_metadata_filepath, "w") as f:
    for index in tqdm(saved_df.index):
        row = saved_df.loc[index].to_dict()
        json_obj = json.dumps(row, ensure_ascii=False)
        f.write(json_obj + "\n")

print(f'###saved metadata to {clean_metadata_filepath}')