In [1]:
!export CUDA_VISIBLE_DEVICES=1

In [2]:
%cd /data/codes/apa/train/
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    ConfusionMatrixDisplay
)
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch

import pickle
import json
import re

from torch.utils.data import DataLoader
from torch import nn

from src.dataset import PrepDataset
import os

/data/codes/apa/train


In [3]:
def convert_score_to_color(score, YELLOW_GREEN=80/50, RED_YELLOW=30/50):
    if RED_YELLOW is not None:
        LABEL2ID = {"GREEN": 0, "YELLOW": 1, "RED":2}
        red_index = score < RED_YELLOW
        yellow_index = ((score >= RED_YELLOW).int() & (score < YELLOW_GREEN).int()).bool()
        green_index = score >= YELLOW_GREEN
    else:
        LABEL2ID = {"GREEN": 0, "YELLOW": 1, "RED":1}
        RED_YELLOW = 30/50
        red_index = score < RED_YELLOW
        yellow_index = ((score >= RED_YELLOW).int() & (score < YELLOW_GREEN).int()).bool()
        green_index = score >= YELLOW_GREEN


    score[red_index] = LABEL2ID["RED"]
    score[yellow_index] = LABEL2ID["YELLOW"]
    score[green_index] = LABEL2ID["GREEN"]

    return score

def load_data(data_dir):
    phone_ids = np.load(f'{data_dir}/phone_ids.npy')
    word_ids = np.load(f'{data_dir}/word_ids.npy')
    
    phone_scores = np.load(f'{data_dir}/phone_scores.npy')
    word_scores = np.load(f'{data_dir}/word_scores.npy')
    sentence_scores = np.load(f'{data_dir}/sentence_scores.npy')

    durations = np.load(f'{data_dir}/duration.npy')
    gops = np.load(f'{data_dir}/gop.npy')
    # wavlm_features = np.load(f'{data_dir}/wavlm_features.npy')
    wavlm_features_path = f'{data_dir}/wavlm_features'

    relative_positions = np.load(f'{data_dir}/relative_positions.npy')

    return phone_ids, word_ids, phone_scores, word_scores, sentence_scores, durations, gops, relative_positions, wavlm_features_path


In [4]:
from torch.utils.data import Dataset, DataLoader
from src.indexed_dataset import IndexedDataset
import torch

class PrepDataset(Dataset):
    def __init__(self, phone_ids, word_ids, phone_scores, word_scores, \
            sentence_scores, durations, gops, relative_positions, wavlm_features_path):
        self.phone_ids = phone_ids
        self.word_ids = word_ids

        self.phone_scores = phone_scores
        self.word_scores = word_scores
        self.sentence_scores = sentence_scores

        self.gops = gops
        self.durations = durations
        self.wavlm_features = IndexedDataset(wavlm_features_path)
        self.relative_positions = relative_positions

    def __len__(self):
        return self.phone_ids.shape[0]
    
    def parse_data(self, phone_ids, word_ids, phone_scores, word_scores, \
            sentence_scores, durations, gops, wavlm_features, relative_positions):
        
        phone_ids = torch.tensor(phone_ids)
        word_ids = torch.tensor(word_ids)

        phone_scores = torch.tensor(phone_scores).float().clone()
        word_scores = torch.tensor(word_scores).float().clone()
        sentence_scores = torch.tensor(sentence_scores).float().clone()

        phone_scores[phone_scores != -1] /= 50
        word_scores[word_scores != -1] /= 50
        sentence_scores /= 50

        durations = torch.tensor(durations)
        gops = torch.tensor(gops)
        wavlm_features = torch.tensor(wavlm_features)
        relative_positions = torch.tensor(relative_positions)

        features = torch.concat([gops, durations.unsqueeze(-1), wavlm_features], dim=-1)        
        return {
            "features": features,
            "phone_ids": phone_ids,
            "word_ids": word_ids,
            "phone_scores":phone_scores,
            "word_scores":word_scores,
            "sentence_scores":sentence_scores,
            "relative_positions": relative_positions
        }
        
    def __getitem__(self, index):
        phone_ids = self.phone_ids[index]
        word_ids = self.word_ids[index]

        phone_scores = self.phone_scores[index]
        word_scores = self.word_scores[index]
        sentence_scores = self.sentence_scores[index]

        gops = self.gops[index]
        durations = self.durations[index]
        wavlm_features = self.wavlm_features[index]
        relative_positions = self.relative_positions[index]

        return self.parse_data(
            phone_ids=phone_ids,
            word_ids=word_ids,
            phone_scores=phone_scores,
            word_scores=word_scores,
            sentence_scores=sentence_scores,
            gops=gops,
            durations=durations,
            wavlm_features=wavlm_features,
            relative_positions=relative_positions
        )

