In [None]:
import pandas as pd
import os
from glob import glob
import json
from pandarallel import pandarallel
from tqdm import tqdm
import torchaudio
import random
import re

pandarallel.initialize(nb_workers=8, progress_bar=True)

In [None]:
type2path = {
    12: {
        "json_dir": "/data/metadata/apa-en/marking-data/12",
        "audio_dir": "/data/audio/prep-submission-audio/apa-type-12",
        "metadata_path": "/data/metadata/apa-en/merged-info/info_question_type-12_01082022_18092023.csv"
    },
}

In [None]:
_type = 12
path_dict = type2path[_type]

data_root_dir = "/home/tuyendv/E2E-R/data/raw" 
data_name = os.path.basename(path_dict["metadata_path"]).split(".")[0]
data_dir = os.path.join(data_root_dir, data_name)

if not os.path.exists(data_dir):
    os.mkdir(data_dir)

out_raw_json_path = f'{data_dir}/metadata-raw.jsonl'

In [None]:
hparams = {
    "json_dir": path_dict["json_dir"],
    "audio_dir": path_dict["audio_dir"],
    "metadata_path": path_dict["metadata_path"],
    "out_jsonl_path": out_raw_json_path
}

metadata = pd.read_csv(hparams["metadata_path"])
metadata.head(2)

In [None]:
def is_valid_audio(audio_id):
    abs_path = os.path.join(hparams["audio_dir"], f'{audio_id}.wav')
    if not os.path.exists(abs_path):
        return False
    try:
        wav, sr = torchaudio.load(abs_path)
        if sr != 16000:
            return False
    except:
        return False
    
    return True

is_exist =  metadata.id.parallel_apply(is_valid_audio)
print(metadata.shape)
metadata = metadata[is_exist]
print(metadata.shape)

In [None]:
def filter_data(data):
    print(f'### shape before filtering: {data.shape}')
    data = data[data.total_time > 1.0]
    data = data[data.total_time < 6.0]
    data = data[data.word_count < 4]
    # data = data[0:20000]
    print(f'### shape after filtering: {data.shape}')
    return data

metadata = filter_data(metadata)

In [None]:
def norm_text(text):
    text = re.sub(r"[\,\.\!\?\:\;]", " ", text)
    text = re.sub("\s+", " ", text).strip()
    text = text.upper()

    return text

def is_valid_phoneme(phoneme):
    if phoneme["phoneme_error_arpabet"] != "normal":
        trans = phoneme["phoneme_error_arpabet"].split(" - ")[-1]
        labels = phoneme["phoneme_error_arpabet"].split(" - ")[0]
        if len(labels.split(" ")) >= 2:
            return False
        
        if len(trans.split(" ")) >= 2:
            return False
                
    return True

def is_valid_word(word):
    if len(word["phonemes"]) != len(word["trans_arpabet"].split()):
        return False

    return True
            
def parse_json_file(json_path):
    decision2color = {
        "correct": 2,
        "warning":1,
        "error":0
    }

    try: 
        with open(json_path, "r") as f:
            content = json.load(f)
        id = os.path.basename(json_path).split(".")[0]

        utterances = []
        for raw_utterance in content["utterance"]:
            id = id
            utt_raw_text = raw_utterance["sentence"]
            utt_score = raw_utterance["nativeness_score"]

            audio_path = os.path.join(hparams["audio_dir"], f'{id}.wav')

            start_time = None
            end_time = None
            utt_uid = None
            intonation_score = 0
            fluency_score = 0
            
            utt_text = []
            utt_arpas = []
            utt_trans = [] 
            utt_phone_scores = []
            utt_decisions = []
            utt_word_scores = []
            utt_word_ids = []
            
            ignore = False
            for word_id, word in enumerate(raw_utterance["words"]):
                word["trans_arpabet"] = word["trans_arpabet"].replace("AH0", "AX")
                
                wrd_score = word["nativeness_score"]
                wrd_text = norm_text(word["text"])
                wrd_arpa = word["trans_arpabet"].split()

                if is_valid_word(word) == False:
                    ignore = True
                    break

                for phoneme in word["phonemes"]:
                    if is_valid_phoneme(phoneme) == False:
                        ignore = True
                        break

                    arpa = phoneme["trans_arpabet"]
                    decision = decision2color[phoneme["decision"]]
                    score = phoneme["nativeness_score"] if phoneme["nativeness_score"] >= 0 else 0
                    tran = phoneme["trans_arpabet"]

                    utt_phone_scores.append(score)
                    utt_word_ids.append(word_id)
                    utt_trans.append(tran)
                    utt_decisions.append(decision)

                utt_text.append(wrd_text)                
                utt_word_scores.append(wrd_score)
                utt_arpas.extend(wrd_arpa)
            

            utterance = {
                "id": id,
                "raw_text": utt_raw_text,
                "text": " ".join(utt_text),
                "utt_id": utt_uid,
                "start_time": start_time,
                "end_time": end_time,
                "arpas": utt_arpas,
                "trans": utt_trans,
                "phone_scores": utt_phone_scores,
                "word_scores": utt_word_scores,
                "decisions": utt_decisions,
                "word_ids": utt_word_ids,
                "utterance_score": utt_score,
                "intonation_score": intonation_score,
                "fluency_score": fluency_score,
                "audio_path": audio_path
            }
            
            if ignore == False:
                utterances.append(utterance)
        
        return utterances
    except:
        return []

extracted_data = metadata.id.parallel_apply(
    lambda x: parse_json_file(os.path.join(hparams["json_dir"], f'{x}.json')))
extracted_data.head()

In [None]:
def save_jsonl_data_col_level(data, path):
    with open(path, "w", encoding="utf-8") as f:
        samples = (sample for sample in data.to_dict().values())
        for sample in tqdm(samples):
            json_obj = json.dumps(sample)

            f.write(f'{json_obj}\n')
    print(f'saved data to: ', path)

data = extracted_data.explode().dropna()
save_jsonl_data_col_level(data=data, path=hparams["out_jsonl_path"])