In [None]:
%cd /data/codes/prep_ps_pykaldi/
import pandas as pd
import numpy as np
import pickle
import json
import re

In [None]:
def load_data(data_dir):
    phone_ids = np.load(f'{data_dir}/phone_ids.npy')
    phone_scores = np.load(f'{data_dir}/phone_scores.npy')
    durations = np.load(f'{data_dir}/duration.npy')
    gops = np.load(f'{data_dir}/gop.npy')
    wavlm_features = np.load(f'{data_dir}/wavlm_features.npy')

    return phone_ids, phone_scores, durations, gops, wavlm_features


In [None]:
from torch.utils.data import Dataset, DataLoader
import torch

class PrepDataset(Dataset):
    def __init__(self, phone_ids, phone_scores, durations, gops, wavlm_features):
        self.phone_ids = phone_ids
        self.phone_scores = phone_scores
        self.gops = gops
        self.durations = durations
        self.wavlm_features = wavlm_features
        
    def __len__(self):
        return self.phone_ids.shape[0]
    
    def parse_data(self, phone_ids, phone_scores, gops, durations, wavlm_features):
        phone_ids = torch.tensor(phone_ids)
        durations = torch.tensor(durations)
        gops = torch.tensor(gops)
        phone_scores = torch.tensor(phone_scores).float().clone()
        wavlm_features = torch.tensor(wavlm_features)

        phone_scores[phone_scores != -1] /= 50

        features = torch.concat([gops, durations.unsqueeze(-1), wavlm_features], dim=-1)        
        return {
            "features": features,
            "phone_ids": phone_ids,
            "phone_scores":phone_scores
        }

        
    def __getitem__(self, index):
        phone_ids = self.phone_ids[index]
        phone_scores = self.phone_scores[index]
        gops = self.gops[index]
        durations = self.durations[index]
        wavlm_features = self.wavlm_features[index]

        return self.parse_data(
            phone_ids=phone_ids,
            phone_scores=phone_scores,
            gops=gops,
            durations=durations,
            wavlm_features=wavlm_features
        )

data_dir = "/data/codes/prep_ps_pykaldi/exp/sm/test"

phone_ids, phone_scores, durations, gops, wavlm_features = load_data(data_dir)
dataset_v1 = PrepDataset(phone_ids, phone_scores, durations, gops, wavlm_features)
dataloader = DataLoader(dataset_v1, batch_size=8)

for batch in dataloader:
    features = batch["features"]
    phone_ids = batch["phone_ids"]
    phone_scores = batch["phone_scores"]
    
    print(features.shape)
    print(phone_ids.shape)
    print(phone_scores.shape)
    break

dataset_v1 = None
dataloader = None

In [None]:
import math
import warnings
import torch
import torch.nn as nn
import numpy as np

