In [None]:
from glob import glob
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle
import torch
import json
import re

In [None]:
def map_phone(phones):
    mapped_phone = []
    for phone in phones:
        if phone == "SCHWA" or phone == "AH0":
            mapped_phone.append("AX")
        else:
            mapped_phone.append(phone)
    
    return mapped_phone

def get_phone_pure(phones):
    pure_phones = [re.sub(r"\d", "", phone) for phone in phones]

    return pure_phones

def preprocess_metadata(metadata):
    metadata = metadata[
        ["id", "audio_path", "text", "arpas", "trans", "phone_scores", "word_scores", "word_ids", "utterance_scores"]
    ]
    metadata["id"] = metadata.id.apply(str)
    metadata = metadata.rename(columns={"arpas":"elsa_phone"})
    metadata["elsa_phone"] = metadata.elsa_phone.apply(map_phone)
    metadata["elsa_phone"] = metadata.elsa_phone.apply(get_phone_pure)
    metadata["trans"] = metadata.trans.apply(map_phone)

    return metadata

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        lines = [json.loads(line.strip()) for line in lines]
    
    lines = pd.DataFrame(lines)
    return lines

import pickle


def load_gops(gop_paths):
    gops = {}
    for path in gop_paths:
        try:
            gop = pickle.load(open(path, "rb"))
        except:
            continue
        
        for key in gop.keys():
            assert key not in gops

        gops.update(gop)

    return gops

In [None]:
metadata_path = "/data/codes/apa/train/prep_data/jsonl/train-data-type-12-v2.jsonl"
out_dir = "/data/codes/apa/train/exps/features/train/train-12-v2"
data_dir = "/data/codes/apa/train/data/train/train-12-v2"

# metadata_path = "/data/codes/apa/train/prep_data/jsonl/info_in_domain_short_sentence_testset.jsonl"
# out_dir = "/data/codes/apa/train/exps/features/test/in-short"
# data_dir = "/data/codes/apa/train/data/test/in-short"

# metadata_path = "/data/codes/apa/train/prep_data/jsonl/info_out_domain_short_sentence_testset.jsonl"
# out_dir = "/data/codes/apa/train/exps/features/test/out-short"
# data_dir = "/data/codes/apa/train/data/test/out-short"

# metadata_path = "/data/codes/apa/train/prep_data/jsonl/info_qt_10_trainset.jsonl"
# out_dir = "/data/codes/apa/train/exps/features/train/train-10"
# data_dir = "/data/codes/apa/train/data/train/train-type-10"

metadata = load_jsonl(metadata_path)
metadata = preprocess_metadata(metadata)
print(metadata.shape)
metadata.head(1)

In [None]:
import os

if not os.path.exists(out_dir):
    os.makedirs(out_dir)

In [None]:
gop_path = f'{data_dir}/*/gop.pkl'
align_path = f'{data_dir}/*/ali.out'

alignment_paths = glob(align_path)
gop_paths = glob(gop_path)

In [None]:
def load_alignment(path):
    alignment_df = pd.read_csv(
        path, names=["id", "alignment"], sep="\t", dtype={"id": str}
    )
    alignment_df["alignment"] = alignment_df.alignment.apply(json.loads)

    return alignment_df

def load_alignments(paths):
    alignments = []
    for path in paths:
        alignment = load_alignment(path)
        alignments.append(alignment)
    
    alignments = pd.concat(alignments)
    alignments.reset_index(inplace=True)

    return alignments[["id", "alignment"]]

alignments = load_alignments(alignment_paths)
gops = load_gops(gop_paths)

In [None]:
is_valid = alignments.id.apply(lambda x: x in gops)
print(alignments.shape)
alignments = alignments[is_valid]
print(alignments.shape)

In [None]:
def extract_phonemes(alignments):
    phonemes = [
        re.sub("\d", "",phoneme[0].split("_")[0]) for phoneme in alignments
        if phoneme[0] != "SIL"
    ]
    return phonemes

def extract_durations(alignments):
    durations = [
        round(phoneme[2] * 0.02, 4) for phoneme in alignments
        if phoneme[0] != "SIL"
    ]
    return durations

def extract_relative_positions(alignments):
    relative_positions = [
        phoneme[0].split("_")[-1] for phoneme in alignments
        if phoneme[0] != "SIL"
    ]
    return relative_positions


