In [None]:
%cd /data/codes/prep_ps_pykaldi/
import pandas as pd
import numpy as np
import pickle
import json
import re

In [None]:
def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        lines = [json.loads(line.strip()) for line in lines]
    
    return lines

path = "prep_data/info_in_domain_short_sentence_testset.jsonl"
metadata = load_jsonl(path)
metadata = pd.DataFrame(metadata)
metadata.head()

In [None]:
with open('/data/codes/prep_ps_pykaldi/exp/sm/test/merged_gop.pkl', 'rb') as f:
    data = pickle.load(f)
metadata = metadata[metadata.id.isin(data)]
metadata.head(2)

In [None]:
def extract_gop_feature(id):
    sample = data[str(id)]
    features = [
        np.array(feature) for feature, phoneme in zip(sample["gopt"], sample["phones"][0])
        if phoneme != "SIL"
    ]
    return np.stack(features)

def extract_phonemes(id):
    sample = data[str(id)]
    phonemes = [
        re.sub("\d", "",phoneme.split("_")[0]) for phoneme in sample["phones"][0]
        if phoneme != "SIL"
    ]
    return phonemes

metadata["features"] = metadata.id.apply(lambda x: extract_gop_feature(x))
metadata["kaldi_phoneme"] = metadata.id.apply(lambda x: extract_phonemes(x))
metadata.head(2)

In [None]:
align_path = "/data/codes/prep_ps_pykaldi/exp/sm/test/merged_align.out"
align_df = pd.read_csv(align_path, names=["id", "alignment"], sep="\t")

def extract_duration(alignment):
    alignment = json.loads(alignment)
    durations = []
    
    for phoneme, start, duration in alignment:
        if phoneme == "SIL":
            continue
        durations.append(round(duration * 0.02, 4))

    return durations

def extract_phonemes(alignment):
    alignment = json.loads(alignment)
    phonemes = []
    
    for phoneme, start, duration in alignment:
        if phoneme == "SIL":
            continue
        phonemes.append(phoneme.split("_")[0])

    return phonemes

align_df["durations"] = align_df["alignment"].apply(lambda x: extract_duration(x))
align_df["phonemes"] = align_df["alignment"].apply(lambda x: extract_phonemes(x))
align_df["id"] = align_df["id"].apply(str)
align_df.head()

In [None]:
metadata = pd.merge(metadata, align_df[["id", "durations"]], how="left", on="id")

In [None]:
from sklearn.preprocessing import StandardScaler
import pickle

features = metadata["features"].to_list()
features = np.concatenate(features)

scaler = StandardScaler()
scaler.fit(features)

with open('resources/scaler.pkl','wb') as f:
    pickle.dump(scaler, f)

In [None]:
# phone_dict = []
# for word in metadata["arpas"].to_list():
#     phone_dict += word
# phone_dict = list(set(phone_dict))
# phone_dict.sort()

# phone_dict.insert(0, "PAD")
# vocab= {}
# for key in phone_dict:
#     key = re.sub("\d", "", key)
#     if key not in vocab:
#         vocab[key] = len(vocab)
# phone_dict = {re.sub("\d", "", key):value for value, key in enumerate(phone_dict)}

# with open("resources/phone_dict.json", "w", encoding="utf-8") as f:
#     json_obj = json.dumps(vocab, indent=4, ensure_ascii=False)
#     f.write(json_obj)

with open("resources/phone_dict.json", "r", encoding="utf-8") as f:
    phone_dict = json.load(f)

In [None]:
metadata['phones'] = metadata["arpas"].apply(lambda word: [phone_dict[re.sub("\d", "", x)] for x in word])

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch

