In [None]:
%cd /data/codes/apa/train

from glob import glob
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle
import torch
import json
import os
import re

from src.dataset import (
    IndexedDataset,
    IndexedDatasetBuilder
)

In [None]:
def map_phone(phones):
    mapped_phone = []
    for phone in phones:
        if phone == "SCHWA" or phone == "AH0":
            mapped_phone.append("AX")
        else:
            mapped_phone.append(phone)
    
    return mapped_phone

def get_phone_pure(phones):
    pure_phones = [re.sub(r"\d", "", phone) for phone in phones]

    return pure_phones

def preprocess_metadata(metadata):
    metadata = metadata[
        ["id", "audio_path", "text", "arpas", "trans", "phone_scores", "word_scores", "word_ids", "utterance_score", "fluency_score", "intonation_score"]
    ]
    metadata["id"] = metadata.id.apply(str)
    metadata = metadata.rename(columns={"arpas":"elsa_phone"})
    metadata["elsa_phone"] = metadata.elsa_phone.apply(map_phone)
    metadata["elsa_phone"] = metadata.elsa_phone.apply(get_phone_pure)
    metadata["trans"] = metadata.trans.apply(map_phone)

    return metadata

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        lines = [json.loads(line.strip()) for line in lines]
    
    lines = pd.DataFrame(lines)
    return lines

def load_gops(gop_paths):
    gops = {}
    for path in gop_paths:
        try:
            gop = pickle.load(open(path, "rb"))
        except:
            continue
        
        for key in gop.keys():
            assert key not in gops

        gops.update(gop)

    return gops

In [None]:
out_dir = "/data/codes/apa/train/data/feats/train/merge"

metadata_path = [
    "/data/codes/apa/train/data/metadata/jsonl/train-data-type-10.jsonl",
    "/data/codes/apa/train/data/metadata/jsonl/train-data-type-12.jsonl"
]

feat_dir = [
    "/data/codes/apa/train/data/feats/train/train-data-type-10-filtered/",
    "/data/codes/apa/train/data/feats/train/train-data-type-12-filtered/",
]

data_dir = [
    "/data/codes/apa/train/data/train/train-data-type-10",
    "/data/codes/apa/train/data/train/train-data-type-12"
]

metadata = []

for path in metadata_path:
    tmp = load_jsonl(path)
    metadata.append(tmp)

metadata = pd.concat(metadata)
metadata = preprocess_metadata(metadata)
metadata.head(1)

In [None]:
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

In [None]:
lengths = metadata["elsa_phone"].apply(len)
lengths.hist(bins=100)

In [None]:
MAX_LENGTH = 128

print(metadata.shape)
metadata = metadata[lengths<MAX_LENGTH-3]
print(metadata.shape)

In [None]:
id_df = []
for dir in feat_dir:
    id_path = f"{dir}/id"

    tmp = pd.read_csv(id_path, names=["id"], dtype={'id':str})
    tmp = tmp.set_index("id")

    id_df.append(tmp)

id_df = pd.concat(id_df)
print(metadata.shape)
metadata = metadata[metadata.id.isin(id_df.index)]
print(metadata.shape)

In [None]:
alignment_paths, gop_paths = [], []
for dir in data_dir:
    gop_path = f'{dir}/*/gop.pkl'
    align_path = f'{dir}/*/ali.out'

    alignment_paths += glob(align_path)
    gop_paths += glob(gop_path)

In [None]:
def load_alignment(path):
    alignment_df = pd.read_csv(
        path, names=["id", "alignment"], sep="\t", dtype={"id": str}
    )
    alignment_df["alignment"] = alignment_df.alignment.apply(json.loads)

    return alignment_df

def load_alignments(paths):
    alignments = []
    for path in paths:
        alignment = load_alignment(path)
        alignments.append(alignment)
    
    alignments = pd.concat(alignments)
    alignments.reset_index(inplace=True)

    return alignments[["id", "alignment"]]

alignments = load_alignments(alignment_paths)
gops = load_gops(gop_paths)

In [None]:
is_valid = alignments.id.apply(lambda x: x in gops)
print(alignments.shape)
alignments = alignments[is_valid]
print(alignments.shape)

In [None]:
def extract_phonemes(alignments):
    phonemes = [
        re.sub("\d", "",phoneme[0].split("_")[0]) for phoneme in alignments
        if phoneme[0] != "SIL"
    ]
    return phonemes

def extract_durations(alignments):
    durations = [
        round(phoneme[2] * 0.02, 4) for phoneme in alignments
        if phoneme[0] != "SIL"
    ]
    return durations