alignments["relative_positions"] = alignments.alignment.apply(lambda x: extract_relative_positions(x))
alignments["prep_phone"] = alignments.alignment.apply(lambda x: extract_phonemes(x))
alignments["duration"] = alignments.alignment.apply(lambda x: extract_durations(x))

In [None]:
metadata = pd.merge(
    left=metadata, 
    right=alignments[["id", "alignment", "prep_phone", "relative_positions", "duration"]], 
    how="inner", on="id"
)

metadata.head(1)

In [None]:
def count_match(elsa, prep, scores):
    for index, (phone_1, phone_2) in enumerate(zip(elsa, prep)):
        if phone_1 != phone_2:
            if scores[index] < 30:
                continue
            return 0
        
    return 1

is_matched = metadata.apply(lambda x: count_match(elsa=x["elsa_phone"], prep=x["prep_phone"], scores=x["phone_scores"]), axis=1)
print(is_matched.sum())
print(is_matched.shape)

In [None]:
metadata = metadata[is_matched==True]

In [None]:
lengths = metadata["elsa_phone"].apply(len)
lengths.hist(bins=100)

In [None]:
MAX_LENGTH = 32

print(metadata.shape)
metadata = metadata[lengths<MAX_LENGTH]
print(metadata.shape)

In [None]:
# print(metadata["id"][0])
# for gop in gops[metadata["id"][0]]:
#     print(gop[0:5])

### Extract gop feature

In [None]:
gop_features = []

for index in tqdm(metadata.index):
    _id = metadata["id"][index]
    gop = gops[_id]

    padding = [[0,]*len(gop[0]),]*(MAX_LENGTH-len(gop))
    gop = gop.tolist() + padding
    gop = torch.tensor(gop)
    gop_features.append(gop)

gop_features = torch.stack(gop_features, dim=0)
gop_features = gop_features.numpy()
np.save(f'{out_dir}/gop.npy', gop_features)
gop_features = None

### Extract Relative Position

In [None]:
path = "/data/codes/apa/train/resources/relative2id.json"
relative2id = json.load(open(path, "r", encoding="utf-8"))

In [None]:
def convert_relative_position_to_id(relative_positions):
    ids = []
    for rel_pos in relative_positions:
        _id = relative2id[rel_pos]
        ids.append(_id)

    return ids
metadata["relative_positions"] = metadata["relative_positions"].apply(convert_relative_position_to_id)

In [None]:
relative_positions = []

for index in metadata.index:
    relative_position = metadata["relative_positions"][index].copy()

    padding = [relative2id["PAD"],]*(MAX_LENGTH-len(relative_position))
    relative_position = relative_position + padding
    relative_position = torch.tensor(relative_position)
    relative_positions.append(relative_position)

relative_positions = torch.stack(relative_positions, dim=0)
relative_positions = relative_positions.numpy()
np.save(f'{out_dir}/relative_positions.npy', relative_positions)
relative_positions = None

### Extract sentence scores

In [None]:
sentence_scores = []

for index in metadata.index:
    sentence_score = metadata["utterance_scores"][index].copy()

    sentence_scores.append(sentence_score)

sentence_scores = torch.tensor(sentence_scores)
sentence_scores = sentence_scores.numpy()
print(sentence_scores.shape)
np.save(f'{out_dir}/sentence_scores.npy', sentence_scores)
sentence_scores = None

### Extract word scores

In [None]:
word_scores = []

for index in metadata.index:
    word_score = metadata["word_scores"][index].copy()
    word_id = metadata["word_ids"][index].copy()

    word_score_in_phone_levels = []
    for wid in word_id:
        word_score_in_phone_levels.append(word_score[wid])

    padding = [-1,]*(MAX_LENGTH-len(word_score_in_phone_levels))
    word_score_in_phone_levels = word_score_in_phone_levels + padding
    word_score_in_phone_levels = torch.tensor(word_score_in_phone_levels)
    word_scores.append(word_score_in_phone_levels)

word_scores = torch.stack(word_scores, dim=0)
word_scores = word_scores.numpy()
print(word_scores.shape)
np.save(f'{out_dir}/word_scores.npy', word_scores)
word_scores = None

### Extract word ids

In [None]:
word_ids = []

for index in metadata.index:
    word_id = metadata["word_ids"][index].copy()

    padding = [-1,]*(MAX_LENGTH-len(word_id))
    word_id = word_id + padding
    word_id = torch.tensor(word_id)
    word_ids.append(word_id)

