In [None]:
import os
import time
import math
import torch
from scipy.io import wavfile
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from ProcessLine.M0 import M0

### Setup

In [None]:
part_id2str = {1:"01F",2:"01M",3:"02F",4:"02M",5:"03F",6:"03M",7:"04F",8:"04M",9:"05F",10:"05M"}
part_id = 5
mode = part_id2str[part_id]+"/M0"
# training
epochs = 100
lr = 0.00001
accumulation_steps = 8
warm_up = 10
# whether to save model
save_steps_flag = True  # save model per "save_steps" steps
save_steps = 10
save_best_flag = False  # save the best effect model
# checkpoint
checkpoint_epoch = 0
checkpoint_path = ""
time_stamp = ""

### CUDA Info

In [None]:
print(torch.cuda.device_count())
print(torch.cuda.get_device_name())
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

### Model

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.m0 = M0()

    def forward(self, speech, text):
        M0_logits = self.m0(speech)
        logits = {"m0": M0_logits}
        return logits

    def loss(self, logits, label):
        loss = self.m0.loss(logits["m0"], label)
        return loss

model = Model()
# checkpoint
if checkpoint_epoch>0:
    model= torch.load(checkpoint_path)
model.to(device)

print(model)

### Data

In [None]:
# Label dict
label2id = {'e0': 0, 'e1': 1, 'e2': 2, 'e3': 3}
# Dataset
class dataset(Dataset):
    def __init__(self, mode):
        if mode == "train":
            df_data = pd.read_csv("./Data/iemocap/iemocap_"+part_id2str[part_id]+".train.csv")
        else:
            df_data = pd.read_csv("./Data/iemocap/iemocap_"+part_id2str[part_id]+".test.csv")

        self.speech = []
        self.label = []
        self.text = []
        self.index = []

        for _rows in tqdm(df_data.iterrows()):
            _label = _rows[1]["emotion"]
            _path = _rows[1]["file"]
            _text = _rows[1]["text"]
            _index = _rows[0]
            _, wave = wavfile.read(_path)
            self.speech.append(wave)
            self.label.append(torch.tensor(label2id[_label]))
            self.text.append(_text)
            self.index.append(_index)
        self.len = len(self.label)
        print("Load <", mode,"> data successfully! \n\tTotal "+str(self.len)+" samples.")

    def __getitem__(self,index):
        return self.index[index], self.speech[index], self.text[index], self.label[index]

    def __len__(self):
        return self.len

train_dataset = dataset("train")
dev_dataset = dataset("dev")

# Dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_size=1, num_workers=4)
dev_dataloader = DataLoader(dataset=dev_dataset, batch_size=1, num_workers=4)
len_train_dataloder = len(train_dataloader)
len_dev_dataloader = len(dev_dataloader)
print("Make dataloder successfully! \n\ttrain:",len_train_dataloder,"\n\tdev:",len_dev_dataloader)

### Logs

In [None]:
# time stamp
if checkpoint_epoch==0:
    time_stamp = time.strftime("%Y-%m-%d_%H-%M", time.localtime())
# make dir
if save_steps_flag or save_best_flag:
    if not os.path.exists("./Models/"+mode+"/"):
        os.makedirs("./Models/"+mode+"/")
if not os.path.exists("./Logs/"+mode+"/"):
    os.makedirs("./Logs/"+mode+"/")
if not os.path.exists("./Logs/"+mode+"/Logits/"+time_stamp+"/"):
    os.makedirs("./Logs/"+mode+"/Logits/"+time_stamp+"/")
# make log file
train_log = open("./Logs/"+mode+"/"+time_stamp+"_train.txt","a+")
dev_log = open("./Logs/"+mode+"/"+time_stamp+"_dev.txt","a+")
train_acc_log = open("./Logs/"+mode+"/"+time_stamp+"_train_acc.txt","a+")
dev_acc_log = open("./Logs/"+mode+"/"+time_stamp+"_dev_acc.txt","a+")
lr_log = open("./Logs/"+mode+"/"+time_stamp+"_lr.txt","a+")
lr_log.write("lr="+str(lr)+",accumulation_steps="+str(accumulation_steps)+"\n")
lr_log.flush()

### Train

In [None]:
optimizer = torch.optim.AdamW([{'params': model.parameters(), 'lr': lr}])
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 
        math.pow(512, 0.5)*min(math.pow(epoch+1,-0.5), (epoch+1)*math.pow(warm_up,-1.5))
    )