In [5]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

ckpt_dir = '/data/codes/apa/train/exps/test'
train_dir = "/data/codes/apa/train/exps/features/train/merged"
test_dir = "/data/codes/apa/train/exps/features/test/out-long"

phone_ids, word_ids, phone_scores, word_scores, sentence_scores, durations, \
    gops, relative_positions, wavlm_features_path = load_data(train_dir)
trainset = PrepDataset(
    phone_ids, word_ids, 
    phone_scores, word_scores, sentence_scores, 
    durations, gops, relative_positions, wavlm_features_path)

trainloader = DataLoader(trainset, batch_size=8, shuffle=True, drop_last=False, pin_memory=True, num_workers=1)

phone_ids, word_ids, phone_scores, word_scores, sentence_scores, durations, \
    gops, relative_positions, wavlm_features_path = load_data(test_dir)
    
testset = PrepDataset(
    phone_ids, word_ids, 
    phone_scores, word_scores, sentence_scores, 
    durations, gops, relative_positions, wavlm_features_path)

testloader = DataLoader(testset, batch_size=64, shuffle=False, drop_last=True, num_workers=1)

In [None]:
from src.model import PrepModel

In [None]:
embed_dim=32
num_heads=1
depth=3
input_dim=853
num_phone=43
max_length=128

lr=1e-3
weight_decay=5e-7
betas=(0.95, 0.999)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

gopt_model = PrepModel(
    embed_dim=embed_dim, num_heads=num_heads, 
    depth=depth, input_dim=input_dim, 
    max_length=max_length, num_phone=num_phone, dropout=0.1).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(
    trainables, lr, 
    weight_decay=weight_decay, 
    betas=betas
)

loss_fn = nn.MSELoss()

In [None]:
state_dict = torch.load("/data/codes/apa/train/exps/test/ckpts-eph=10-mse=0.16089999675750732/model.pt")
gopt_model.load_state_dict(state_dict)

In [None]:
def valid_phn(predict, target):
    preds, targs = [], []

    for i in range(predict.shape[0]):
        for j in range(predict.shape[1]):
            if target[i, j] >= 0:
                preds.append(predict[i, j])
                targs.append(target[i, j])
    targs = np.array(targs)
    preds = np.array(preds)

    mse = np.mean((targs - preds) ** 2)
    mae = np.mean(np.abs(targs - preds))
    corr = np.corrcoef(preds, targs)[0, 1]
    return mse, mae, corr


In [None]:
def valid_wrd(predict, target, word_id):
    preds, targs = [], []

    for i in range(target.shape[0]):
        prev_w_id, start_id = 0, 0
        # for each token
        for j in range(target.shape[1]):
            cur_w_id = word_id[i, j].int()
            # if a new word
            if cur_w_id != prev_w_id:
                # average each phone belongs to the word
                preds.append(np.mean(predict[i, start_id: j].numpy(), axis=0))
                targs.append(np.mean(target[i, start_id: j].numpy(), axis=0))

                if cur_w_id == -1:
                    break
                else:
                    prev_w_id = cur_w_id
                    start_id = j

    preds = np.array(preds)
    targs = np.array(targs).round(2)

    word_mse = np.mean((preds - targs) ** 2)
    wrd_mae = np.mean(np.abs(preds - targs))
    word_corr = np.corrcoef(preds, targs)[0, 1]
    
    return word_mse, wrd_mae, word_corr