class PrepDataset(Dataset):
    def __init__(self, metadata,):
        self.metadata = metadata
        self.scaler = pickle.load(open('resources/scaler.pkl','rb'))
    
    def __len__(self):
        return self.metadata.shape[0]
    
    def parse_data(self, features, phones, phone_scores, durations):
        features = torch.tensor(features)
        durations = torch.tensor(durations)
        phones = torch.tensor(phones)
        phone_scores = torch.tensor(phone_scores) / 50

        features = torch.concat([features, durations.unsqueeze(-1)], dim=-1)        
        return {
            "features": features,
            "phones": phones,
            "phone_scores":phone_scores
        }
        
    def __getitem__(self, index):
        features = self.metadata["features"][index]

        features = self.scaler.transform(features)
        phones = self.metadata["phones"][index]
        phone_scores = self.metadata["phone_scores"][index]
        durations =self.metadata["durations"][index]

        return self.parse_data(
            features=features, 
            phones=phones, 
            phone_scores=phone_scores,
            durations=durations
        )
        
    def pad_1d(self, input_ids, pad_value=-1, max_length=64):
        if max_length is None:
            max_length = max([len(sample) for sample in input_ids])        
            
        attention_masks = []
        for i in range(len(input_ids)):
            if input_ids[i].size(0) < max_length:
                input_ids[i] = torch.cat(
                    (
                        input_ids[i], 
                        torch.Tensor([pad_value, ]*(max_length-len(input_ids[i])))
                        )
                    )
            elif input_ids[i].size(0) > max_length:
                input_ids[i] = input_ids[i][0:max_length]
                
            attention_masks.append(input_ids[i] != pad_value)
            
        return {
            "input_ids": torch.vstack(input_ids),
            "attention_mask": torch.vstack(attention_masks)
        }
        
    def pad_2d(self, inputs, pad_value=0, max_length=64):
        # max_length = max([len(sample) for sample in inputs])        
            
        for i in range(len(inputs)):
            if inputs[i].size(0) < max_length:
                inputs[i] = torch.cat(
                    (
                        inputs[i], 
                        pad_value*torch.ones((max_length-len(inputs[i]), 83))
                        )
                    )
            elif inputs[i].size(0) > max_length:
                inputs[i] = inputs[i][0:max_length]
                
        return torch.stack(inputs, dim=0)
        
    
    def collate_fn(self, batch):
        features = [sample["features"] for sample in batch]
        phone_scores = [sample["phone_scores"] for sample in batch]
        phones = [sample["phones"] for sample in batch]
        
        outputs = self.pad_1d(phones, pad_value=0)
        phones = outputs["input_ids"]
        phones_mask = outputs["attention_mask"]
        
        outputs = self.pad_1d(phone_scores, pad_value=-1)
        phone_scores = outputs["input_ids"]
        
        features = self.pad_2d(features, pad_value=0)
                     
        return {
            "features": features,
            "phones": phones,
            "phone_scores": phone_scores
        }

dataset = PrepDataset(metadata)
dataloader = DataLoader(dataset, batch_size=8, collate_fn=dataset.collate_fn)

for batch in dataloader:
    features = batch["features"]
    phones = batch["phones"]
    phone_scores = batch["phone_scores"]
    
    print(features.shape)
    print(phones.shape)
    print(phone_scores.shape)
    break

In [None]:
import math
import warnings
import torch
import torch.nn as nn
import numpy as np