def extract_relative_positions(alignments):
    relative_positions = [
        phoneme[0].split("_")[-1] for phoneme in alignments
        if phoneme[0] != "SIL"
    ]
    return relative_positions


alignments["relative_positions"] = alignments.alignment.apply(lambda x: extract_relative_positions(x))
alignments["prep_phone"] = alignments.alignment.apply(lambda x: extract_phonemes(x))
alignments["duration"] = alignments.alignment.apply(lambda x: extract_durations(x))

In [None]:
metadata = pd.merge(
    left=metadata, 
    right=alignments[["id", "alignment", "prep_phone", "relative_positions", "duration"]], 
    how="inner", on="id"
)

metadata.head(1)

In [None]:
def count_match(elsa, prep, scores):
    for index, (phone_1, phone_2) in enumerate(zip(elsa, prep)):
        if phone_1 != phone_2:
            if scores[index] < 40:
                continue
            return 0
        
    return 1

is_matched = metadata.apply(lambda x: count_match(elsa=x["elsa_phone"], prep=x["prep_phone"], scores=x["phone_scores"]), axis=1)
metadata = metadata[is_matched==True]
print(is_matched.sum())
print(is_matched.shape)

### Extract alignment

In [None]:
def preprocess_alignments(alignment):
    processed_alignment = []
    for phone, start, duration in alignment:
        if phone == "SIL":
            continue
        phone = phone.split("_")[0]
        processed_alignment.append([phone, start, duration])
    
    return processed_alignment

In [None]:
with open(f'{out_dir}/alignment', "w", encoding="utf-8") as f:
    for index in tqdm(metadata.index):
        _id = metadata["id"][index]
        _alignment = metadata["alignment"][index]

        _alignment = preprocess_alignments(_alignment)

        _alignment= json.dumps(_alignment, ensure_ascii=False)
        
        f.write(f'{_alignment}\n')

#### Extract audio_path

In [None]:
with open(f'{out_dir}/wav_path', "w", encoding="utf-8") as f:
    for index in tqdm(metadata.index):
        audio_path = metadata["audio_path"][index]
        
        f.write(f'{audio_path}\n')

### Extract id

In [None]:
with open(f'{out_dir}/id', "w", encoding="utf-8") as f:
    for index in tqdm(metadata.index):
        _id = metadata["id"][index]
        
        f.write(f'{_id}\n')

### Extract gop feature

In [None]:
indexed_path = f'{out_dir}/gop'

builder = IndexedDatasetBuilder(indexed_path)
for index in tqdm(metadata.index):
    _id = metadata["id"][index]
    gop = gops[_id]

    gop = np.array(gop)
    builder.add_item(item=gop)
    
builder.finalize()

In [None]:
temp = IndexedDataset(indexed_path)
temp[0]

### Extract Relative Position

In [None]:
path = "/data/codes/apa/train/exp/dicts/relative2id.json"
relative2id = json.load(open(path, "r", encoding="utf-8"))

In [None]:
def convert_relative_position_to_id(relative_positions):
    ids = []
    for rel_pos in relative_positions:
        _id = relative2id[rel_pos]
        ids.append(_id)

    return ids
metadata["relative_positions"] = metadata["relative_positions"].apply(convert_relative_position_to_id)

In [None]:
indexed_path = f'{out_dir}/relative_positions'

builder = IndexedDatasetBuilder(indexed_path)

for index in metadata.index:
    relative_position = metadata["relative_positions"][index].copy()

    relative_position = np.array(relative_position)

    builder.add_item(item=relative_position)
builder.finalize()

In [None]:
indexed_path = f'{out_dir}/relative_positions'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract sentence scores

In [None]:
indexed_path = f'{out_dir}/sentence_scores'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    sentence_score = metadata["utterance_score"][index].copy()

    builder.add_item(item=sentence_score)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/sentence_scores'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract fluency scores

In [None]:
indexed_path = f'{out_dir}/fluency_score'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    sentence_score = metadata["utterance_score"][index].copy()

    builder.add_item(item=sentence_score)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/fluency_score'
temp = IndexedDataset(indexed_path)
temp[0]

#### Extract intonation score 

In [None]:
indexed_path = f'{out_dir}/intonation_score'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    sentence_score = metadata["utterance_score"][index].copy()

    builder.add_item(item=sentence_score)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/intonation_score'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract word scores