In [None]:
def valid_utt(predict, target):
    utt_mse = np.mean(((predict[:, 0] - target[:, 0]) ** 2).numpy())
    utt_mae = np.mean((np.abs(predict[:, 0] - target[:, 0])).numpy())
    
    utt_corr = np.corrcoef(predict[:, 0], target[:, 0])[0, 1]
    return utt_mse, utt_mae, utt_corr


In [None]:
def to_device(batch, device):
    features = batch["features"].to(device)
    phone_ids = batch["phone_ids"].to(device)
    relative_positions = batch["relative_positions"].to(device)
    word_ids = batch["word_ids"]
    
    phone_labels = batch["phone_scores"].to(device)
    word_labels = batch["word_scores"].to(device)
    utterance_labels = batch["sentence_scores"].to(device)

    return features, phone_ids, word_ids, relative_positions, phone_labels, word_labels, utterance_labels

def to_cpu(preds, labels):
    preds = preds.detach().cpu().squeeze(-1)
    labels = labels.detach().cpu()

    return preds, labels

In [None]:
def load_pred_and_label(pred_path, label_path):
    pred = np.load(pred_path)
    label = np.load(label_path)

    pred = np.concatenate(pred)
    label = np.concatenate(label)
    index = label != -1    
    
    return label[index], pred[index]

def save_confusion_matrix_figure(
        fig_path, pred_path, label_path, YELLOW_GREEN=80/50, RED_YELLOW=30/50):
    
    label, pred = load_pred_and_label(pred_path=pred_path, label_path=label_path)
    
    actual = convert_score_to_color(
        torch.from_numpy(label), YELLOW_GREEN=YELLOW_GREEN, RED_YELLOW=RED_YELLOW)
    
    predicted = convert_score_to_color(
        torch.from_numpy(pred), YELLOW_GREEN=YELLOW_GREEN, RED_YELLOW=RED_YELLOW)
    
    cfs_mtr = confusion_matrix(actual, predicted)
    cfs_mtr = cfs_mtr / cfs_mtr.sum(axis=1, keepdims=True)
    if RED_YELLOW is not None:
        cm_display = ConfusionMatrixDisplay(
            confusion_matrix = cfs_mtr, display_labels = ["GREEN", "YELLOW", "RED"])
    else:
        cm_display = ConfusionMatrixDisplay(
            confusion_matrix = cfs_mtr, display_labels = ["CORRECT", "INCORRECT"])

    plt.title("Confusion Matrix")
    cm_display.plot(cmap='Blues')
    plt.savefig(fig_path) 
    plt.close()

def save(epoch, output_dir, model, optimizer, phone_desicion_result, \
    phone_predicts, phone_labels, word_predicts, word_labels, utterance_predicts, utterance_labels):
    
    model_path = f'{output_dir}/model.pt'
    optimizer_path = f'{output_dir}/optimizer.pt'
    phone_desicion_result_path = f'{output_dir}/phone_result'

    phone_predict_path = f'{output_dir}/phn_pred.npy'
    phone_label_path = f'{output_dir}/phn_label.npy'
    word_predict_path = f'{output_dir}/wrd_pred.npy'
    word_label_path = f'{output_dir}/wrd_label.npy'
    utterance_predict_path = f'{output_dir}/utt_pred.npy'
    utterance_label_path = f'{output_dir}/utt_label.npy'

    three_class_fig_path = f'{output_dir}/confusion_matrix_three_class.png'
    two_class_fig_path = f'{output_dir}/confusion_matrix_two_class.png'

    with open(phone_desicion_result_path, "w") as f:
        f.write(phone_desicion_result)

    torch.save(model.state_dict(), model_path)
    torch.save(optimizer.state_dict(), optimizer_path)
    np.save(phone_predict_path, phone_predicts)
    np.save(phone_label_path, phone_labels)
    np.save(word_predict_path, word_predicts)
    np.save(word_label_path, word_labels)
    np.save(utterance_predict_path, utterance_predicts)
    np.save(utterance_label_path, utterance_labels)
    save_confusion_matrix_figure(three_class_fig_path, phone_predict_path, phone_label_path, YELLOW_GREEN=80/50, RED_YELLOW=40/50)
    save_confusion_matrix_figure(two_class_fig_path, phone_predict_path, phone_label_path, YELLOW_GREEN=80/50, RED_YELLOW=None)

    print(f'Save state dict and result to {output_dir}')

