In [3]:
import os, sys, time, random, torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm
import hyp
import models
import state_data as aug
import preprocess_data as prep

def store_result(store_epoch_acc_val, store_epoch_loss_val, store_qwk_epoch_loss_val, HIDDEN, ONE_HOT, DATA_AUG, training_dir):
    most_acc = max(store_epoch_acc_val)
    min_loss = min(store_epoch_loss_val)
    qwk_max_loss = max(store_qwk_epoch_loss_val)
    print("\nHighest accuracy of {} occured at {}...\nMinimum loss occured at {}... \nMaximum QWK metric of {} occured at {}".format(
        most_acc, store_epoch_acc_val.index(most_acc)+1, 
        store_epoch_loss_val.index(min_loss)+1, 
        qwk_max_loss, store_qwk_epoch_loss_val.index(qwk_max_loss)+1))
    with open(training_dir+"/HYP.txt","w+") as f:
        f.write("EPOCH = {} \n".format(hyp.EPOCHS))
        f.write("LR = {} \n".format(hyp.LR))
        f.write("HIDDEN_LAYERS = {} \n".format(HIDDEN))
        f.write("ONE_HOT = {} \n".format(ONE_HOT))
        f.write("DATA_AUG = {} \n".format(DATA_AUG))
        f.write("Highest accuracy of {} occured at {}...\nMinimum loss of {} occured at {}... \nMaximum QWK metric of {} occured at {}".format(
        most_acc, store_epoch_acc_val.index(most_acc)+1, 
        min_loss, store_epoch_loss_val.index(min_loss)+1, 
        qwk_max_loss, store_qwk_epoch_loss_val.index(qwk_max_loss)+1))
    checkpoints = os.listdir(training_dir)
    for checkpoint in checkpoints:
        if "checkpoint" in checkpoint:
            checkpoint_num = int(checkpoint[checkpoint.index("_")+1:checkpoint.index(".")])
            if checkpoint_num not in [store_qwk_epoch_loss_val.index(qwk_max_loss)+1,
                                      store_epoch_loss_val.index(min_loss)+1,
                                      store_epoch_acc_val.index(most_acc)+1]:
                os.remove(training_dir+"/"+checkpoint)

def train(model, HIDDEN, ONE_HOT, DATA_AUG, data_train_loader, data_val_loader):
    print("Training...")
    training_dir = './training_{}+{}_{}_{}_{}'.format(ONE_HOT, DATA_AUG, len(HIDDEN), max(HIDDEN), time.time())
    os.mkdir(training_dir)
    os.mkdir(training_dir+'/misclassified')
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=hyp.LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    
    qwk_loss = cohen_kappa_score
    ce_loss = nn.CrossEntropyLoss().cuda()
    epoch = 0
    store_epoch_loss = []
    store_qwk_epoch_loss = []
    store_epoch_loss_val = []
    store_qwk_epoch_loss_val = []
    store_epoch_acc_val = []
    try:
        for e in tqdm(range(hyp.EPOCHS)):
            #scheduler.step()
            epoch = e + 1
            epoch_loss = 0
            qwk_epoch_loss = 0
            store_batch_loss = []
            store_qwk_batch_loss = []
            
            for batch_num, (X, y) in enumerate(data_train_loader):
                optimizer.zero_grad()
                prediction = model.forward(X.cuda())
                batch_loss = ce_loss(prediction, y)
                batch_loss.backward()
                qwk_batch_loss = qwk_loss(y.clone().detach().cpu().numpy(), 
                                          np.argmax(prediction.clone().detach().cpu().numpy(), axis=1), 
                                          weights="quadratic")
                optimizer.step()
                store_batch_loss.append(batch_loss.clone().cpu())
                store_qwk_batch_loss.append(qwk_batch_loss)
                epoch_loss = torch.FloatTensor(store_batch_loss).mean()
                qwk_epoch_loss = torch.FloatTensor(store_qwk_batch_loss).mean()
                
            store_epoch_loss.append(epoch_loss)
            store_qwk_epoch_loss.append(qwk_epoch_loss)
            torch.save(model.state_dict(), "{}/checkpoint_{}.pth".format(training_dir, epoch))

            model.eval()
            epoch_loss_val = 0
            qwk_epoch_loss_val = 0
            epoch_acc_val = 0
            store_batch_loss_val = []
            store_qwk_batch_loss_val = []
            store_batch_acc_val = []
            misclassified_images = []
            for batch_num, (X, y) in enumerate(data_val_loader):
                with torch.no_grad():
                    prediction = model.forward(X.cuda())
                batch_loss = ce_loss(prediction, y)
                qwk_batch_loss = qwk_loss(y.clone().detach().cpu().numpy(), 
                                          np.argmax(prediction.clone().detach().cpu().numpy(), axis=1), 
                                          weights="quadratic")
                misclassified = prediction.max(-1)[-1].squeeze().cpu() != y.cpu()
                misclassified_images.append(X[misclassified==1].cpu())
                batch_acc = misclassified.float().mean()
                store_batch_loss_val.append(batch_loss)
                store_qwk_batch_loss_val.append(qwk_batch_loss)
                store_batch_acc_val.append(batch_acc)
                epoch_loss_val = torch.FloatTensor(store_batch_loss_val).mean()
                qwk_epoch_loss_val = torch.FloatTensor(store_qwk_batch_loss_val).mean()
                epoch_acc_val = torch.FloatTensor(store_batch_acc_val).mean()
            store_epoch_loss_val.append(epoch_loss_val)
            store_qwk_epoch_loss_val.append(qwk_epoch_loss_val)
            store_epoch_acc_val.append(1-epoch_acc_val)
            plt.plot(store_epoch_loss_val[1:], label="Validation Loss")
            plt.plot(store_qwk_epoch_loss_val[1:], label="Validation Metric(QWK)")
            plt.plot(store_epoch_acc_val[1:], label="Validation Accuracy")
            plt.legend()
            plt.grid()
            plt.savefig("{}/Loss.png".format(training_dir))
            plt.close()
            if len(misclassified_images) > 0:
                misclassified_images = np.concatenate(misclassified_images,axis=0)
                validation_dir = training_dir+'/misclassified/checkpoint_{}'.format(epoch)
                os.mkdir(validation_dir)
            model.train()
        store_result(store_epoch_acc_val, store_epoch_loss_val, store_qwk_epoch_loss_val, HIDDEN, ONE_HOT, DATA_AUG, training_dir)

    except KeyboardInterrupt:
        store_result(store_epoch_acc_val, store_epoch_loss_val, store_qwk_epoch_loss_val, HIDDEN, ONE_HOT, DATA_AUG, training_dir)

