In [9]:
%cd /data/codes/prep_ps_pykaldi/
import pandas as pd
import numpy as np
import pickle
import json
import re

/data/codes/prep_ps_pykaldi


In [10]:
def load_data(data_dir):
    phone_ids = np.load(f'{data_dir}/phone_ids.npy')
    phone_scores = np.load(f'{data_dir}/phone_scores.npy')
    durations = np.load(f'{data_dir}/duration.npy')
    gops = np.load(f'{data_dir}/gop.npy')
    wavlm_features = np.load(f'{data_dir}/wavlm_features.npy')

    return phone_ids, phone_scores, durations, gops, wavlm_features


In [11]:
from torch.utils.data import Dataset, DataLoader
import torch

class PrepDataset(Dataset):
    def __init__(self, phone_ids, phone_scores, durations, gops, wavlm_features):
        self.phone_ids = phone_ids
        self.phone_scores = phone_scores
        self.gops = gops
        self.durations = durations
        self.wavlm_features = wavlm_features
        
    def __len__(self):
        return self.phone_ids.shape[0]
    
    def parse_data(self, phone_ids, phone_scores, gops, durations, wavlm_features):
        phone_ids = torch.tensor(phone_ids)
        durations = torch.tensor(durations)
        gops = torch.tensor(gops)
        phone_scores = torch.tensor(phone_scores) / 50
        # wavlm_features = torch.tensor(wavlm_features)

        features = torch.concat([gops, durations.unsqueeze(-1)], dim=-1)        
        return {
            "features": features,
            "phone_ids": phone_ids,
            "phone_scores":phone_scores
        }
        
    def __getitem__(self, index):
        phone_ids = self.phone_ids[index]
        phone_scores = self.phone_scores[index]
        gops = self.gops[index]
        durations = self.durations[index]
        wavlm_features = self.wavlm_features[index]

        return self.parse_data(
            phone_ids=phone_ids,
            phone_scores=phone_scores,
            gops=gops,
            durations=durations,
            wavlm_features=wavlm_features
        )

data_dir = "/data/codes/prep_ps_pykaldi/exp/sm/test"

phone_ids, phone_scores, durations, gops, wavlm_features = load_data(data_dir)
dataset_v1 = PrepDataset(phone_ids, phone_scores, durations, gops, wavlm_features)
dataloader = DataLoader(dataset_v1, batch_size=8)

for batch in dataloader:
    features = batch["features"]
    phone_ids = batch["phone_ids"]
    phone_scores = batch["phone_scores"]
    
    print(features.shape)
    print(phone_ids.shape)
    print(phone_scores.shape)
    break

torch.Size([8, 32, 83])
torch.Size([8, 32])
torch.Size([8, 32])


In [12]:
import math
import warnings
import torch
import torch.nn as nn
import numpy as np

def get_sinusoid_encoding(n_position, d_hid):
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class GOPT(nn.Module):
    def __init__(self, embed_dim, num_heads, depth, input_dim=84, max_length=50, num_phone=40):
        super().__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.blocks = nn.ModuleList(
            [
                Block(dim=embed_dim, num_heads=num_heads) 
                for i in range(depth)
                ]
            )

        self.pos_embed = nn.Parameter(torch.zeros(1, max_length+1, self.embed_dim))
        trunc_normal_(self.pos_embed, std=.02)

        self.in_proj = nn.Linear(self.input_dim, embed_dim)
        self.linear = nn.Linear(embed_dim * 2, embed_dim)
        self.mlp_head_phn = nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        self.mlp_head_word= nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        self.num_phone = num_phone
        self.phn_proj = nn.Linear(num_phone, embed_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mlp_head_utt = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, 1))

        trunc_normal_(self.cls_token, std=.02)

    def forward(self, x, phn):
        B = x.shape[0]
        phn_one_hot = torch.nn.functional.one_hot(phn.long()+1, num_classes=self.num_phone).float()
        phn_embed = self.phn_proj(phn_one_hot)

        if self.embed_dim != self.input_dim:
            x = self.in_proj(x)

        x = torch.cat([x, phn_embed], dim=-1)
        x = self.linear(x)
        
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed[:,:x.shape[1],:]

        for blk in self.blocks:
            x = blk(x)
        u = self.mlp_head_utt(x[:, 0])
        p = self.mlp_head_phn(x[:, 1:])
        w = self.mlp_head_word(x[:, 1:])
        return u, p, w