In [None]:
@torch.no_grad()
def validate(epoch, gopt_model, testloader, best_mse, ckpt_dir):
    gopt_model.eval()
    A_phn, A_phn_target = [], []
    A_utt, A_utt_target = [], []
    A_wrd, A_wrd_target, A_wrd_id = [], [], []

    for batch in testloader:
        features, phone_ids, word_ids, relative_positions,\
            phone_labels, word_labels, utterance_labels = to_device(batch, device)
        
        utterance_preds, phone_preds, word_preds = gopt_model(
            x=features.float(), phn=phone_ids.long(), rel_pos=relative_positions.long())
        
        phone_preds, phone_labels = to_cpu(phone_preds, phone_labels)
        word_preds, word_labels = to_cpu(word_preds, word_labels)
        utterance_preds, utterance_labels = to_cpu(utterance_preds, utterance_labels)
        
        A_phn.append(phone_preds), A_phn_target.append(phone_labels)
        A_utt.append(utterance_preds), A_utt_target.append(utterance_labels)
        A_wrd.append(word_preds), A_wrd_target.append(word_labels), A_wrd_id.append(word_ids)
    
    # phone level
    A_phn, A_phn_target  = torch.vstack(A_phn), torch.vstack(A_phn_target)
    decision_result = calculate_phone_decision_result(A_phn, A_phn_target)

    # word level
    A_word, A_word_target, A_word_id = torch.vstack(A_wrd), torch.vstack(A_wrd_target), torch.vstack(A_wrd_id) 

    # utterance level
    A_utt, A_utt_target = torch.vstack(A_utt), torch.vstack(A_utt_target)

    # valid_token_mse, mae, corr
    phn_mse, phn_mae, phn_corr = valid_phn(A_phn, A_phn_target)
    word_mse, wrd_mae, word_corr = valid_wrd(A_word, A_word_target, A_word_id)
    utt_mse, utt_mae, utt_corr = valid_utt(A_utt, A_utt_target)

    if phn_mse < best_mse:
        best_mse = phn_mse
    ckpt_dir = f'{ckpt_dir}/ckpts-eph={epoch}-mse={round(phn_mse, 4)}'
    os.makedirs(ckpt_dir)
    
    save(
        epoch=epoch,
        output_dir=ckpt_dir, 
        model=gopt_model, 
        optimizer=optimizer, 
        phone_desicion_result=decision_result,
        phone_predicts=A_phn.numpy(), 
        phone_labels=A_phn_target.numpy(), 
        word_predicts=A_word.numpy(), 
        word_labels=A_word_target.numpy(), 
        utterance_predicts=A_utt.numpy(), 
        utterance_labels=A_utt_target.numpy()
    )
    
    with open(f'{ckpt_dir}/pcc', "w") as f:
        f.write("Phone level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} \n".format(phn_mse, phn_mae, phn_corr))
        f.write("Word level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} \n".format(word_mse, wrd_mae, word_corr))
        f.write("Utt level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} \n".format(utt_mse, utt_mae, utt_corr))

    print(f"### Validation result (epoch={epoch})")
    print("  Phone level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} ".format(phn_mse, phn_mae, phn_corr))
    print("   Word level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} ".format(word_mse, wrd_mae, word_corr))
    print("    Utt level:  MSE={:.3f}  MAE={:.3f}  PCC={:.3f} ".format(utt_mse, utt_mae, utt_corr))

    return {
        "phn_mse": phn_mse, 
        "phn_mae": phn_mae,
        "phn_corr": phn_corr,
        "word_mse": word_mse,
        "wrd_mae": wrd_mae,
        "word_corr": word_corr,
        "utt_mse": utt_mse,
        "utt_mae": utt_mae,
        "utt_corr": utt_corr,
        "best_mse": best_mse
    }

