In [2]:
%cd /data/codes/prep_ps_pykaldi/
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    ConfusionMatrixDisplay
)
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

import pickle
import json
import re

from torch.utils.data import DataLoader
from torch import nn

from src.dataset import PrepDataset

/data/codes/prep_ps_pykaldi


In [3]:
def convert_score_to_color(score, YELLOW_GREEN=80/50, RED_YELLOW=30/50):
    if RED_YELLOW is not None:
        LABEL2ID = {"GREEN": 0, "YELLOW": 1, "RED":2}
        red_index = score < RED_YELLOW
        yellow_index = ((score >= RED_YELLOW).int() & (score < YELLOW_GREEN).int()).bool()
        green_index = score >= YELLOW_GREEN
    else:
        LABEL2ID = {"GREEN": 0, "YELLOW": 1, "RED":1}
        RED_YELLOW = 30/50
        red_index = score < RED_YELLOW
        yellow_index = ((score >= RED_YELLOW).int() & (score < YELLOW_GREEN).int()).bool()
        green_index = score >= YELLOW_GREEN


    score[red_index] = LABEL2ID["RED"]
    score[yellow_index] = LABEL2ID["YELLOW"]
    score[green_index] = LABEL2ID["GREEN"]

    return score

def load_data(data_dir):
    phone_ids = np.load(f'{data_dir}/phone_ids.npy')
    word_ids = np.load(f'{data_dir}/word_ids.npy')
    
    phone_scores = np.load(f'{data_dir}/phone_scores.npy')
    word_scores = np.load(f'{data_dir}/word_scores.npy')
    sentence_scores = np.load(f'{data_dir}/sentence_scores.npy')

    durations = np.load(f'{data_dir}/duration.npy')
    gops = np.load(f'{data_dir}/gop.npy')
    wavlm_features = np.load(f'{data_dir}/wavlm_features.npy')

    return phone_ids, word_ids, phone_scores, word_scores, sentence_scores, durations, gops, wavlm_features


In [4]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from src.model import GOPT

train_dir = "/data/codes/prep_ps_pykaldi/exp/sm/train"

phone_ids, word_ids, phone_scores, word_scores, \
    sentence_scores, durations, gops, wavlm_features = load_data(train_dir)
trainset = PrepDataset(
    phone_ids, word_ids, 
    phone_scores, word_scores, sentence_scores, 
    durations, gops, wavlm_features
    )

trainloader = DataLoader(trainset, batch_size=8, shuffle=True, drop_last=False)

test_dir = "/data/codes/prep_ps_pykaldi/exp/sm/test"
phone_ids, word_ids, phone_scores, word_scores, \
    sentence_scores, durations, gops, wavlm_features = load_data(test_dir)
testset = PrepDataset(
    phone_ids, word_ids, 
    phone_scores, word_scores, sentence_scores, 
    durations, gops, wavlm_features
    )

testloader = DataLoader(testset, batch_size=8, shuffle=False, drop_last=True)

In [5]:
embed_dim=32
num_heads=1
depth=3
input_dim=851
num_phone=62
max_length=128

lr=3e-4
weight_decay=5e-7
betas=(0.95, 0.999)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gopt_model = GOPT(
    embed_dim=embed_dim, num_heads=num_heads, 
    depth=depth, input_dim=input_dim, 
    max_length=max_length, num_phone=num_phone).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(
    trainables, lr, 
    weight_decay=weight_decay, 
    betas=betas
)

loss_fn = nn.MSELoss()

In [6]:
def valid_phn(predict, target):
    preds, targs = [], []

    for i in range(predict.shape[0]):
        for j in range(predict.shape[1]):
            if target[i, j] >= 0:
                preds.append(predict[i, j])
                targs.append(target[i, j])
    targs = np.array(targs)
    preds = np.array(preds)

    mse = np.mean((targs - preds) ** 2)
    mae = np.mean(np.abs(targs - preds))
    corr = np.corrcoef(preds, targs)[0, 1]
    return mse, mae, corr