def get_sinusoid_encoding(n_position, d_hid):
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class GOPT(nn.Module):
    def __init__(self, embed_dim, num_heads, depth, input_dim=84, max_length=50, num_phone=40):
        super().__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.blocks = nn.ModuleList(
            [
                Block(dim=embed_dim, num_heads=num_heads) 
                for i in range(depth)
                ]
            )

        self.pos_embed = nn.Parameter(torch.zeros(1, max_length+1, self.embed_dim))
        trunc_normal_(self.pos_embed, std=.02)

        self.in_proj = nn.Linear(self.input_dim, embed_dim)
        self.linear = nn.Linear(embed_dim * 2, embed_dim)
        self.mlp_head_phn = nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        self.mlp_head_word= nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        self.num_phone = num_phone
        self.phn_proj = nn.Linear(num_phone, embed_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mlp_head_utt = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        trunc_normal_(self.cls_token, std=.02)

    def forward(self, x, phn):
        B = x.shape[0]
        phn_one_hot = torch.nn.functional.one_hot(phn.long()+1, num_classes=self.num_phone).float()
        phn_embed = self.phn_proj(phn_one_hot)

        if self.embed_dim != self.input_dim:
            x = self.in_proj(x)

        x = torch.cat([x, phn_embed], dim=-1)
        x = self.linear(x)
        
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed[:,:x.shape[1],:]

        for blk in self.blocks:
            x = blk(x)
        u = self.mlp_head_utt(x[:, 0])
        p = self.mlp_head_phn(x[:, 1:])
        w = self.mlp_head_word(x[:, 1:])
        return u, p, w

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

data_dir = "/data/codes/prep_ps_pykaldi/exp/sm/train"
phone_ids, phone_scores, durations, gops, wavlm_features = load_data(data_dir)
trainset = PrepDataset(
    phone_ids, 
    phone_scores, 
    durations, 
    gops, 
    wavlm_features
)

trainloader = DataLoader(trainset, batch_size=8, shuffle=True, drop_last=False)

data_dir = "/data/codes/prep_ps_pykaldi/exp/sm/test"
phone_ids, phone_scores, durations, gops, wavlm_features = load_data(data_dir)
testset = PrepDataset(
    phone_ids, 
    phone_scores, 
    durations, 
    gops, 
    wavlm_features
)

testloader = DataLoader(testset, batch_size=8, shuffle=True, drop_last=False)


In [None]:
from torch.optim.lr_scheduler import MultiStepLR
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gopt_model = GOPT(
    embed_dim=32, num_heads=1, 
    depth=3, input_dim=851, 
    max_length=128, num_phone=62).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

lr = 3e-4
optimizer = torch.optim.Adam(
    trainables, lr, weight_decay=5e-7, betas=(0.95, 0.999))

scheduler = MultiStepLR(
    optimizer, list(range(10, 100, 5)), gamma=0.5, last_epoch=-1)

loss_fn = nn.MSELoss()

In [None]:
def valid_phn(audio_output, target):
    valid_token_pred = []
    valid_token_target = []
    # audio_output = audio_output.squeeze(2)
    for i in range(audio_output.shape[0]):
        for j in range(audio_output.shape[1]):
            # only count valid tokens, not padded tokens (represented by negative values)
            if target[i, j] >= 0:
                valid_token_pred.append(audio_output[i, j])
                valid_token_target.append(target[i, j])
    valid_token_target = np.array(valid_token_target)
    valid_token_pred = np.array(valid_token_pred)

    valid_token_mse = np.mean((valid_token_target - valid_token_pred) ** 2)
    valid_token_mae = np.mean(np.abs(valid_token_target - valid_token_pred))
    corr = np.corrcoef(valid_token_pred, valid_token_target)[0, 1]
    return valid_token_mse, valid_token_mae, corr


In [None]:
from tqdm import tqdm

global_step = 0
for epoch in range(50):
    gopt_model.train()
    train_tqdm = tqdm(trainloader, "Training")

    for batch in train_tqdm:
        optimizer.zero_grad()

        features = batch["features"].to(device)
        phone_ids = batch["phone_ids"].to(device)
        phone_labels = batch["phone_scores"].to(device)

        # warm_up_step = 100
        # if global_step <= warm_up_step and global_step % 5 == 0:
        #     warm_lr = (global_step / warm_up_step) * lr
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = warm_lr

        utterance_preds, phone_preds, word_preds = gopt_model(x=features.float(), phn=phone_ids.long())
        
        mask = phone_labels >=0
        phone_preds = phone_preds.squeeze(2)
        phone_preds = phone_preds * mask
        phone_labels = phone_labels * mask
        
        loss_phn = loss_fn(phone_preds, phone_labels)
        loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)
        
        loss_phn.backward()
        optimizer.step()

        # if global_step > warm_up_step:
        #     scheduler.step()
        
        global_step += 1
        train_tqdm.set_postfix(loss_phn=loss_phn.item())
    
    A_phn, A_phn_target = [], []
    for batch in testloader:
        features = batch["features"].to(device)
        phone_ids = batch["phone_ids"].to(device)
        phone_labels = batch["phone_scores"].to(device)
        
        utterance_preds, phone_preds, word_preds = gopt_model(x=features.float(), phn=phone_ids.long())
        
        phone_preds = phone_preds.detach().cpu()
        phone_labels = phone_labels.detach().cpu()
        
        A_phn.append(phone_preds[:, :, 0])
        A_phn_target.append(phone_labels)
        
    A_phn, A_phn_target  = torch.vstack(A_phn), torch.vstack(A_phn_target)
    # valid_token_mse, valid_token_mae, corr
    phn_mse, phn_mae, phn_corr = valid_phn(A_phn, A_phn_target)
    print(f"### Validation result: MSE={round(phn_mse, 4)} MAE={round(phn_mae, 4)} PCC={round(phn_corr, 4)}")