In [13]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

index = 15000
data_dir = "/data/codes/prep_ps_pykaldi/exp/sm/test"
phone_ids, phone_scores, durations, gops, wavlm_features = load_data(data_dir)
trainset = PrepDataset(phone_ids[:index], phone_scores[:index], durations[:index], gops[:index], wavlm_features[:index])
trainloader = DataLoader(trainset, batch_size=8, shuffle=True, drop_last=False)

# data_dir = "/data/codes/prep_ps_pykaldi/exp/sm/test"
# phone_ids, phone_scores, durations, gops, wavlm_features = load_data(data_dir)
testset = PrepDataset(phone_ids[index:], phone_scores[index:], durations[index:], gops[index:], wavlm_features[index:])
testloader = DataLoader(testset, batch_size=8, shuffle=True, drop_last=False)


In [14]:
from torch.optim.lr_scheduler import MultiStepLR
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gopt_model = GOPT(
    embed_dim=32, num_heads=1, 
    depth=3, input_dim=83, 
    max_length=128, num_phone=62).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

lr = 3e-4
optimizer = torch.optim.Adam(
    trainables, lr, weight_decay=5e-7, betas=(0.95, 0.999))

scheduler = MultiStepLR(
    optimizer, list(range(10, 100, 5)), gamma=0.5, last_epoch=-1)

loss_fn = nn.MSELoss()


In [15]:
def valid_phn(audio_output, target):
    valid_token_pred = []
    valid_token_target = []
    # audio_output = audio_output.squeeze(2)
    for i in range(audio_output.shape[0]):
        for j in range(audio_output.shape[1]):
            # only count valid tokens, not padded tokens (represented by negative values)
            if target[i, j] >= 0:
                valid_token_pred.append(audio_output[i, j])
                valid_token_target.append(target[i, j])
    valid_token_target = np.array(valid_token_target)
    valid_token_pred = np.array(valid_token_pred)

    valid_token_mse = np.mean((valid_token_target - valid_token_pred) ** 2)
    valid_token_mae = np.mean(np.abs(valid_token_target - valid_token_pred))
    corr = np.corrcoef(valid_token_pred, valid_token_target)[0, 1]
    return valid_token_mse, valid_token_mae, corr


In [16]:
from tqdm import tqdm

global_step = 0
for epoch in range(50):
    gopt_model.train()
    train_tqdm = tqdm(trainloader, "Training")

    for batch in train_tqdm:
        optimizer.zero_grad()

        features = batch["features"].to(device)
        phone_ids = batch["phone_ids"].to(device)
        phone_labels = batch["phone_scores"].to(device)

        # warm_up_step = 100
        # if global_step <= warm_up_step and global_step % 5 == 0:
        #     warm_lr = (global_step / warm_up_step) * lr
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = warm_lr

        utterance_preds, phone_preds, word_preds = gopt_model(x=features.float(), phn=phone_ids.long())
        
        mask = phone_labels >=0
        phone_preds = phone_preds.squeeze(2)
        phone_preds = phone_preds * mask
        phone_labels = phone_labels * mask
        
        loss_phn = loss_fn(phone_preds, phone_labels)
        loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)
        
        loss_phn.backward()
        optimizer.step()

        # if global_step > warm_up_step:
        #     scheduler.step()
        
        global_step += 1

        train_tqdm.set_postfix(loss_phn=loss_phn.item())
    
    A_phn, A_phn_target = [], []
    for batch in testloader:
        features = batch["features"].to(device)
        phone_ids = batch["phone_ids"].to(device)
        phone_labels = batch["phone_scores"].to(device)
        
        utterance_preds, phone_preds, word_preds = gopt_model(x=features.float(), phn=phone_ids.long())
        
        phone_preds = phone_preds.detach().cpu()
        phone_labels = phone_labels.detach().cpu()
        
        A_phn.append(phone_preds[:, :, 0])
        A_phn_target.append(phone_labels)
        
    A_phn, A_phn_target  = torch.vstack(A_phn), torch.vstack(A_phn_target)
    # valid_token_mse, valid_token_mae, corr
    phn_mse, phn_mae, phn_corr = valid_phn(A_phn, A_phn_target)
    print(f"### Validation result: MSE={round(phn_mse, 4)} MAE={round(phn_mae, 4)} PCC={round(phn_corr, 4)}")

