In [None]:
from tqdm import tqdm
import pandas as pd
import torchaudio
import librosa
import shutil
import json
import os

from pandarallel import pandarallel
pandarallel.initialize(nb_workers=8, progress_bar=True)

In [None]:
data_type = 12
data_root_dir = "/home/tuyendv/E2E-R/data/raw" 

text_label = "text"
n_sample_per_question_id = 268

In [None]:
type2path = {
    12: {
        "json_dir": "/data/metadata/apa-en/marking-data/12",
        "audio_dir": "/data/audio/prep-submission-audio/apa-type-12",
        "metadata_path": "/data/metadata/apa-en/merged-info/info_question_type-12_01082022_18092023.csv"
    },
}

In [None]:
path_dict = type2path[data_type]

data_name = os.path.basename(path_dict["metadata_path"]).split(".")[0]
data_dir = os.path.join(data_root_dir, data_name)
    
in_jsonl_path = f'{data_dir}/metadata-raw.jsonl'
out_jsonl_path = f'{data_dir}/metadata.jsonl'
out_csv_path = f'{data_dir}/metadata.csv'

In [None]:
def load_jsonl_data(path):
    with open(path, "r", encoding="utf-8") as f:
        content = f.readlines()
        lines = [json.loads(line.strip()) for line in content]
    data = pd.DataFrame(lines)

    return data

def save_jsonl_data_row_level(data, path):
    with open(path, "w", encoding="utf-8") as f:
        for index in tqdm(data.index):
            sample = data.loc[index].to_dict()
            json_obj = json.dumps(sample)

            f.write(f'{json_obj}\n')

    print(f'###saved jsonl data to: {path}')
    
    
hparams = {
    "in_jsonl_path": in_jsonl_path,
    "out_csv_path": out_csv_path,
    "out_jsonl_path": out_jsonl_path,
    "n_sample_per_question_id": n_sample_per_question_id,
    "text_label": text_label
}

metadata = load_jsonl_data(hparams["in_jsonl_path"])
metadata.head(1)

In [None]:
def filter_data_with_text(data, text_label="text", n_sample_per_question_id=268):
    print(f'### shape before filtering: {data.shape}')
    filtered_data = []
    for name, group in data.groupby(text_label):
        if group.shape[0] >= n_sample_per_question_id:
            samples = group.sample(n_sample_per_question_id)
            filtered_data.append(samples)
        else:
            filtered_data.append(group)
    filtered_data = pd.concat(filtered_data)
    print(f'### shape after filtering: {filtered_data.shape}')
    return filtered_data

filtered_metadata = filter_data_with_text(
    data=metadata, text_label=hparams["text_label"],
    n_sample_per_question_id=hparams["n_sample_per_question_id"]
)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def plot_score_distribution(data, label, phn_score_label="phone_scores", wrd_score_label="word_scores", utt_score_label="utterance_score"):
    names = [phn_score_label, wrd_score_label, utt_score_label]
    
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(16, 4))

    for index, name in enumerate(names):
        if name == utt_score_label:
            scores = data.apply(lambda row: row[name], axis=1)
        else:
            scores = data.apply(lambda row: row[name], axis=1).explode()

        sns.histplot(data=scores, bins=100, ax=axes[index])
        axes[index].set_xlabel(name)

    lengths = data.apply(lambda row: row[phn_score_label], axis=1).apply(len)
    sns.histplot(data=lengths, bins=100, ax=axes[-1])
    axes[-1].set_xlabel("length")
    axes[-1].set_ylabel("")

    plt.title(label)
    return fig

fig = plot_score_distribution(
    filtered_metadata, label=os.path.basename(out_jsonl_path))

In [None]:
length = filtered_metadata.arpas.apply(len)
filtered_metadata = filtered_metadata[length < 124]

In [None]:
##save data visualization
fig.savefig(f'{data_dir}/data-visualize.png')

##save data in kaldi format
filtered_metadata[["id", "text"]].to_csv(
    hparams["out_csv_path"], sep="|", index=None, header=None)

##save data in jsonl format
save_jsonl_data_row_level(
    data=filtered_metadata, path=hparams["out_jsonl_path"])