# Train!
steps = 0
loss_item = 0
loss_ctc_item = 0
loss_cls_item = 0
max_dev_acc = 0
for epoch in tqdm(range(1+checkpoint_epoch,epochs+checkpoint_epoch+1)):
    logits_record = {}  # save single epoch logits
    train_acc_num = 0
    model.train()
    for index,(data_index, speech, text, label) in enumerate(train_dataloader):
        steps += 1
        label = label.to(device)
        logits = model(speech, text)
        output = torch.argmax(logits["m0"]["cls"], dim=-1)
        train_acc_num += (output == label).sum().item()
        loss = model.loss(logits, {"cls": label, "ctc": text})
        loss_total, loss_ctc, loss_cls = loss["total"], loss["ctc"], loss["cls"]
        loss_item += loss_total.cpu().item()
        loss_ctc_item += loss_ctc.cpu().item()
        loss_cls_item += loss_cls.cpu().item()
        loss["total"] = loss["total"]/accumulation_steps        
        loss["total"].backward()
        if steps%accumulation_steps == 0:
            optimizer.step()
            print((index+1)//accumulation_steps,"/",
                    len_train_dataloder//accumulation_steps,
                    "-->",np.around(loss_item/accumulation_steps,4),
                    "\tctc_loss:",np.around(loss_ctc_item/accumulation_steps,4),
                    "\tcls_loss:",np.around(loss_cls_item/accumulation_steps,4),
                    "\tlr:",optimizer.state_dict()['param_groups'][0]['lr']
                    )
            # logging
            train_log.write(str(loss_item/accumulation_steps)+"\t"+str(loss_ctc_item/accumulation_steps)+"\t"+str(loss_cls_item/accumulation_steps)+"\n")
            train_log.flush()
            # reset
            optimizer.zero_grad()
            steps = 0
            loss_item = 0
            loss_ctc_item = 0
            loss_cls_item = 0
    # logging
    train_acc = 100*train_acc_num / len(train_dataset)
    train_acc_log.write(str(train_acc)+"\n")
    train_acc_log.flush()
    # update learning rate
    scheduler.step()

    ## dev
    model.eval()
    dev_loss = 0
    dev_ctc_loss = 0
    dev_cls_loss = 0
    dev_index = 0
    dev_acc_num = 0
    with torch.no_grad():
        for index,(data_index, speech, text, label) in enumerate(dev_dataloader):
            label = label.to(device)
            logits = model(speech, text)
            output = torch.argmax(logits["m0"]["cls"], dim=-1)
            dev_acc_num += (output == label).sum().item()
            loss = model.loss(logits, {"cls": label, "ctc": text})
            dev_loss += loss["total"].cpu().item()
            dev_ctc_loss += loss["ctc"].cpu().item()
            dev_cls_loss += loss["cls"].cpu().item()
            dev_index += 1
            logits_record[data_index.item()] = logits["m0"]["cls"].detach().cpu().tolist()
    dev_loss /= dev_index
    dev_ctc_loss /= dev_index
    dev_cls_loss /= dev_index
    dev_acc = 100 * dev_acc_num / len(dev_dataset)
    # logging
    df_logits_record = pd.DataFrame.from_dict(logits_record, orient="index")
    df_logits_record.columns = ["logits"]
    df_logits_record.index.name = "index"
    df_logits_record.to_csv("./Logs/"+mode+"/Logits/"+time_stamp+"/"+str(epoch)+".csv")
    dev_log.write(str(dev_loss)+"\n")
    dev_log.flush()
    dev_acc_log.write(str(dev_acc)+"\n")
    dev_acc_log.flush()
    lr_log.write(str(optimizer.state_dict()['param_groups'][0]['lr'])+"\n")
    lr_log.flush()
    print("Epoch:",epoch,
            "\tDev_Loss:",round(dev_loss, 3),
            "\tDev_CTC_Loss:",round(dev_ctc_loss, 3),
            "\tDev_CLS_Loss:",round(dev_cls_loss, 3),
            "\tDev_acc:",round(dev_acc, 2),
            "%\tTrain_acc:",round(train_acc, 2),
            "%\tAcc_num:",dev_acc_num)
    if save_steps_flag and epoch%save_steps == 0:
        torch.save(model,"./Models/"+mode+"/"+time_stamp
                +"_Epoch"+str(epoch)
                +"_Lr"+str(optimizer.state_dict()['param_groups'][0]['lr'])
                +"_DevLoss"+str(np.around(dev_loss,3))
                +"_DevCTCLoss"+str(np.around(dev_ctc_loss,3))
                +"_DevCLSLoss"+str(np.around(dev_cls_loss,3))
                +"_DevAcc"+str(np.around(dev_acc,3))
                +".pt")
    if save_best_flag and epoch>epochs/2 and max_dev_acc<dev_acc:
        max_dev_acc = dev_acc
        torch.save(model,"./Models/"+mode+"/"+time_stamp
                +"_Epoch"+str(epoch)
                +"_Lr"+str(optimizer.state_dict()['param_groups'][0]['lr'])
                +"_DevLoss"+str(np.around(dev_loss,3))
                +"_DevCTCLoss"+str(np.around(dev_ctc_loss,3))
                +"_DevCLSLoss"+str(np.around(dev_cls_loss,3))
                +"_DevAcc"+str(np.around(dev_acc,3))
                +".pt")

### Plot

In [None]:
dev_acc_log.seek(0,0)
train_acc_log.seek(0,0)
df_dev_acc = pd.read_csv(dev_acc_log)
df_train_acc = pd.read_csv(train_acc_log)
df_dev_acc.columns = ["dev acc"]
df_train_acc.columns = ["train acc"]
ax = df_dev_acc.plot()
df_train_acc.plot(ax=ax)
plt.xlabel("epochs")
plt.ylabel("acc")
plt.savefig("./Logs/"+mode+"/"+time_stamp+".png")