def get_sinusoid_encoding(n_position, d_hid):
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class GOPT(nn.Module):
    def __init__(self, embed_dim, num_heads, depth, input_dim=84, max_length=50, num_phone=40):
        super().__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.blocks = nn.ModuleList(
            [
                Block(dim=embed_dim, num_heads=num_heads) 
                for i in range(depth)
                ]
            )

        self.pos_embed = nn.Parameter(torch.zeros(1, max_length+1, self.embed_dim))
        trunc_normal_(self.pos_embed, std=.02)

        self.in_proj = nn.Linear(self.input_dim, embed_dim)
        self.linear = nn.Linear(embed_dim * 2, embed_dim)
        self.mlp_head_phn = nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        self.mlp_head_word= nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        self.num_phone = num_phone
        self.phn_proj = nn.Linear(num_phone, embed_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mlp_head_utt = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        trunc_normal_(self.cls_token, std=.02)

    def forward(self, x, phn):
        B = x.shape[0]
        phn_one_hot = torch.nn.functional.one_hot(phn.long()+1, num_classes=self.num_phone).float()
        phn_embed = self.phn_proj(phn_one_hot)

        if self.embed_dim != self.input_dim:
            x = self.in_proj(x)

        x = torch.cat([x, phn_embed], dim=-1)
        x = self.linear(x)
        
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed[:,:x.shape[1],:]

        for blk in self.blocks:
            x = blk(x)
        u = self.mlp_head_utt(x[:, 0])
        p = self.mlp_head_phn(x[:, 1:])
        w = self.mlp_head_word(x[:, 1:])
        return u, p, w

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# trainset, testset = train_test_split(metadata, test_size=0.5, random_state=42)
index = 15000
trainset = metadata[:index]
testset = metadata[index:]

trainset.reset_index(inplace=True)
testset.reset_index(inplace=True)

trainset = PrepDataset(trainset)
trainloader = DataLoader(
    dataset=trainset,
    batch_size=8,
    collate_fn=trainset.collate_fn,
    shuffle=True,
    drop_last=False
)

testset = PrepDataset(testset)
testloader = DataLoader(
    dataset=testset,
    batch_size=1,
    collate_fn=testset.collate_fn,
    shuffle=False,
    drop_last=True
)

In [None]:
from torch.optim.lr_scheduler import MultiStepLR
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gopt_model = GOPT(
    embed_dim=32, num_heads=1, 
    depth=3, input_dim=83, 
    max_length=128, num_phone=62).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

lr = 3e-4
optimizer = torch.optim.Adam(
    trainables, lr, weight_decay=5e-7, betas=(0.95, 0.999))

scheduler = MultiStepLR(
    optimizer, list(range(10, 100, 5)), gamma=0.5, last_epoch=-1)

loss_fn = nn.MSELoss()


In [None]:
def valid_phn(audio_output, target):
    valid_token_pred = []
    valid_token_target = []
    # audio_output = audio_output.squeeze(2)
    for i in range(audio_output.shape[0]):
        for j in range(audio_output.shape[1]):
            # only count valid tokens, not padded tokens (represented by negative values)
            if target[i, j] >= 0:
                valid_token_pred.append(audio_output[i, j])
                valid_token_target.append(target[i, j])
    valid_token_target = np.array(valid_token_target)
    valid_token_pred = np.array(valid_token_pred)

    # valid_token_mse = np.mean((valid_token_target - valid_token_pred) ** 2)
    valid_token_mse = np.mean(np.abs(valid_token_target - valid_token_pred))
    corr = np.corrcoef(valid_token_pred, valid_token_target)[0, 1]
    return valid_token_mse, corr


In [None]:
from tqdm import tqdm

global_step = 0
for epoch in range(50):
    gopt_model.train()
    train_tqdm = tqdm(trainloader, "Training")

    for batch in train_tqdm:
        optimizer.zero_grad()
        
        features = batch["features"].to(device)
        phones = batch["phones"].to(device)
        phone_labels = batch["phone_scores"].to(device)

        # warm_up_step = 100
        # if global_step <= warm_up_step and global_step % 5 == 0:
        #     warm_lr = (global_step / warm_up_step) * lr
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = warm_lr

        utterance_preds, phone_preds, word_preds = gopt_model(x=features.float(), phn=phones.long())
        
        mask = phone_labels >=0
        phone_preds = phone_preds.squeeze(2)
        phone_preds = phone_preds * mask
        phone_labels = phone_labels * mask
        
        loss_phn = loss_fn(phone_preds, phone_labels)
        loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)
        
        loss_phn.backward()
        optimizer.step()

        # if global_step > warm_up_step:
        #     scheduler.step()
        
        global_step += 1

        train_tqdm.set_postfix(loss_phn=loss_phn.item())
    
    A_phn, A_phn_target = [], []
    for batch in testloader:
        features = batch["features"].to(device)
        phones = batch["phones"].to(device)
        phone_labels = batch["phone_scores"].to(device)
        
        utterance_preds, phone_preds, word_preds = gopt_model(x=features.float(), phn=phones.long())
        
        phone_preds = phone_preds.detach().cpu()
        phone_labels = phone_labels.detach().cpu()
        
        A_phn.append(phone_preds[:, :, 0])
        A_phn_target.append(phone_labels)
        
    A_phn, A_phn_target  = torch.vstack(A_phn), torch.vstack(A_phn_target)
    phn_mse, phn_corr = valid_phn(A_phn, A_phn_target)
    print(f"### Validation result: MSE={phn_mse} PCC={phn_corr}")