In [None]:
!export CUDA_VISIBLE_DEVICES=0

%cd /data/codes/apa/train/
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np

import pickle
import json
import re
import os

from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch

from src.utils.train import (
    load_data,
    to_device,
    validate
)

from src.dataset import PrepDataset
from src.model import PrepModel

In [None]:
ckpt_dir = '/data/codes/apa/train/exp/dev'
train_dir = "/data/codes/apa/train/data/feats/train/train-data-type-12/"
test_dir = "/data/codes/apa/train/data/feats/train/train-data-type-12/"

ids, phone_ids, word_ids, phone_scores, word_scores, sentence_scores, durations, \
    gops, relative_positions, wavlm_features_path = load_data(train_dir)

trainset = PrepDataset(
    ids=ids,
    phone_ids=phone_ids,
    word_ids=word_ids,
    durations=durations,
    gops=gops,
    phone_scores=phone_scores,
    relative_positions=relative_positions,
    sentence_scores=sentence_scores,
    wavlm_features_path=wavlm_features_path,
    word_scores=word_scores
)

trainloader = DataLoader(
    trainset, 
    batch_size=8, 
    num_workers=1,
    shuffle=True, 
    drop_last=False, 
    pin_memory=True
)

ids, phone_ids, word_ids, phone_scores, word_scores, sentence_scores, durations, \
    gops, relative_positions, wavlm_features_path = load_data(test_dir)
    
testset = PrepDataset(
    ids=ids,
    phone_ids=phone_ids,
    word_ids=word_ids,
    durations=durations,
    gops=gops,
    phone_scores=phone_scores,
    relative_positions=relative_positions,
    sentence_scores=sentence_scores,
    wavlm_features_path=wavlm_features_path,
    word_scores=word_scores
)


testloader = DataLoader(
    testset, 
    num_workers=1,
    batch_size=64, 
    shuffle=False, 
    drop_last=True, 
)

In [None]:
embed_dim=32
num_heads=1
depth=3
input_dim=855
num_phone=44
max_length=128
dropout=0.1

lr=1e-3
weight_decay=5e-7
betas=(0.95, 0.999)

device = torch.device(
    "cuda:0" if torch.cuda.is_available() else "cpu")

gopt_model = PrepModel(
    embed_dim=embed_dim, 
    num_heads=num_heads, 
    depth=depth, 
    input_dim=input_dim, 
    max_length=max_length, 
    num_phone=num_phone, 
    dropout=dropout).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(
    trainables, lr, 
    weight_decay=weight_decay, 
    betas=betas
)

loss_fn = nn.MSELoss()

In [None]:
def calculate_losses(phone_preds, phone_labels, word_preds, word_labels, utterance_preds, utterance_labels):
    # phone level
    mask = phone_labels >=0
    phone_preds = phone_preds.squeeze(2) * mask
    phone_labels = phone_labels * mask
    
    loss_phn = loss_fn(phone_preds, phone_labels)
    loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    # utterance level
    loss_utt = loss_fn(utterance_preds.squeeze(1) ,utterance_labels)

    # word level
    mask = word_labels >= 0      
    word_preds = word_preds.squeeze(2) * mask
    word_labels = word_labels * mask
    
    loss_word = loss_fn(word_preds, word_labels)
    loss_word = loss_word * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    return loss_phn, loss_utt, loss_word

In [None]:
global_step = 0
best_mse = 1e5
num_epoch = 50 
phone_weight = 1.0
word_weight = 1.0
utterance_weight = 1.0

cur_lr = lr
for epoch in range(num_epoch):
    if epoch >= 10 and epoch % 3 == 0:
        cur_lr = (4 / 5) * cur_lr 
        for param_group in optimizer.param_groups:
            param_group['lr'] = cur_lr

    gopt_model.train()
    train_tqdm = tqdm(trainloader, "Training")
    for batch in train_tqdm:
        optimizer.zero_grad()

        ids, features, phone_ids, word_ids, relative_positions,\
            phone_labels, word_labels, utterance_labels = to_device(batch, device)
        
        utterance_preds, phone_preds, word_preds = gopt_model(
            x=features.float(), phn=phone_ids.long(), rel_pos=relative_positions.long())
                
        loss_phn, loss_utt, loss_word = calculate_losses(
            phone_preds=phone_preds, 
            phone_labels=phone_labels, 
            word_preds=word_preds, 
            word_labels=word_labels, 
            utterance_preds=utterance_preds, 
            utterance_labels=utterance_labels)

        loss = phone_weight*loss_phn + word_weight*loss_word + utterance_weight*loss_utt
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(gopt_model.parameters(), 1.0)
        
        optimizer.step()
        
        global_step += 1
        train_tqdm.set_postfix(
            lr=cur_lr,
            loss=loss.item(), 
            loss_phn=loss_phn.item(), 
            loss_word=loss_word.item(), 
            loss_utt=loss_utt.item())
    
    valid_result = validate(
        epoch=epoch, 
        optimizer=optimizer,
        gopt_model=gopt_model, 
        testloader=testloader, 
        best_mse=best_mse, 
        ckpt_dir=ckpt_dir,
        device=device)
    
    best_mse = valid_result["best_mse"]
    global_step += 1