def calculate_phone_decision_result(A_phn, A_phn_target):
    indices = A_phn_target != -1
    _label = A_phn_target[indices].clone()
    _pred = A_phn[indices].clone()

    converted_pred = convert_score_to_color(_pred).view(-1)
    converted_label = convert_score_to_color(_label).view(-1)

    result = classification_report(y_true=converted_label, y_pred=converted_pred)
    print("### F1 Score: \n", result)

    return result



In [None]:
def calculate_losses(phone_preds, phone_labels, word_preds, word_labels, utterance_preds, utterance_labels):
    # phone level
    mask = phone_labels >=0
    phone_preds = phone_preds.squeeze(2) * mask
    phone_labels = phone_labels * mask
    
    loss_phn = loss_fn(phone_preds, phone_labels)
    loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    # utterance level
    loss_utt = loss_fn(utterance_preds.squeeze(1) ,utterance_labels)
    # loss_utt = torch.tensor(0)

    # word level
    mask = word_labels >= 0      
    word_preds = word_preds.squeeze(2) * mask
    word_labels = word_labels * mask
    
    loss_word = loss_fn(word_preds, word_labels)
    loss_word = loss_word * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    return loss_phn, loss_utt, loss_word

In [None]:
global_step = 0
best_mse = 1e5
num_epoch = 50 
phone_weight = 1.0
word_weight = 1.0
utterance_weight = 1.0

cur_lr = lr
for epoch in range(num_epoch):
    if epoch >= 10 and epoch % 3 == 0:
        cur_lr = (4 / 5) * cur_lr 
        for param_group in optimizer.param_groups:
            param_group['lr'] = cur_lr

    gopt_model.train()
    train_tqdm = tqdm(trainloader, "Training")
    for batch in train_tqdm:
        optimizer.zero_grad()

        features, phone_ids, word_ids, relative_positions,\
            phone_labels, word_labels, utterance_labels = to_device(batch, device)
        
        utterance_preds, phone_preds, word_preds = gopt_model(
            x=features.float(), phn=phone_ids.long(), rel_pos=relative_positions.long())
                
        loss_phn, loss_utt, loss_word = calculate_losses(
            phone_preds=phone_preds, 
            phone_labels=phone_labels, 
            word_preds=word_preds, 
            word_labels=word_labels, 
            utterance_preds=utterance_preds, 
            utterance_labels=utterance_labels)

        loss = phone_weight*loss_phn + word_weight*loss_word + utterance_weight*loss_utt
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(gopt_model.parameters(), 1.0)
        
        optimizer.step()
        
        global_step += 1
        train_tqdm.set_postfix(
            lr=cur_lr,
            loss=loss.item(), 
            loss_phn=loss_phn.item(), 
            loss_word=loss_word.item(), 
            loss_utt=loss_utt.item())
    
    valid_result = validate(
        epoch=epoch, 
        gopt_model=gopt_model, 
        testloader=testloader, 
        best_mse=best_mse, 
        ckpt_dir=ckpt_dir)
    
    best_mse = valid_result["best_mse"]
    global_step += 1

Training: 100%|██████████| 33968/33968 [02:34<00:00, 220.42it/s, loss=0.318, loss_phn=0.27, loss_utt=0.03, loss_word=0.0181, lr=0.001]       


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.93      0.83      0.88     61683
         1.0       0.20      0.57      0.29      6508
         2.0       0.81      0.42      0.56     12113

    accuracy                           0.75     80304
   macro avg       0.65      0.61      0.58     80304
weighted avg       0.86      0.75      0.78     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=0-mse=0.19220000505447388
### Validation result (epoch=0)
  Phone level:  MSE=0.192  MAE=0.271  PCC=0.742 
   Word level:  MSE=0.137  MAE=0.271  PCC=0.735 
    Utt level:  MSE=0.050  MAE=0.177  PCC=0.784 


Training: 100%|██████████| 33968/33968 [02:38<00:00, 213.79it/s, loss=1.04, loss_phn=0.361, loss_utt=0.304, loss_word=0.377, lr=0.001]       


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.93      0.87      0.90     61683
         1.0       0.21      0.53      0.30      6508
         2.0       0.82      0.42      0.56     12113

    accuracy                           0.77     80304
   macro avg       0.65      0.61      0.58     80304