Training: 100%|██████████| 1875/1875 [00:07<00:00, 243.89it/s, loss_phn=0.282] 


### Validation result: MSE=0.23899999260902405 MAE=0.3587000072002411 PCC=0.6648


Training: 100%|██████████| 1875/1875 [00:07<00:00, 235.96it/s, loss_phn=0.256] 


### Validation result: MSE=0.22609999775886536 MAE=0.3158999979496002 PCC=0.6814


Training: 100%|██████████| 1875/1875 [00:08<00:00, 227.73it/s, loss_phn=0.347] 


### Validation result: MSE=0.21979999542236328 MAE=0.3091999888420105 PCC=0.6931


Training: 100%|██████████| 1875/1875 [00:07<00:00, 243.54it/s, loss_phn=0.307] 


### Validation result: MSE=0.21879999339580536 MAE=0.2883000075817108 PCC=0.7017


Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.94it/s, loss_phn=0.17]  


### Validation result: MSE=0.21549999713897705 MAE=0.322299987077713 PCC=0.7027


Training: 100%|██████████| 1875/1875 [00:07<00:00, 238.55it/s, loss_phn=0.103] 


### Validation result: MSE=0.2102999985218048 MAE=0.3140000104904175 PCC=0.7127


Training: 100%|██████████| 1875/1875 [00:07<00:00, 236.69it/s, loss_phn=0.215] 


### Validation result: MSE=0.20319999754428864 MAE=0.2922999858856201 PCC=0.7199


Training: 100%|██████████| 1875/1875 [00:07<00:00, 238.75it/s, loss_phn=0.183] 


### Validation result: MSE=0.20250000059604645 MAE=0.2768999934196472 PCC=0.7244


Training: 100%|██████████| 1875/1875 [00:08<00:00, 230.50it/s, loss_phn=0.198] 


### Validation result: MSE=0.20090000331401825 MAE=0.28780001401901245 PCC=0.7236


Training: 100%|██████████| 1875/1875 [00:07<00:00, 243.33it/s, loss_phn=0.211] 


### Validation result: MSE=0.21160000562667847 MAE=0.2563999891281128 PCC=0.7235


Training: 100%|██████████| 1875/1875 [00:08<00:00, 230.15it/s, loss_phn=0.275] 


### Validation result: MSE=0.19820000231266022 MAE=0.2815000116825104 PCC=0.7284


Training: 100%|██████████| 1875/1875 [00:07<00:00, 239.84it/s, loss_phn=0.117] 


### Validation result: MSE=0.20110000669956207 MAE=0.2937000095844269 PCC=0.7276


Training: 100%|██████████| 1875/1875 [00:07<00:00, 238.79it/s, loss_phn=0.1]   


### Validation result: MSE=0.20000000298023224 MAE=0.2565000057220459 PCC=0.7352


Training: 100%|██████████| 1875/1875 [00:08<00:00, 234.34it/s, loss_phn=0.23]  