In [7]:
def valid_wrd(predict, target, word_id):
    preds, targs = [], []

    for i in range(target.shape[0]):
        prev_w_id, start_id = 0, 0
        # for each token
        for j in range(target.shape[1]):
            cur_w_id = word_id[i, j].int()
            # if a new word
            if cur_w_id != prev_w_id:
                # average each phone belongs to the word
                preds.append(np.mean(predict[i, start_id: j].numpy(), axis=0))
                targs.append(np.mean(target[i, start_id: j].numpy(), axis=0))

                if cur_w_id == -1:
                    break
                else:
                    prev_w_id = cur_w_id
                    start_id = j

    preds = np.array(preds)
    targs = np.array(targs).round(2)

    word_mse = np.mean((preds - targs) ** 2)
    wrd_mae = np.mean(np.abs(preds - targs))
    word_corr = np.corrcoef(preds, targs)[0, 1]
    
    return word_mse, wrd_mae, word_corr

In [8]:
def valid_utt(predict, target):
    utt_mse = np.mean(((predict[:, 0] - target[:, 0]) ** 2).numpy())
    utt_mae = np.mean((np.abs(predict[:, 0] - target[:, 0])).numpy())
    
    utt_corr = np.corrcoef(predict[:, 0], target[:, 0])[0, 1]
    return utt_mse, utt_mae, utt_corr


In [9]:
def to_device(batch, device):
    features = batch["features"].to(device)
    phone_ids = batch["phone_ids"].to(device)
    word_ids = batch["word_ids"]
    
    phone_labels = batch["phone_scores"].to(device)
    word_labels = batch["word_scores"].to(device)
    utterance_labels = batch["sentence_scores"].to(device)

    return features, phone_ids, word_ids, phone_labels, word_labels, utterance_labels

def to_cpu(preds, labels):
    preds = preds.detach().cpu().squeeze(-1)
    labels = labels.detach().cpu()

    return preds, labels

In [10]:
def load_pred_and_label(pred_path, label_path):
    pred = np.load(pred_path)
    label = np.load(label_path)

    pred = np.concatenate(pred)
    label = np.concatenate(label)
    index = label != -1    
    
    return label[index], pred[index]

def save_confusion_matrix_figure(
        fig_path, pred_path, label_path, YELLOW_GREEN=70/50, RED_YELLOW=35/50):
    
    label, pred = load_pred_and_label(pred_path=pred_path, label_path=label_path)
    
    actual = convert_score_to_color(
        torch.from_numpy(label), YELLOW_GREEN=YELLOW_GREEN, RED_YELLOW=RED_YELLOW)
    
    predicted = convert_score_to_color(
        torch.from_numpy(pred), YELLOW_GREEN=YELLOW_GREEN, RED_YELLOW=RED_YELLOW)
    
    cfs_mtr = confusion_matrix(actual, predicted)
    cfs_mtr = cfs_mtr / cfs_mtr.sum(axis=1, keepdims=True)
    if RED_YELLOW is not None:
        cm_display = ConfusionMatrixDisplay(
            confusion_matrix = cfs_mtr, display_labels = ["GREEN", "YELLOW", "RED"])
    else:
        cm_display = ConfusionMatrixDisplay(
            confusion_matrix = cfs_mtr, display_labels = ["CORRECT", "INCORRECT"])

    plt.title("Confusion Matrix")
    cm_display.plot(cmap='Blues')
    plt.savefig(fig_path) 