In [None]:
indexed_path = f'{out_dir}/word_scores'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    word_score = metadata["word_scores"][index].copy()
    word_id = metadata["word_ids"][index].copy()

    word_score_in_phone_levels = []
    for wid in word_id:
        word_score_in_phone_levels.append(word_score[wid])

    word_score_in_phone_levels = np.array(word_score_in_phone_levels)
    builder.add_item(item=word_score_in_phone_levels)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/word_scores'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract word ids

In [None]:
indexed_path = f'{out_dir}/word_ids'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    word_id = metadata["word_ids"][index].copy()

    word_id = np.array(word_id)
    builder.add_item(item=word_id)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/word_ids'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract duration feature

In [None]:
indexed_path = f'{out_dir}/duration'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    duration = metadata["duration"][index].copy()

    duration = np.array(duration)
    builder.add_item(item=duration)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/duration'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract phone scores

In [None]:
indexed_path = f'{out_dir}/phone_scores'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    phone_score = metadata["phone_scores"][index].copy()

    phone_score = np.array(phone_score)
    builder.add_item(item=phone_score)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/phone_scores'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract phone ids

In [None]:
phone_dict_path =  "/data/codes/apa/train/exp/dicts/phone_dict.json"
with open(phone_dict_path, "r", encoding="utf-8") as f:
    phone_dict = json.load(f)

In [None]:
indexed_path = f'{out_dir}/phone_ids'

builder = IndexedDatasetBuilder(indexed_path)
for index in metadata.index:
    phoneme = metadata["elsa_phone"][index].copy()

    phoneme = [re.sub("\d", "", phn) for phn in phoneme]
    phoneme = [phone_dict[phn] for phn in phoneme]

    phoneme = np.array(phoneme)
    builder.add_item(item=phoneme)

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/phone_ids'
temp = IndexedDataset(indexed_path)
temp[0]

### Extract WavLM Feature

In [None]:
import torch
from src.models.wavlm_model import WavLM, WavLMConfig
from tqdm import tqdm
import librosa
import pandas as pd
import json

In [None]:
pretrained_path = "/data/codes/apa/train/exp/torch/wavlm-base+.pt"
checkpoint = torch.load(pretrained_path)

config = WavLMConfig(checkpoint['cfg'])
model = WavLM(config).eval().cuda()
model.load_state_dict(checkpoint['model'])

In [None]:
def extract_feature(alignment, features):
    index = 0
    phonemes = []
    indices = -1 * torch.ones(alignment[-1][1] + alignment[-1][2])
    for phoneme, start_frame, duration in alignment:
        if phoneme == "SIL":
            continue
        end_frame = start_frame + duration
        indices[start_frame:end_frame] = index
        phonemes.append(phoneme)
        index += 1

    if -1 in indices:
        indices[indices==-1] = indices.max() + 1

        indices = torch.nn.functional.one_hot(indices.long(), num_classes=int(indices.max().item())+1).cuda()
        indices = indices / indices.sum(0, keepdim=True)
    
        if features.shape[0] != indices.shape[0]:
            features = features[0:indices.shape[0]]
        features = torch.matmul(indices.transpose(0, 1), features)

        return features[:-1].cpu(), phonemes
    
    else:
        indices[indices==-1] = indices.max() + 1

        indices = torch.nn.functional.one_hot(indices.long(), num_classes=int(indices.max().item())+1).cuda()
        indices = indices / indices.sum(0, keepdim=True)
    
        if features.shape[0] != indices.shape[0]:
            features = features[0:indices.shape[0]]
        features = torch.matmul(indices.transpose(0, 1), features)

        return features.cpu(), phonemes

In [None]:
indexed_path = f'{out_dir}/wavlm_features'
builder = IndexedDatasetBuilder(indexed_path)

wavlm_features = []
for index in tqdm(metadata.index):
    audio_path = metadata["audio_path"][index]
    alignment = metadata["alignment"][index]

    wav, sr = librosa.load(audio_path, sr=16000)

    input_values = torch.from_numpy(wav).unsqueeze(0).cuda()
    with torch.no_grad():
        features = model.extract_features(input_values)[0]
        if index % 1000:
            torch.cuda.empty_cache()

    index = torch.arange(features.shape[1]).unsqueeze(-1)
    expanded_index = index.expand((-1, 2)).flatten()
    features = features[0][expanded_index]

    features, phonemes = extract_feature(alignment, features)
    if len(features) != len(phonemes):
        print(metadata["id"][index])

    builder.add_item(item=features.numpy())

builder.finalize()

In [None]:
indexed_path = f'{out_dir}/wavlm_features'
temp = IndexedDataset(indexed_path)
temp[0]