word_ids = torch.stack(word_ids, dim=0)
word_ids = word_ids.numpy()
print(word_ids.shape)
np.save(f'{out_dir}/word_ids.npy', word_ids)
word_ids = None

### Extract duration feature

In [None]:
durations = []

for index in metadata.index:
    duration = metadata["duration"][index].copy()

    padding = [0, ]*(MAX_LENGTH-len(duration))

    duration += padding
    duration = torch.tensor(duration)
    durations.append(duration)

durations = torch.stack(durations, dim=0)
durations = durations.numpy()
np.save(f'{out_dir}/duration.npy', durations)
durations = None

### Extract phone scores

In [None]:
phone_scores = []

for index in metadata.index:
    phone_score = metadata["phone_scores"][index].copy()

    padding = [-1, ]*(MAX_LENGTH-len(phone_score))

    phone_score += padding
    phone_score = torch.tensor(phone_score)
    phone_scores.append(phone_score)

phone_scores = torch.stack(phone_scores, dim=0)
phone_scores = phone_scores.numpy()
np.save(f'{out_dir}/phone_scores.npy', phone_scores)
phone_scores = None

### Extract phone ids

In [None]:
phone_dict_path =  "/data/codes/apa/train/resources/phone_dict.json"
with open(phone_dict_path, "r", encoding="utf-8") as f:
    phone_dict = json.load(f)

In [None]:
phone_ids = []

pad_token_id = phone_dict["PAD"]
for index in metadata.index:
    phoneme = metadata["elsa_phone"][index].copy()

    phoneme = [re.sub("\d", "", phn) for phn in phoneme]
    phoneme = [phone_dict[phn] for phn in phoneme]
    padding = [pad_token_id, ]*(MAX_LENGTH-len(phoneme))

    phoneme += padding
    phone_ids.append(torch.tensor(phoneme))

phone_ids = torch.stack(phone_ids, dim=0)
phone_ids = phone_ids.numpy()
np.save(f'{out_dir}/phone_ids.npy', phone_ids)
phone_ids = None

### Extract WavLM Feature

In [None]:
%cd /data/codes/apa/train
import torch
from src.models.wavlm_model import WavLM, WavLMConfig
from tqdm import tqdm
import librosa
import pandas as pd
import json

In [None]:
pretrained_path = "/data/codes/apa/train/exps/ckpts/wavlm-base+.pt"
checkpoint = torch.load(pretrained_path)

config = WavLMConfig(checkpoint['cfg'])
model = WavLM(config).eval().cuda()
model.load_state_dict(checkpoint['model'])

In [None]:
def extract_feature(alignment, features):
    index = 0
    phonemes = []
    indices = -1 * torch.ones(alignment[-1][1] + alignment[-1][2])
    for phoneme, start_frame, duration in alignment:
        if phoneme == "SIL":
            continue
        end_frame = start_frame + duration
        indices[start_frame:end_frame] = index
        phonemes.append(phoneme)
        index += 1

    indices[indices==-1] = indices.max() + 1

    indices = torch.nn.functional.one_hot(indices.long(), num_classes=int(indices.max().item())+1).cuda()
    indices = indices / indices.sum(0, keepdim=True)
    
    if features.shape[0] != indices.shape[0]:
        features = features[0:indices.shape[0]]
    features = torch.matmul(indices.transpose(0, 1), features)

    return features[:-1].cpu(), phonemes

In [None]:
wavlm_features = []
for index in tqdm(metadata.index):
    audio_path = metadata["audio_path"][index]
    alignment = metadata["alignment"][index]

    wav, sr = librosa.load(audio_path, sr=16000)

    input_values = torch.from_numpy(wav).unsqueeze(0).cuda()
    with torch.no_grad():
        features = model.extract_features(input_values)[0]

    index = torch.arange(features.shape[1]).unsqueeze(-1)
    expanded_index = index.expand((-1, 2)).flatten()
    features = features[0][expanded_index]

    features, phonemes = extract_feature(alignment, features)
    if len(features) != len(phonemes):
        print(metadata["id"][index])

    features = torch.concat([features, torch.zeros(MAX_LENGTH-len(features), 768)], axis=0)
    wavlm_features.append(features.unsqueeze(0).numpy())

wavlm_features = np.row_stack(wavlm_features)
np.save(f'{out_dir}/wavlm_features.npy', wavlm_features)
wavlm_features = None