def save(epoch, output_dir, model, optimizer, phone_desicion_result, \
    phone_predicts, phone_labels, word_predicts, word_labels, utterance_predicts, utterance_labels):
    
    model_path = f'{output_dir}/model.pt'
    optimizer_path = f'{output_dir}/optimizer.pt'
    phone_desicion_result_path = f'{output_dir}/phone_result'

    phone_predict_path = f'{output_dir}/phn_pred.npy'
    phone_label_path = f'{output_dir}/phn_label.npy'
    word_predict_path = f'{output_dir}/wrd_pred.npy'
    word_label_path = f'{output_dir}/wrd_label.npy'
    utterance_predict_path = f'{output_dir}/utt_pred.npy'
    utterance_label_path = f'{output_dir}/utt_label.npy'

    three_class_fig_path = f'{output_dir}/confusion_matrix_three_class.png'
    two_class_fig_path = f'{output_dir}/confusion_matrix_two_class.png'

    with open(phone_desicion_result_path, "w") as f:
        f.write(phone_desicion_result)

    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)
    np.save(phone_predict_path, phone_predicts)
    np.save(phone_label_path, phone_labels)
    np.save(word_predict_path, word_predicts)
    np.save(word_label_path, word_labels)
    np.save(utterance_predict_path, utterance_predicts)
    np.save(utterance_label_path, utterance_labels)
    save_confusion_matrix_figure(three_class_fig_path, phone_predict_path, phone_label_path, YELLOW_GREEN=85/50, RED_YELLOW=35/50)
    save_confusion_matrix_figure(two_class_fig_path, phone_predict_path, phone_label_path, YELLOW_GREEN=85/50, RED_YELLOW=None)

    print(f'Save state dict and result to {output_dir}')

In [11]:
@torch.no_grad()
def validate(epoch, gopt_model, testloader, best_mse, ckpt_dir):
    gopt_model.eval()
    A_phn, A_phn_target = [], []
    A_utt, A_utt_target = [], []
    A_wrd, A_wrd_target, A_wrd_id = [], [], []

    for batch in testloader:
        features, phone_ids, word_ids, \
            phone_labels, word_labels, utterance_labels = to_device(batch, device)
        
        utterance_preds, phone_preds, word_preds = gopt_model(
            x=features.float(), phn=phone_ids.long())
        
        phone_preds, phone_labels = to_cpu(phone_preds, phone_labels)
        word_preds, word_labels = to_cpu(word_preds, word_labels)
        utterance_preds, utterance_labels = to_cpu(utterance_preds, utterance_labels)
        
        A_phn.append(phone_preds), A_phn_target.append(phone_labels)
        A_utt.append(utterance_preds), A_utt_target.append(utterance_labels)
        A_wrd.append(word_preds), A_wrd_target.append(word_labels), A_wrd_id.append(word_ids)
    
    # phone level
    A_phn, A_phn_target  = torch.vstack(A_phn), torch.vstack(A_phn_target)
    decision_result = calculate_phone_decision_result(A_phn, A_phn_target)

    # word level
    A_word, A_word_target, A_word_id = torch.vstack(A_wrd), torch.vstack(A_wrd_target), torch.vstack(A_wrd_id) 

    # utterance level
    A_utt, A_utt_target = torch.vstack(A_utt), torch.vstack(A_utt_target)

    # valid_token_mse, mae, corr
    phn_mse, phn_mae, phn_corr = valid_phn(A_phn, A_phn_target)
    word_mse, wrd_mae, word_corr = valid_wrd(A_word, A_word_target, A_word_id)
    utt_mse, utt_mae, utt_corr = valid_utt(A_utt, A_utt_target)

    if phn_mse < best_mse:
        best_mse = phn_mse
        save(
            epoch=epoch,
            output_dir=ckpt_dir, 
            model=gopt_model, 
            optimizer=optimizer, 
            phone_desicion_result=decision_result, 
            phone_predicts=A_phn.numpy(), 
            phone_labels=A_phn_target.numpy(), 
            word_predicts=A_word.numpy(), 
            word_labels=A_word_target.numpy(), 
            utterance_predicts=A_utt.numpy(), 
            utterance_labels=A_utt_target.numpy()
        )

    print(f"### Validation result (epoch={epoch})")
    print("  Phone level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} ".format(phn_mse, phn_mae, phn_corr))
    print("   Word level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} ".format(word_mse, wrd_mae, word_corr))
    print("    Utt level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} ".format(utt_mse, utt_mae, utt_corr))

    return {
        "phn_mse": phn_mse, 
        "phn_mae": phn_mae,
        "phn_corr": phn_corr,
        "word_mse": word_mse,
        "wrd_mae": wrd_mae,
        "word_corr": word_corr,
        "utt_mse": utt_mse,
        "utt_mae": utt_mae,
        "utt_corr": utt_corr,
        "best_mse": best_mse
    }