### Validation result: MSE=0.19709999859333038 MAE=0.2702000141143799 PCC=0.7319


Training: 100%|██████████| 1875/1875 [00:07<00:00, 242.37it/s, loss_phn=0.173] 


### Validation result: MSE=0.20200000703334808 MAE=0.28290000557899475 PCC=0.7221


Training: 100%|██████████| 1875/1875 [00:07<00:00, 247.96it/s, loss_phn=0.101] 


### Validation result: MSE=0.19339999556541443 MAE=0.2581000030040741 PCC=0.739


Training: 100%|██████████| 1875/1875 [00:07<00:00, 239.24it/s, loss_phn=0.257] 


### Validation result: MSE=0.19380000233650208 MAE=0.27880001068115234 PCC=0.7397


Training: 100%|██████████| 1875/1875 [00:07<00:00, 249.08it/s, loss_phn=0.184] 


### Validation result: MSE=0.19130000472068787 MAE=0.2791000008583069 PCC=0.7401


Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.38it/s, loss_phn=0.156] 


### Validation result: MSE=0.1898999959230423 MAE=0.2727000117301941 PCC=0.7418


Training: 100%|██████████| 1875/1875 [00:08<00:00, 230.19it/s, loss_phn=0.234] 


### Validation result: MSE=0.18930000066757202 MAE=0.2621000111103058 PCC=0.7448


Training: 100%|██████████| 1875/1875 [00:08<00:00, 228.14it/s, loss_phn=0.288] 


### Validation result: MSE=0.1898999959230423 MAE=0.2554999887943268 PCC=0.7458


Training: 100%|██████████| 1875/1875 [00:08<00:00, 230.51it/s, loss_phn=0.26]  


### Validation result: MSE=0.18960000574588776 MAE=0.2709999978542328 PCC=0.742


Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.39it/s, loss_phn=0.211] 


### Validation result: MSE=0.18889999389648438 MAE=0.26669999957084656 PCC=0.7465


Training: 100%|██████████| 1875/1875 [00:08<00:00, 233.17it/s, loss_phn=0.176] 


### Validation result: MSE=0.1867000013589859 MAE=0.2700999975204468 PCC=0.7477


Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.88it/s, loss_phn=0.311] 


### Validation result: MSE=0.18860000371932983 MAE=0.2791999876499176 PCC=0.7468


Training: 100%|██████████| 1875/1875 [00:07<00:00, 243.86it/s, loss_phn=0.182] 


### Validation result: MSE=0.18479999899864197 MAE=0.27320000529289246 PCC=0.7502


Training: 100%|██████████| 1875/1875 [00:07<00:00, 257.68it/s, loss_phn=0.116] 


### Validation result: MSE=0.19020000100135803 MAE=0.2572000026702881 PCC=0.7437


Training: 100%|██████████| 1875/1875 [00:07<00:00, 240.51it/s, loss_phn=0.269] 


### Validation result: MSE=0.1873999983072281 MAE=0.2540999948978424 PCC=0.7482


Training: 100%|██████████| 1875/1875 [00:07<00:00, 247.55it/s, loss_phn=0.14]  


### Validation result: MSE=0.1867000013589859 MAE=0.27000001072883606 PCC=0.7481


Training: 100%|██████████| 1875/1875 [00:07<00:00, 236.64it/s, loss_phn=0.0949]


### Validation result: MSE=0.18889999389648438 MAE=0.2605000138282776 PCC=0.7449


Training: 100%|██████████| 1875/1875 [00:07<00:00, 248.28it/s, loss_phn=0.258] 


### Validation result: MSE=0.18860000371932983 MAE=0.2694999873638153 PCC=0.7451


Training: 100%|██████████| 1875/1875 [00:07<00:00, 234.97it/s, loss_phn=0.148] 


### Validation result: MSE=0.18379999697208405 MAE=0.2603999972343445 PCC=0.7516