In [None]:
# Training: 100%|██████████| 1875/1875 [00:06<00:00, 282.83it/s, loss_phn=0.383] 
# ### Validation result: MSE=0.2248000055551529 MAE=0.32100000977516174 PCC=0.697
# Training: 100%|██████████| 1875/1875 [00:06<00:00, 269.10it/s, loss_phn=0.121] 
# ### Validation result: MSE=0.20890000462532043 MAE=0.3012000024318695 PCC=0.7107
# Training: 100%|██████████| 1875/1875 [00:07<00:00, 266.09it/s, loss_phn=0.408] 
# ### Validation result: MSE=0.20309999585151672 MAE=0.3140999972820282 PCC=0.7242
# Training: 100%|██████████| 1875/1875 [00:07<00:00, 242.88it/s, loss_phn=0.2]   
# ### Validation result: MSE=0.19470000267028809 MAE=0.2847999930381775 PCC=0.7351
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 226.88it/s, loss_phn=0.0942]
# ### Validation result: MSE=0.19599999487400055 MAE=0.3019999861717224 PCC=0.7397
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 211.98it/s, loss_phn=0.0843]
# ### Validation result: MSE=0.20020000636577606 MAE=0.2888000011444092 PCC=0.7336
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 215.95it/s, loss_phn=0.17]  
# ### Validation result: MSE=0.19990000128746033 MAE=0.3253999948501587 PCC=0.7476
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 217.42it/s, loss_phn=0.235] 
# ### Validation result: MSE=0.19089999794960022 MAE=0.2535000145435333 PCC=0.7488
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 215.28it/s, loss_phn=0.196] 
# ### Validation result: MSE=0.18160000443458557 MAE=0.25870001316070557 PCC=0.7551
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 229.91it/s, loss_phn=0.0789]
# ### Validation result: MSE=0.1860000044107437 MAE=0.27230000495910645 PCC=0.7551
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 220.74it/s, loss_phn=0.245] 
# ### Validation result: MSE=0.18070000410079956 MAE=0.2531000077724457 PCC=0.7567
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 217.68it/s, loss_phn=0.2]   
# ### Validation result: MSE=0.17980000376701355 MAE=0.2590999901294708 PCC=0.7587
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 223.91it/s, loss_phn=0.269] 
# ### Validation result: MSE=0.18520000576972961 MAE=0.2809000015258789 PCC=0.7549
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 218.08it/s, loss_phn=0.136] 
# ### Validation result: MSE=0.1770000010728836 MAE=0.2590999901294708 PCC=0.7639
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 231.73it/s, loss_phn=0.222] 
# ### Validation result: MSE=0.17919999361038208 MAE=0.27320000529289246 PCC=0.7616
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 218.86it/s, loss_phn=0.14]  
# ### Validation result: MSE=0.1808999925851822 MAE=0.2766999900341034 PCC=0.7592
# Training: 100%|██████████| 1875/1875 [00:07<00:00, 243.76it/s, loss_phn=0.109] 
# ### Validation result: MSE=0.1776999980211258 MAE=0.25119999051094055 PCC=0.7633
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 224.56it/s, loss_phn=0.216] 
# ### Validation result: MSE=0.1761000007390976 MAE=0.2637999951839447 PCC=0.7658
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 225.93it/s, loss_phn=0.105] 
# ### Validation result: MSE=0.17810000479221344 MAE=0.2831999957561493 PCC=0.7658
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 228.92it/s, loss_phn=0.147] 
# ### Validation result: MSE=0.17800000309944153 MAE=0.23960000276565552 PCC=0.7648
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 226.74it/s, loss_phn=0.0989]
# ### Validation result: MSE=0.1761000007390976 MAE=0.2502000033855438 PCC=0.7652
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 216.72it/s, loss_phn=0.136] 
# ### Validation result: MSE=0.17949999868869781 MAE=0.2687000036239624 PCC=0.7608
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 213.79it/s, loss_phn=0.0404]
# ### Validation result: MSE=0.17720000445842743 MAE=0.25189998745918274 PCC=0.7622
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 224.79it/s, loss_phn=0.0969]
# ### Validation result: MSE=0.1834000051021576 MAE=0.26460000872612 PCC=0.7539
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 220.28it/s, loss_phn=0.174] 
# ### Validation result: MSE=0.17720000445842743 MAE=0.2492000013589859 PCC=0.7639
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 219.94it/s, loss_phn=0.253] 
# ### Validation result: MSE=0.17839999496936798 MAE=0.251800000667572 PCC=0.7603
# Training: 100%|██████████| 1875/1875 [00:07<00:00, 235.49it/s, loss_phn=0.118] 
# ### Validation result: MSE=0.17749999463558197 MAE=0.26579999923706055 PCC=0.764
# Training: 100%|██████████| 1875/1875 [00:07<00:00, 236.08it/s, loss_phn=0.175] 
# ### Validation result: MSE=0.1818999946117401 MAE=0.24539999663829803 PCC=0.7586
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 230.76it/s, loss_phn=0.188] 
# ### Validation result: MSE=0.18610000610351562 MAE=0.27649998664855957 PCC=0.7561
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 226.30it/s, loss_phn=0.0649]
# ### Validation result: MSE=0.1850000023841858 MAE=0.25200000405311584 PCC=0.7616
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 222.24it/s, loss_phn=0.277] 
# ### Validation result: MSE=0.18240000307559967 MAE=0.2583000063896179 PCC=0.7576
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 233.27it/s, loss_phn=0.118] 
# ### Validation result: MSE=0.18219999969005585 MAE=0.25850000977516174 PCC=0.7564
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 222.01it/s, loss_phn=0.176] 
# ### Validation result: MSE=0.18449999392032623 MAE=0.2515999972820282 PCC=0.7514
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 224.62it/s, loss_phn=0.158] 
# ### Validation result: MSE=0.19040000438690186 MAE=0.24060000479221344 PCC=0.7489
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 224.26it/s, loss_phn=0.115] 
# ### Validation result: MSE=0.19280000030994415 MAE=0.23849999904632568 PCC=0.7519
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 227.97it/s, loss_phn=0.103] 
# ### Validation result: MSE=0.18389999866485596 MAE=0.23899999260902405 PCC=0.757
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 228.36it/s, loss_phn=0.0634]
# ### Validation result: MSE=0.1860000044107437 MAE=0.24860000610351562 PCC=0.7525
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 229.89it/s, loss_phn=0.104] 
# ### Validation result: MSE=0.18809999525547028 MAE=0.24690000712871552 PCC=0.7492
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 223.95it/s, loss_phn=0.102] 
# ### Validation result: MSE=0.18729999661445618 MAE=0.26499998569488525 PCC=0.7522
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.23it/s, loss_phn=0.165] 
# ### Validation result: MSE=0.188400000333786 MAE=0.24269999563694 PCC=0.7507
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 228.30it/s, loss_phn=0.0935]
# ### Validation result: MSE=0.18790000677108765 MAE=0.24289999902248383 PCC=0.7523
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 227.63it/s, loss_phn=0.0525] 
# ### Validation result: MSE=0.19089999794960022 MAE=0.2556999921798706 PCC=0.746
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 229.75it/s, loss_phn=0.0636] 
# ### Validation result: MSE=0.19419999420642853 MAE=0.2393999993801117 PCC=0.7454
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.12it/s, loss_phn=0.0951]
# ### Validation result: MSE=0.18770000338554382 MAE=0.24400000274181366 PCC=0.7504
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.98it/s, loss_phn=0.111] 
# ### Validation result: MSE=0.19140000641345978 MAE=0.24040000140666962 PCC=0.7481
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 226.29it/s, loss_phn=0.0311]
# ### Validation result: MSE=0.1923999935388565 MAE=0.2410999983549118 PCC=0.7451
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 226.37it/s, loss_phn=0.124] 
# ### Validation result: MSE=0.1940000057220459 MAE=0.2451000064611435 PCC=0.7458
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 218.81it/s, loss_phn=0.078] 
# ### Validation result: MSE=0.19059999287128448 MAE=0.25780001282691956 PCC=0.749
# Training: 100%|██████████| 1875/1875 [00:07<00:00, 241.92it/s, loss_phn=0.0821]
# ### Validation result: MSE=0.19480000436306 MAE=0.25189998745918274 PCC=0.7401
# Training: 100%|██████████| 1875/1875 [00:08<00:00, 218.41it/s, loss_phn=0.0773] 
# ### Validation result: MSE=0.2020999938249588 MAE=0.25769999623298645 PCC=0.7357