def calculate_phone_decision_result(A_phn, A_phn_target):
    indices = A_phn_target != -1
    _label = A_phn_target[indices].clone()
    _pred = A_phn[indices].clone()

    converted_pred = convert_score_to_color(_pred).view(-1)
    converted_label = convert_score_to_color(_label).view(-1)

    result = classification_report(y_true=converted_label, y_pred=converted_pred)
    print("### F1 Score: \n", result)

    return result


In [12]:
def calculate_losses(phone_preds, phone_labels, word_preds, word_labels, utterance_preds, utterance_labels):
    # phone level
    mask = phone_labels >=0
    phone_preds = phone_preds.squeeze(2) * mask
    phone_labels = phone_labels * mask
    
    loss_phn = loss_fn(phone_preds, phone_labels)
    loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    # utterance level
    loss_utt = loss_fn(utterance_preds.squeeze(1) ,utterance_labels)

    # word level
    mask = word_labels >= 0      
    word_preds = word_preds.squeeze(2) * mask
    word_labels = word_labels * mask
    
    loss_word = loss_fn(word_preds, word_labels)
    loss_word = loss_word * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    return loss_phn, loss_utt, loss_word

In [13]:
global_step = 0
best_mse = 1e5
num_epoch = 50 
phone_weight = 1.0
word_weight = 1.0
utterance_weight = 1.0
ckpt_dir = '/data/codes/prep_ps_pykaldi/exp/ckpts'

for epoch in range(num_epoch):
    gopt_model.train()
    train_tqdm = tqdm(trainloader, "Training")
    for batch in train_tqdm:
        optimizer.zero_grad()

        features, phone_ids, word_ids, \
            phone_labels, word_labels, utterance_labels = to_device(batch, device)

        utterance_preds, phone_preds, word_preds = gopt_model(
            x=features.float(), phn=phone_ids.long())
        
        loss_phn, loss_utt, loss_word = calculate_losses(
            phone_preds=phone_preds, 
            phone_labels=phone_labels, 
            word_preds=word_preds, 
            word_labels=word_labels, 
            utterance_preds=utterance_preds, 
            utterance_labels=utterance_labels)

        # total loss
        loss = phone_weight*loss_phn + word_weight*loss_word + utterance_weight*loss_utt
        
        loss.backward()
        optimizer.step()
        
        global_step += 1
        train_tqdm.set_postfix(
            loss=loss.item(), 
            loss_phn=loss_phn.item(), 
            loss_word=loss_word.item(), 
            loss_utt=loss_utt.item())
    
    valid_result = validate(
        epoch=epoch, 
        gopt_model=gopt_model, 
        testloader=testloader, 
        best_mse=best_mse, 
        ckpt_dir=ckpt_dir)
    
    best_mse = valid_result["best_mse"]
    


Training: 100%|██████████| 18740/18740 [03:42<00:00, 84.14it/s, loss=0.212, loss_phn=0.141, loss_utt=0.0263, loss_word=0.0452]    


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.94      0.83      0.88     99093
         1.0       0.20      0.51      0.29     10946
         2.0       0.74      0.53      0.62     19040

    accuracy                           0.76    129079
   macro avg       0.63      0.62      0.60    129079
weighted avg       0.85      0.76      0.79    129079

