#### train

In [None]:
!export CUDA_VISIBLE_DEVICES=0

%cd /data/codes/apa/train/
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle
import json
import re
import os

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch import nn
import torch

from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    ConfusionMatrixDisplay
)

from src.dataset import PrepDataset
from src.model import PrepModel
from src.utils.train import (
    load_data,
    convert_score_to_color,
    to_device,
    valid_phn,
    valid_utt,
    valid_wrd,
    to_cpu,
    load_pred_and_label,
    save_confusion_matrix_figure,
    validate,
    save
)


In [None]:
ckpt_dir = '/data/codes/apa/train/exp/ckpts/dev'
in_dir = "/data/codes/apa/train/data/feats/train/train-data-type-12"
out_dir = f'{in_dir}-filtered'

if not os.path.exists(out_dir):
    os.mkdir(out_dir)

In [None]:
max_length=128
relative2id_path="/data/codes/apa/train/exp/dicts/relative2id.json"
phone2id_path="/data/codes/apa/train/exp/dicts/phone_dict.json"

ids, phone_ids_path, word_ids_path, \
    phone_scores_path, word_scores_path, sentence_scores_path, fluency_score_path, intonation_score_path, \
    durations_path, gops_path, relative_positions_path, wavlm_features_path = load_data(in_dir)

dataset = PrepDataset(
    ids=ids, 
    phone_ids_path=phone_ids_path, 
    word_ids_path=word_ids_path, 
    phone_scores_path=phone_scores_path, 
    word_scores_path=word_scores_path, 
    sentence_scores_path=sentence_scores_path, 
    fluency_score_path=fluency_score_path,
    intonation_scores_path=intonation_score_path,
    durations_path=durations_path, 
    gops_path=gops_path, 
    relative_positions_path=relative_positions_path, 
    wavlm_features_path=wavlm_features_path,
    relative2id_path=relative2id_path, 
    phone2id_path=phone2id_path,
    max_length=max_length,
)

dataloader = DataLoader(
    dataset, 
    batch_size=8, 
    num_workers=1,
    shuffle=True, 
    drop_last=True, 
    pin_memory=True, 
)

for i in tqdm(range(len(dataset))):
    batch = dataset[i]
    pass

In [None]:
embed_dim=32
num_heads=1
depth=3
input_dim=855
num_phone=44
max_length=256

lr=1e-3
weight_decay=5e-7
betas=(0.95, 0.999)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

gopt_model = PrepModel(
    embed_dim=embed_dim, num_heads=num_heads, 
    depth=depth, input_dim=input_dim, 
    max_length=max_length, num_phone=num_phone, dropout=0.1).to(device)

