In [1]:
!export CUDA_VISIBLE_DEVICES=0

%cd /data/codes/apa/train/
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np

import pickle
import json
import re
import os

from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch

from src.utils.train import (
    load_data,
    to_device,
    validate
)

from src.dataset import PrepDataset
from src.model import PrepModel

/data/codes/apa/train


In [2]:
ckpt_dir = '/data/codes/apa/train/exp/fine-tuning'
test_dir = "//data/codes/apa/train/data/feats/train/train-data-type-12-filtered/"
ckpt_path = '/data/codes/apa/train/exp/dev/ckpts-eph=14-mse=0.12870000302791595/model.pt'

max_length=128
relative2id_path="/data/codes/apa/train/exp/dicts/relative2id.json"
phone2id_path="/data/codes/apa/train/exp/dicts/phone_dict.json"

ids, phone_ids_path, word_ids_path, \
    phone_scores_path, word_scores_path, sentence_scores_path, fluency_score_path, intonation_score_path, \
    durations_path, gops_path, relative_positions_path, wavlm_features_path = load_data(test_dir)

testset = PrepDataset(
    ids=ids, 
    phone_ids_path=phone_ids_path, 
    word_ids_path=word_ids_path, 
    phone_scores_path=phone_scores_path, 
    word_scores_path=word_scores_path, 
    sentence_scores_path=sentence_scores_path, 
    fluency_score_path=fluency_score_path,
    intonation_scores_path=intonation_score_path,
    durations_path=durations_path, 
    gops_path=gops_path, 
    relative_positions_path=relative_positions_path, 
    wavlm_features_path=wavlm_features_path,
    relative2id_path=relative2id_path, 
    phone2id_path=phone2id_path,
    max_length=max_length,

)

testloader = DataLoader(
    testset, 
    num_workers=1,
    batch_size=64, 
    shuffle=False, 
    drop_last=True, 
)

In [3]:
embed_dim=32
num_heads=1
depth=3
input_dim=855
num_phone=44
max_length=256
dropout=0.1

lr=5e-4
weight_decay=5e-7
betas=(0.95, 0.999)

device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")

gopt_model = PrepModel(
    embed_dim=embed_dim, 
    num_heads=num_heads, 
    depth=depth, 
    input_dim=input_dim, 
    max_length=max_length, 
    num_phone=num_phone, 
    dropout=dropout).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(
    trainables, lr, 
    weight_decay=weight_decay, 
    betas=betas
)

loss_fn = nn.MSELoss()

In [4]:
state_dict = torch.load(ckpt_path, map_location="cpu")

gopt_model.load_state_dict(state_dict)

<All keys matched successfully>

In [5]:
valid_result = validate(
    epoch=-1, 
    optimizer=optimizer,
    gopt_model=gopt_model, 
    testloader=testloader, 
    best_mse=-1, 
    is_save=False,
    ckpt_dir=ckpt_dir,
    device=device)

### F1 Score: 
               precision    recall  f1-score   support

         0.0       0.97      0.90      0.93    199927
         1.0       0.21      0.54      0.31     14599
         2.0       0.84      0.64      0.73     31139

    accuracy                           0.84    245665
   macro avg       0.67      0.69      0.66    245665
weighted avg       0.91      0.84      0.87    245665

### Validation result (epoch=-1)
  Phone level (ACC): MSE=0.102  MAE=0.178  PCC=0.857 
   Word level (ACC): MSE=0.054  MAE=0.167  PCC=0.835 
    Utt level (ACC): MSE=0.047  MAE=0.160  PCC=0.837 
    Utt level (Intonation):  MSE=0.145  MAE=0.328  PCC=0.674
    Utt level (Fluency):  MSE=0.112  MAE=0.292  PCC=0.673