Save state dict and result to /data/codes/prep_ps_pykaldi/exp/ckpts
### Validation result (epoch=0)
  Phone level:  MSE=0.189  MAE=0.255  PCC=0.750 
   Word level:  MSE=0.082  MAE=0.211  PCC=0.692 
    Utt level:  MSE=0.067  MAE=0.188  PCC=0.719 


Training: 100%|██████████| 18740/18740 [03:42<00:00, 84.12it/s, loss=0.237, loss_phn=0.129, loss_utt=0.0407, loss_word=0.0673]    


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.81      0.88     99093
         1.0       0.21      0.55      0.30     10946
         2.0       0.73      0.57      0.64     19040

    accuracy                           0.75    129079
   macro avg       0.63      0.65      0.61    129079
weighted avg       0.85      0.75      0.79    129079

Save state dict and result to /data/codes/prep_ps_pykaldi/exp/ckpts
### Validation result (epoch=1)
  Phone level:  MSE=0.181  MAE=0.255  PCC=0.764 
   Word level:  MSE=0.084  MAE=0.217  PCC=0.711 
    Utt level:  MSE=0.072  MAE=0.199  PCC=0.732 


Training: 100%|██████████| 18740/18740 [03:42<00:00, 84.15it/s, loss=0.146, loss_phn=0.0833, loss_utt=0.03, loss_word=0.0331]     


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.82      0.88     99093
         1.0       0.21      0.51      0.30     10946
         2.0       0.72      0.61      0.66     19040

    accuracy                           0.76    129079
   macro avg       0.62      0.65      0.61    129079
weighted avg       0.85      0.76      0.80    129079

### Validation result (epoch=2)
  Phone level:  MSE=0.182  MAE=0.248  PCC=0.771 
   Word level:  MSE=0.073  MAE=0.199  PCC=0.732 
    Utt level:  MSE=0.065  MAE=0.190  PCC=0.742 


Training: 100%|██████████| 18740/18740 [02:41<00:00, 115.70it/s, loss=0.174, loss_phn=0.0936, loss_utt=0.0345, loss_word=0.0461]   


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.82      0.88     99093
         1.0       0.21      0.53      0.30     10946
         2.0       0.74      0.60      0.66     19040

    accuracy                           0.77    129079
   macro avg       0.63      0.65      0.62    129079
weighted avg       0.86      0.77      0.80    129079

Save state dict and result to /data/codes/prep_ps_pykaldi/exp/ckpts
### Validation result (epoch=3)
  Phone level:  MSE=0.174  MAE=0.243  PCC=0.777 
   Word level:  MSE=0.075  MAE=0.206  PCC=0.731 
    Utt level:  MSE=0.063  MAE=0.187  PCC=0.746 


Training: 100%|██████████| 18740/18740 [03:42<00:00, 84.25it/s, loss=0.293, loss_phn=0.197, loss_utt=0.0146, loss_word=0.081]     


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.83      0.88     99093
         1.0       0.21      0.56      0.30     10946
         2.0       0.76      0.54      0.63     19040

    accuracy                           0.76    129079
   macro avg       0.64      0.64      0.61    129079
weighted avg       0.86      0.76      0.80    129079

### Validation result (epoch=4)
  Phone level:  MSE=0.174  MAE=0.248  PCC=0.770 
   Word level:  MSE=0.075  MAE=0.205  PCC=0.722 
    Utt level:  MSE=0.062  MAE=0.187  PCC=0.748 


Training: 100%|██████████| 18740/18740 [03:42<00:00, 84.17it/s, loss=0.285, loss_phn=0.147, loss_utt=0.06, loss_word=0.0773]      


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.96      0.80      0.87     99093
         1.0       0.21      0.58      0.30     10946
         2.0       0.74      0.61      0.67     19040

    accuracy                           0.75    129079
   macro avg       0.63      0.66      0.61    129079
weighted avg       0.86      0.75      0.79    129079