weighted avg       0.85      0.77      0.80     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=1-mse=0.1873999983072281
### Validation result (epoch=1)
  Phone level:  MSE=0.187  MAE=0.250  PCC=0.754 
   Word level:  MSE=0.136  MAE=0.262  PCC=0.752 
    Utt level:  MSE=0.052  MAE=0.177  PCC=0.807 


Training: 100%|██████████| 33968/33968 [02:39<00:00, 213.19it/s, loss=0.266, loss_phn=0.133, loss_utt=0.018, loss_word=0.115, lr=0.001]      


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.94      0.83      0.88     61683
         1.0       0.21      0.58      0.30      6508
         2.0       0.80      0.51      0.62     12113

    accuracy                           0.76     80304
   macro avg       0.65      0.64      0.60     80304
weighted avg       0.86      0.76      0.80     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=2-mse=0.17430000007152557
### Validation result (epoch=2)
  Phone level:  MSE=0.174  MAE=0.251  PCC=0.771 
   Word level:  MSE=0.124  MAE=0.260  PCC=0.763 
    Utt level:  MSE=0.041  MAE=0.160  PCC=0.829 


Training: 100%|██████████| 33968/33968 [02:39<00:00, 212.45it/s, loss=0.193, loss_phn=0.146, loss_utt=0.0186, loss_word=0.0287, lr=0.001]    


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.82      0.88     61683
         1.0       0.20      0.61      0.30      6508
         2.0       0.82      0.48      0.61     12113

    accuracy                           0.75     80304
   macro avg       0.65      0.64      0.60     80304
weighted avg       0.87      0.75      0.79     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=3-mse=0.17110000550746918
### Validation result (epoch=3)
  Phone level:  MSE=0.171  MAE=0.261  PCC=0.776 
   Word level:  MSE=0.123  MAE=0.255  PCC=0.766 
    Utt level:  MSE=0.040  MAE=0.159  PCC=0.831 


Training: 100%|██████████| 33968/33968 [02:40<00:00, 212.16it/s, loss=0.0796, loss_phn=0.0467, loss_utt=0.00585, loss_word=0.0271, lr=0.001] 


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.94      0.86      0.90     61683
         1.0       0.22      0.54      0.31      6508
         2.0       0.82      0.48      0.61     12113

    accuracy                           0.78     80304
   macro avg       0.66      0.63      0.60     80304
weighted avg       0.86      0.78      0.81     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=4-mse=0.16910000145435333
### Validation result (epoch=4)
  Phone level:  MSE=0.169  MAE=0.237  PCC=0.780 
   Word level:  MSE=0.128  MAE=0.250  PCC=0.772 
    Utt level:  MSE=0.045  MAE=0.166  PCC=0.829 


Training: 100%|██████████| 33968/33968 [02:40<00:00, 211.63it/s, loss=0.141, loss_phn=0.0969, loss_utt=0.0192, loss_word=0.0246, lr=0.001]   


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.94      0.85      0.89     61683
         1.0       0.21      0.58      0.31      6508
         2.0       0.83      0.46      0.59     12113

    accuracy                           0.77     80304
   macro avg       0.66      0.63      0.60     80304
weighted avg       0.87      0.77      0.80     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=5-mse=0.16899999976158142
### Validation result (epoch=5)
  Phone level:  MSE=0.169  MAE=0.243  PCC=0.779 
   Word level:  MSE=0.121  MAE=0.244  PCC=0.773 
    Utt level:  MSE=0.045  MAE=0.165  PCC=0.840 


Training: 100%|██████████| 33968/33968 [02:40<00:00, 211.58it/s, loss=0.401, loss_phn=0.333, loss_utt=0.04, loss_word=0.0283, lr=0.001]      


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.94      0.86      0.90     61683
         1.0       0.22      0.56      0.32      6508
         2.0       0.82      0.50      0.62     12113

    accuracy                           0.78     80304
   macro avg       0.66      0.64      0.61     80304