trainables = [p for p in gopt_model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(
    trainables, lr, 
    weight_decay=weight_decay, 
    betas=betas
)

loss_fn = nn.MSELoss()

In [None]:
def calculate_losses(
        phone_preds, phone_labels, word_preds, word_labels, 
        utterance_preds, utterance_labels, fluency_preds, fluency_labels,
        intonation_preds, intonation_labels):
    
    # phone level
    mask = phone_labels >=0
    phone_preds = phone_preds.squeeze(2) * mask
    phone_labels = phone_labels * mask
    
    loss_phn = loss_fn(phone_preds, phone_labels)
    loss_phn = loss_phn * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    # utterance level
    loss_utt = loss_fn(utterance_preds.squeeze(1) ,utterance_labels)

    loss_utt_flu = loss_fn(fluency_preds.squeeze(1) ,fluency_labels)
    loss_utt_int = loss_fn(intonation_preds.squeeze(1) ,intonation_labels)

    # word level
    mask = word_labels >= 0      
    word_preds = word_preds.squeeze(2) * mask
    word_labels = word_labels * mask
    
    loss_word = loss_fn(word_preds, word_labels)
    loss_word = loss_word * (mask.shape[0] * mask.shape[1]) / torch.sum(mask)

    return loss_phn, loss_utt, loss_word, loss_utt_flu, loss_utt_int

In [None]:
global_step = 0
best_mse = 1e5
num_epoch = 20
phone_weight = 1.0
word_weight = 1.0
utterance_weight = 1.0

cur_lr = lr
for epoch in range(num_epoch):
    if epoch >= 5 and epoch % 3 == 0:
        cur_lr = (4 / 5) * cur_lr 
        for param_group in optimizer.param_groups:
            param_group['lr'] = cur_lr

    gopt_model.train()
    train_tqdm = tqdm(dataloader, "Training")
    for batch in train_tqdm:
        optimizer.zero_grad()

        ids, features, phone_ids, word_ids, relative_positions, \
            phone_labels, word_labels, utterance_labels, \
            fluency_labels, intonation_labels = to_device(batch, device)
        
        utterance_preds, phone_preds, word_preds, flu_preds, int_preds = gopt_model(
            x=features.float(), phn=phone_ids.long(), rel_pos=relative_positions.long())
                
        loss_phn, loss_utt, loss_word, loss_utt_flu, loss_utt_int = calculate_losses(
            phone_preds=phone_preds, 
            phone_labels=phone_labels, 
            word_preds=word_preds, 
            word_labels=word_labels, 
            utterance_preds=utterance_preds, 
            utterance_labels=utterance_labels,
            fluency_preds=flu_preds, 
            fluency_labels=fluency_labels,
            intonation_preds=int_preds,
            intonation_labels=intonation_labels
        )

        loss = phone_weight*loss_phn + word_weight*loss_word + \
            utterance_weight*(loss_utt + loss_utt_flu + loss_utt_int)/3
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(gopt_model.parameters(), 1.0)
        
        optimizer.step()
        
        global_step += 1
        train_tqdm.set_postfix(
            lr=cur_lr,
            loss=loss.item(), 
            loss_phn=loss_phn.item(), 
            loss_word=loss_word.item(), 
            loss_utt=loss_utt.item())
    
    valid_result = validate(
        epoch=epoch, 
        optimizer=optimizer,
        gopt_model=gopt_model, 
        testloader=dataloader, 
        best_mse=best_mse, 
        ckpt_dir=ckpt_dir,
        device=device)
    
    best_mse = valid_result["best_mse"]
    global_step += 1

#### infer

In [None]:
!export CUDA_VISIBLE_DEVICES=0

import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import numpy as np
import pickle
import torch
import json
import re
import os

from torch.utils.data import Dataset, DataLoader
from torch import nn

from src.utils.train import (
    to_device,
    load_data,
    load_id
)

In [None]:
ids, phone_ids_path, word_ids_path, \
    phone_scores_path, word_scores_path, sentence_scores_path, fluency_score_path, intonation_score_path, \
    durations_path, gops_path, relative_positions_path, wavlm_features_path = load_data(in_dir)

dataset = PrepDataset(
    ids=ids, 
    phone_ids_path=phone_ids_path, 
    word_ids_path=word_ids_path, 
    phone_scores_path=phone_scores_path, 
    word_scores_path=word_scores_path, 
    sentence_scores_path=sentence_scores_path, 
    fluency_score_path=fluency_score_path,
    intonation_scores_path=intonation_score_path,
    durations_path=durations_path, 
    gops_path=gops_path, 
    relative_positions_path=relative_positions_path, 
    wavlm_features_path=wavlm_features_path,
    relative2id_path=relative2id_path, 
    phone2id_path=phone2id_path,
    max_length=max_length,
)

dataloader = DataLoader(
    dataset, 
    batch_size=1, 
    num_workers=1,
    shuffle=True, 
    drop_last=False, 
    pin_memory=True, 
)


In [None]:
# embed_dim=32
# num_heads=1
# depth=3
# input_dim=855
# num_phone=44
# max_length=128

# lr=1e-3
# weight_decay=5e-7
# betas=(0.95, 0.999)

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# gopt_model = PrepModel(
#     embed_dim=embed_dim, num_heads=num_heads, 
#     depth=depth, input_dim=input_dim, 
#     max_length=max_length, num_phone=num_phone, dropout=0.1).to(device)

In [None]:
# ckpt_path = "/data/codes/apa/train/exp/ckpts/dev/ckpts-eph=24-mse=0.17229999601840973/model.pt"
# state_dict = torch.load(ckpt_path, map_location="cpu")
# gopt_model.eval()
# gopt_model.load_state_dict(state_dict)

In [None]:
sample_ids = []
prep_scores = []
elsa_scores = []

gopt_model.eval()
for batch in tqdm(dataloader):
    batch_ids, features, phone_ids, word_ids, relative_positions, \
            phone_labels, word_labels, utterance_labels, \
            fluency_labels, intonation_labels = to_device(batch, device)
        
    with torch.no_grad():
        phone_score = gopt_model.forwar_phn(
            x=features.float(), phn=phone_ids.long())
        
        phone_score = phone_score.squeeze(-1)
        
    assert phone_score.shape[0] == 1
    
    phone_ids = [f'{batch_ids[0]}_{index}' for index in range((phone_labels!=-1).sum())]
    sample_ids += phone_ids

    elsa_scores.append(phone_labels[phone_labels!=-1])
    prep_scores.append(phone_score[phone_labels!=-1])

elsa_scores = torch.concat(elsa_scores).cpu()
prep_scores = torch.concat(prep_scores).cpu()

In [None]:
with open(f'{in_dir}/id', "w") as f:
    content = "\n".join(sample_ids)
    f.write(content)

np.save(f'{in_dir}/infer-prep_scores.npy', prep_scores.numpy())
np.save(f'{in_dir}/infer-elsa_scores.npy', elsa_scores.numpy())
np.save(f'{in_dir}/infer-phone_ids.npy', sample_ids)

#### filter

In [None]:
%cd /data/codes/apa/train

from sklearn.metrics import classification_report
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch

from src.utils.train import (
    convert_score_to_color
)

In [None]:
prep_scores = np.load(f'{in_dir}/infer-prep_scores.npy')
elsa_scores = np.load(f'{in_dir}/infer-elsa_scores.npy')
ids = np.load(f'{in_dir}/infer-phone_ids.npy')

In [None]:
prep_decisions = convert_score_to_color(torch.from_numpy(prep_scores).clone())
prep_decisions = prep_decisions.int().numpy()

elsa_decisions = convert_score_to_color(torch.from_numpy(elsa_scores).clone())
elsa_decisions = elsa_decisions.int().numpy()

print(classification_report(y_true=elsa_decisions, y_pred=prep_decisions))

In [None]:
id2label = {0:"GREEN", 1:"YELLOW", 2:"RED"}

df = pd.DataFrame(
    {
        "id": ids.tolist(),
        "prep": prep_decisions.tolist(),
        "elsa": elsa_decisions.tolist(),
        "prep_score": prep_scores.tolist(),
        "elsa_score": elsa_scores.tolist(),
    }
)

df["prep"] = df.prep.apply(lambda x: id2label[x])
df["elsa"] = df.elsa.apply(lambda x: id2label[x])
df["uid"] = df.id.apply(lambda x: x.split("_")[0])
df["diff"] = np.abs(df["prep_score"] - df["elsa_score"])
df.head(2)

In [None]:
print(df.uid.nunique())
print((df[df["diff"] > 30/50].uid.value_counts() > 1).sum())
print((df[df["diff"] > 30/50].uid.value_counts() / df.uid.value_counts() > 0.1).sum())

In [None]:
# ignore_samples = df[df["diff"] > 30/50].uid.value_counts() > 1
ignore_samples = df[df["diff"] > 30/50].uid.value_counts() / df.uid.value_counts() > 0.1
ignore_samples = ignore_samples[ignore_samples==True]

filtered_samples = df[~df.uid.isin(ignore_samples.index)]
print(df.shape)
print(filtered_samples.shape)
print(filtered_samples.uid.unique().shape)

In [None]:
id_path = f"{out_dir}/id"
with open(id_path, "w", encoding="utf-8") as f:
    f.write("\n".join(filtered_samples.uid.unique().tolist()))

In [None]:
id_path