Save state dict and result to /data/codes/prep_ps_pykaldi/exp/ckpts
### Validation result (epoch=5)
  Phone level:  MSE=0.170  MAE=0.249  PCC=0.784 
   Word level:  MSE=0.082  MAE=0.226  PCC=0.745 
    Utt level:  MSE=0.074  MAE=0.214  PCC=0.759 


Training: 100%|██████████| 18740/18740 [03:43<00:00, 84.01it/s, loss=0.296, loss_phn=0.137, loss_utt=0.08, loss_word=0.0789]      


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.96      0.78      0.86     99093
         1.0       0.20      0.61      0.30     10946
         2.0       0.75      0.61      0.67     19040

    accuracy                           0.74    129079
   macro avg       0.64      0.67      0.61    129079
weighted avg       0.87      0.74      0.79    129079

Save state dict and result to /data/codes/prep_ps_pykaldi/exp/ckpts
### Validation result (epoch=6)
  Phone level:  MSE=0.169  MAE=0.261  PCC=0.789 
   Word level:  MSE=0.084  MAE=0.228  PCC=0.747 
    Utt level:  MSE=0.071  MAE=0.207  PCC=0.762 


Training: 100%|██████████| 18740/18740 [03:43<00:00, 84.01it/s, loss=0.145, loss_phn=0.0901, loss_utt=0.0196, loss_word=0.035]    


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.84      0.89     99093
         1.0       0.22      0.59      0.32     10946
         2.0       0.79      0.53      0.64     19040

    accuracy                           0.77    129079
   macro avg       0.66      0.65      0.62    129079
weighted avg       0.87      0.77      0.80    129079

Save state dict and result to /data/codes/prep_ps_pykaldi/exp/ckpts
### Validation result (epoch=7)
  Phone level:  MSE=0.159  MAE=0.237  PCC=0.791 
   Word level:  MSE=0.067  MAE=0.190  PCC=0.753 
    Utt level:  MSE=0.056  MAE=0.174  PCC=0.765 


Training: 100%|██████████| 18740/18740 [03:43<00:00, 83.80it/s, loss=0.22, loss_phn=0.13, loss_utt=0.0472, loss_word=0.043]       


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.96      0.78      0.86     99093
         1.0       0.20      0.63      0.30     10946
         2.0       0.76      0.59      0.66     19040

    accuracy                           0.74    129079
   macro avg       0.64      0.66      0.61    129079
weighted avg       0.87      0.74      0.79    129079

### Validation result (epoch=8)
  Phone level:  MSE=0.166  MAE=0.266  PCC=0.789 
   Word level:  MSE=0.077  MAE=0.222  PCC=0.757 
    Utt level:  MSE=0.067  MAE=0.207  PCC=0.771 


Training: 100%|██████████| 18740/18740 [03:42<00:00, 84.13it/s, loss=0.441, loss_phn=0.284, loss_utt=0.0506, loss_word=0.107]     


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.85      0.90     99093
         1.0       0.23      0.55      0.32     10946
         2.0       0.78      0.56      0.65     19040

    accuracy                           0.78    129079
   macro avg       0.65      0.65      0.62    129079
weighted avg       0.86      0.78      0.81    129079

### Validation result (epoch=9)
  Phone level:  MSE=0.159  MAE=0.228  PCC=0.793 
   Word level:  MSE=0.066  MAE=0.187  PCC=0.759 
    Utt level:  MSE=0.055  MAE=0.171  PCC=0.772 


Training: 100%|██████████| 18740/18740 [03:43<00:00, 83.95it/s, loss=0.136, loss_phn=0.0694, loss_utt=0.0312, loss_word=0.0355]   


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.83      0.89     99093
         1.0       0.22      0.56      0.32     10946
         2.0       0.76      0.62      0.68     19040

    accuracy                           0.77    129079
   macro avg       0.64      0.67      0.63    129079
weighted avg       0.86      0.77      0.81    129079