Training: 100%|██████████| 1875/1875 [00:07<00:00, 247.90it/s, loss_phn=0.129] 


### Validation result: MSE=0.18930000066757202 MAE=0.2676999866962433 PCC=0.7428


Training: 100%|██████████| 1875/1875 [00:08<00:00, 233.79it/s, loss_phn=0.126] 


### Validation result: MSE=0.19269999861717224 MAE=0.2612999975681305 PCC=0.7386


Training: 100%|██████████| 1875/1875 [00:08<00:00, 233.68it/s, loss_phn=0.165] 


### Validation result: MSE=0.1890999972820282 MAE=0.2680000066757202 PCC=0.7429


Training: 100%|██████████| 1875/1875 [00:07<00:00, 245.13it/s, loss_phn=0.262] 


### Validation result: MSE=0.1882999986410141 MAE=0.26100000739097595 PCC=0.747


Training: 100%|██████████| 1875/1875 [00:07<00:00, 250.13it/s, loss_phn=0.17]  


### Validation result: MSE=0.18880000710487366 MAE=0.2551000118255615 PCC=0.7459


Training: 100%|██████████| 1875/1875 [00:08<00:00, 233.97it/s, loss_phn=0.223] 


### Validation result: MSE=0.19380000233650208 MAE=0.2606000006198883 PCC=0.7405


Training: 100%|██████████| 1875/1875 [00:07<00:00, 235.82it/s, loss_phn=0.166] 


### Validation result: MSE=0.19249999523162842 MAE=0.2759000062942505 PCC=0.7405


Training: 100%|██████████| 1875/1875 [00:07<00:00, 243.85it/s, loss_phn=0.159] 


### Validation result: MSE=0.18940000236034393 MAE=0.2671999931335449 PCC=0.7433


Training: 100%|██████████| 1875/1875 [00:07<00:00, 241.48it/s, loss_phn=0.164] 


### Validation result: MSE=0.193900004029274 MAE=0.25760000944137573 PCC=0.7378


Training: 100%|██████████| 1875/1875 [00:07<00:00, 246.68it/s, loss_phn=0.179] 


### Validation result: MSE=0.1949000060558319 MAE=0.2757999897003174 PCC=0.7351


Training: 100%|██████████| 1875/1875 [00:07<00:00, 241.08it/s, loss_phn=0.0553]


### Validation result: MSE=0.18960000574588776 MAE=0.25679999589920044 PCC=0.7432


Training: 100%|██████████| 1875/1875 [00:07<00:00, 254.76it/s, loss_phn=0.127] 


### Validation result: MSE=0.2070000022649765 MAE=0.2476000040769577 PCC=0.7295


Training: 100%|██████████| 1875/1875 [00:07<00:00, 244.53it/s, loss_phn=0.21]  


### Validation result: MSE=0.1987999975681305 MAE=0.2718999981880188 PCC=0.7316


Training: 100%|██████████| 1875/1875 [00:07<00:00, 249.18it/s, loss_phn=0.24]  


### Validation result: MSE=0.20579999685287476 MAE=0.2833000123500824 PCC=0.7313


Training: 100%|██████████| 1875/1875 [00:07<00:00, 240.35it/s, loss_phn=0.112] 


### Validation result: MSE=0.19629999995231628 MAE=0.2517000138759613 PCC=0.736


Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.98it/s, loss_phn=0.214] 


### Validation result: MSE=0.19740000367164612 MAE=0.26330000162124634 PCC=0.7349


Training: 100%|██████████| 1875/1875 [00:07<00:00, 249.65it/s, loss_phn=0.145] 


### Validation result: MSE=0.20170000195503235 MAE=0.25690001249313354 PCC=0.7289


Training: 100%|██████████| 1875/1875 [00:08<00:00, 232.57it/s, loss_phn=0.126] 


### Validation result: MSE=0.19930000603199005 MAE=0.2775000035762787 PCC=0.7315