if __name__ == "__main__":
    for ONE_HOT in [0,1]: # for MaturitySize, FurLength, Health
        for DATA_AUG in [0,1]: # for state data
            data_train_loader, data_val_loader = prep.preprocess_data(ONE_HOT, DATA_AUG)
            train(models.Model(hyp.HIDDEN_LIST[2], ONE_HOT, DATA_AUG).cuda(), hyp.HIDDEN_LIST[2], ONE_HOT, DATA_AUG, data_train_loader, data_val_loader)


  0%|          | 0/1000 [00:00<?, ?it/s][A

Training...



  0%|          | 1/1000 [00:02<46:24,  2.79s/it][A
  0%|          | 2/1000 [00:05<45:59,  2.77s/it][A
  0%|          | 3/1000 [00:08<45:40,  2.75s/it][A
  0%|          | 4/1000 [00:10<45:22,  2.73s/it][A


Highest accuracy of 0.4391006827354431 occured at 4...
Minimum loss occured at 4... 
Maximum QWK metric of 0.4511677324771881 occured at 4




  0%|          | 0/1000 [00:00<?, ?it/s][A[A

Training...




  0%|          | 1/1000 [00:02<43:59,  2.64s/it][A[A

  0%|          | 2/1000 [00:05<44:03,  2.65s/it][A[A


Highest accuracy of 0.33313632011413574 occured at 2...
Minimum loss occured at 2... 
Maximum QWK metric of 0.08528832346200943 occured at 2





  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A

Training...





  0%|          | 1/1000 [00:02<44:26,  2.67s/it][A[A[A


  0%|          | 2/1000 [00:05<44:49,  2.70s/it][A[A[A


Highest accuracy of 0.3620961308479309 occured at 2...
Minimum loss occured at 2... 
Maximum QWK metric of 0.2737426161766052 occured at 2






  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A

Training...






  0%|          | 1/1000 [00:02<43:43,  2.63s/it][A[A[A[A



  0%|          | 2/1000 [00:05<43:49,  2.63s/it][A[A[A[A



  0%|          | 3/1000 [00:07<43:53,  2.64s/it][A[A[A[A



  0%|          | 4/1000 [00:10<44:02,  2.65s/it][A[A[A[A



  0%|          | 5/1000 [00:13<44:10,  2.66s/it][A[A[A[A



  1%|          | 6/1000 [00:15<44:15,  2.67s/it][A[A[A[A



  1%|          | 7/1000 [00:18<44:15,  2.67s/it][A[A[A[A



  1%|          | 8/1000 [00:21<44:19,  2.68s/it][A[A[A[A



  1%|          | 9/1000 [00:24<44:20,  2.68s/it][A[A[A[A



  1%|          | 10/1000 [00:26<44:12,  2.68s/it][A[A[A[A



  1%|          | 11/1000 [00:29<44:20,  2.69s/it][A[A[A[A


Highest accuracy of 0.3657653331756592 occured at 8...
Minimum loss occured at 11... 
Maximum QWK metric of 0.20535168051719666 occured at 8


in main function: 0 0






  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A

in model: 0 0
Training...






  0%|          | 1/1000 [00:02<44:39,  2.68s/it][A[A[A[A


Highest accuracy of 0.3918193578720093 occured at 1...
Minimum loss of 1.364376187324524 occured at 1... 
Maximum QWK metric of 0.28678271174430847 occured at 1
in main function: 0 1







  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A

in model: 0 1
Training...







  0%|          | 1/1000 [00:02<43:26,  2.61s/it][A[A[A[A[A




[A[A[A[A[A


Highest accuracy of 0.24059295654296875 occured at 1...
Minimum loss of 1.531187653541565 occured at 1... 
Maximum QWK metric of -0.04076163470745087 occured at 1
in main function: 1 0







  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A

in model: 1 0
Training...







  0%|          | 1/1000 [00:02<43:02,  2.58s/it][A[A[A[A[A


Highest accuracy of 0.35421591997146606 occured at 1...
Minimum loss of 1.3958059549331665 occured at 1... 
Maximum QWK metric of 0.26124417781829834 occured at 1
in main function: 1 1








  0%|          | 0/1000 [00:00<?, ?it/s][A[A[A[A[A[A

in model: 1 1
Training...








  0%|          | 1/1000 [00:02<42:35,  2.56s/it][A[A[A[A[A[A





[A[A[A[A[A[A


Highest accuracy of 0.29395681619644165 occured at 1...
Minimum loss of 1.5496681928634644 occured at 1... 
Maximum QWK metric of -0.017652668058872223 occured at 1