weighted avg       0.86      0.78      0.81     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=6-mse=0.16410000622272491
### Validation result (epoch=6)
  Phone level:  MSE=0.164  MAE=0.236  PCC=0.786 
   Word level:  MSE=0.116  MAE=0.243  PCC=0.782 
    Utt level:  MSE=0.041  MAE=0.159  PCC=0.835 


Training: 100%|██████████| 33968/33968 [02:40<00:00, 211.42it/s, loss=0.15, loss_phn=0.0951, loss_utt=0.00779, loss_word=0.0468, lr=0.001]   


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.84      0.89     61683
         1.0       0.21      0.60      0.31      6508
         2.0       0.82      0.50      0.62     12113

    accuracy                           0.77     80304
   macro avg       0.66      0.64      0.61     80304
weighted avg       0.87      0.77      0.80     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=7-mse=0.1607999950647354
### Validation result (epoch=7)
  Phone level:  MSE=0.161  MAE=0.241  PCC=0.790 
   Word level:  MSE=0.117  MAE=0.250  PCC=0.779 
    Utt level:  MSE=0.035  MAE=0.148  PCC=0.855 


Training: 100%|██████████| 33968/33968 [02:41<00:00, 210.83it/s, loss=0.461, loss_phn=0.191, loss_utt=0.115, loss_word=0.156, lr=0.001]      


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.94      0.86      0.90     61683
         1.0       0.22      0.58      0.32      6508
         2.0       0.85      0.44      0.58     12113

    accuracy                           0.78     80304
   macro avg       0.67      0.63      0.60     80304
weighted avg       0.87      0.78      0.80     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=8-mse=0.16609999537467957
### Validation result (epoch=8)
  Phone level:  MSE=0.166  MAE=0.244  PCC=0.784 
   Word level:  MSE=0.121  MAE=0.248  PCC=0.782 
    Utt level:  MSE=0.044  MAE=0.165  PCC=0.841 


Training: 100%|██████████| 33968/33968 [02:40<00:00, 211.21it/s, loss=0.0746, loss_phn=0.0379, loss_utt=0.0179, loss_word=0.0187, lr=0.001]  


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.84      0.89     61683
         1.0       0.21      0.59      0.31      6508
         2.0       0.81      0.52      0.64     12113

    accuracy                           0.77     80304
   macro avg       0.66      0.65      0.61     80304
weighted avg       0.87      0.77      0.80     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=9-mse=0.16179999709129333
### Validation result (epoch=9)
  Phone level:  MSE=0.162  MAE=0.239  PCC=0.789 
   Word level:  MSE=0.114  MAE=0.241  PCC=0.787 
    Utt level:  MSE=0.034  MAE=0.144  PCC=0.858 


Training: 100%|██████████| 33968/33968 [02:52<00:00, 197.02it/s, loss=0.246, loss_phn=0.177, loss_utt=0.00494, loss_word=0.0642, lr=0.001]   


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.94      0.86      0.90     61683
         1.0       0.22      0.58      0.32      6508
         2.0       0.84      0.47      0.61     12113

    accuracy                           0.78     80304
   macro avg       0.67      0.64      0.61     80304
weighted avg       0.87      0.78      0.81     80304

Save state dict and result to /data/codes/apa/train/exps/test/ckpts-eph=10-mse=0.16089999675750732
### Validation result (epoch=10)
  Phone level:  MSE=0.161  MAE=0.237  PCC=0.791 
   Word level:  MSE=0.114  MAE=0.239  PCC=0.785 
    Utt level:  MSE=0.036  MAE=0.147  PCC=0.854 


Training: 100%|██████████| 33968/33968 [02:44<00:00, 206.93it/s, loss=0.864, loss_phn=0.74, loss_utt=0.0514, loss_word=0.0723, lr=0.001]     


### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.95      0.83      0.88     61683
         1.0       0.21      0.61      0.31      6508
         2.0       0.82      0.51      0.63     12113

    accuracy                           0.76     80304
   macro avg       0.66      0.65      0.61     80304
weighted avg       0.87      0.76      0.80     80304

