In [1]:
import os, sys, time, random, torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from tqdm import tqdm
import hyp
import models
import state_data as aug
import preprocess_data as prep

def store_result(store_epoch_acc_val, store_epoch_loss_val, store_qwk_epoch_loss_val, HIDDEN, ONE_HOT, DATA_AUG, training_dir):
    most_acc = max(store_epoch_acc_val)
    min_loss = min(store_epoch_loss_val)
    qwk_max_loss = max(store_qwk_epoch_loss_val)
    print("\nHighest accuracy of {} occured at {}...\nMinimum loss occured at {}... \nMaximum QWK metric of {} occured at {}".format(
        most_acc, store_epoch_acc_val.index(most_acc)+1, 
        store_epoch_loss_val.index(min_loss)+1, 
        qwk_max_loss, store_qwk_epoch_loss_val.index(qwk_max_loss)+1))
    with open(training_dir+"/HYP.txt","w+") as f:
        f.write("EPOCH = {} \n".format(hyp.EPOCHS))
        f.write("LR = {} \n".format(hyp.LR))
        f.write("HIDDEN_LAYERS = {} \n".format(HIDDEN))
        f.write("ONE_HOT = {} \n".format(ONE_HOT))
        f.write("DATA_AUG = {} \n".format(DATA_AUG))
        f.write("Highest accuracy of {} occured at {}...\nMinimum loss of {} occured at {}... \nMaximum QWK metric of {} occured at {}".format(
        most_acc, store_epoch_acc_val.index(most_acc)+1, 
        min_loss, store_epoch_loss_val.index(min_loss)+1, 
        qwk_max_loss, store_qwk_epoch_loss_val.index(qwk_max_loss)+1))
    checkpoints = os.listdir(training_dir)
    for checkpoint in checkpoints:
        if "checkpoint" in checkpoint:
            checkpoint_num = int(checkpoint[checkpoint.index("_")+1:checkpoint.index(".")])
            if checkpoint_num not in [store_qwk_epoch_loss_val.index(qwk_max_loss)+1,
                                      store_epoch_loss_val.index(min_loss)+1,
                                      store_epoch_acc_val.index(most_acc)+1]:
                os.remove(training_dir+"/"+checkpoint)

def train(model, HIDDEN, ONE_HOT, DATA_AUG, data_train_loader, data_val_loader):
    print("Training...")
    training_dir = './training_{}+{}_{}_{}_{}'.format(ONE_HOT, DATA_AUG, len(HIDDEN), max(HIDDEN), time.time())
    os.mkdir(training_dir)
    os.mkdir(training_dir+'/misclassified')
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=hyp.LR, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    
    qwk_loss = cohen_kappa_score
    ce_loss = nn.CrossEntropyLoss().cuda()
    epoch = 0
    store_epoch_loss = []
    store_qwk_epoch_loss = []
    store_epoch_loss_val = []
    store_qwk_epoch_loss_val = []
    store_epoch_acc_val = []
    try:
        for e in tqdm(range(hyp.EPOCHS)):
            #scheduler.step()
            epoch = e + 1
            epoch_loss = 0
            qwk_epoch_loss = 0
            store_batch_loss = []
            store_qwk_batch_loss = []
            
            for batch_num, (X, y) in enumerate(data_train_loader):
                optimizer.zero_grad()
                prediction = model.forward(X.cuda())
                batch_loss = ce_loss(prediction, y)
                batch_loss.backward()
                qwk_batch_loss = qwk_loss(y.clone().detach().cpu().numpy(), 
                                          np.argmax(prediction.clone().detach().cpu().numpy(), axis=1), 
                                          weights="quadratic")
                optimizer.step()
                store_batch_loss.append(batch_loss.clone().cpu())
                store_qwk_batch_loss.append(qwk_batch_loss)
                epoch_loss = torch.FloatTensor(store_batch_loss).mean()
                qwk_epoch_loss = torch.FloatTensor(store_qwk_batch_loss).mean()
                
            store_epoch_loss.append(epoch_loss)
            store_qwk_epoch_loss.append(qwk_epoch_loss)
            torch.save(model.state_dict(), "{}/checkpoint_{}.pth".format(training_dir, epoch))

            model.eval()
            epoch_loss_val = 0
            qwk_epoch_loss_val = 0
            epoch_acc_val = 0
            store_batch_loss_val = []
            store_qwk_batch_loss_val = []
            store_batch_acc_val = []
            misclassified_images = []
            for batch_num, (X, y) in enumerate(data_val_loader):
                with torch.no_grad():
                    prediction = model.forward(X.cuda())
                batch_loss = ce_loss(prediction, y)
                qwk_batch_loss = qwk_loss(y.clone().detach().cpu().numpy(), 
                                          np.argmax(prediction.clone().detach().cpu().numpy(), axis=1), 
                                          weights="quadratic")
                misclassified = prediction.max(-1)[-1].squeeze().cpu() != y.cpu()
                misclassified_images.append(X[misclassified==1].cpu())
                batch_acc = misclassified.float().mean()
                store_batch_loss_val.append(batch_loss)
                store_qwk_batch_loss_val.append(qwk_batch_loss)
                store_batch_acc_val.append(batch_acc)
                epoch_loss_val = torch.FloatTensor(store_batch_loss_val).mean()
                qwk_epoch_loss_val = torch.FloatTensor(store_qwk_batch_loss_val).mean()
                epoch_acc_val = torch.FloatTensor(store_batch_acc_val).mean()
            store_epoch_loss_val.append(epoch_loss_val)
            store_qwk_epoch_loss_val.append(qwk_epoch_loss_val)
            store_epoch_acc_val.append(1-epoch_acc_val)
            plt.plot(store_epoch_loss_val[1:], label="Validation Loss")
            plt.plot(store_qwk_epoch_loss_val[1:], label="Validation Metric(QWK)")
            plt.plot(store_epoch_acc_val[1:], label="Validation Accuracy")
            plt.legend()
            plt.grid()
            plt.savefig("{}/Loss.png".format(training_dir))
            plt.close()
            if len(misclassified_images) > 0:
                misclassified_images = np.concatenate(misclassified_images,axis=0)
                validation_dir = training_dir+'/misclassified/checkpoint_{}'.format(epoch)
                os.mkdir(validation_dir)
            model.train()
        store_result(store_epoch_acc_val, store_epoch_loss_val, store_qwk_epoch_loss_val, HIDDEN, ONE_HOT, DATA_AUG, training_dir)

    except KeyboardInterrupt:
        store_result(store_epoch_acc_val, store_epoch_loss_val, store_qwk_epoch_loss_val, HIDDEN, ONE_HOT, DATA_AUG, training_dir)

if __name__ == "__main__":
    for ONE_HOT in [0,1]: # for MaturitySize, FurLength, Health
        for DATA_AUG in [0,1]: # for state data
            data_train_loader, data_val_loader = prep.preprocess_data(ONE_HOT, DATA_AUG)
            train(models.Model(hyp.HIDDEN_LIST[2], ONE_HOT, DATA_AUG).cuda(), hyp.HIDDEN_LIST[2], ONE_HOT, DATA_AUG, data_train_loader, data_val_loader)

  0%|          | 0/1000 [00:00<?, ?it/s]

Training...


  0%|          | 2/1000 [00:05<49:19,  2.97s/it]



Highest accuracy of 0.3899478316307068 occured at 2...
Minimum loss occured at 2... 
Maximum QWK metric of 0.33045652508735657 occured at 2


  0%|          | 0/1000 [00:00<?, ?it/s]

Training...


  0%|          | 2/1000 [00:05<47:47,  2.87s/it]



Highest accuracy of 0.2976999282836914 occured at 1...
Minimum loss occured at 1... 
Maximum QWK metric of 0.03216612711548805 occured at 2


  0%|          | 0/1000 [00:00<?, ?it/s]

Training...


  0%|          | 2/1000 [00:05<45:49,  2.76s/it]



Highest accuracy of 0.39671987295150757 occured at 2...
Minimum loss occured at 2... 
Maximum QWK metric of 0.33799639344215393 occured at 2


  0%|          | 0/1000 [00:00<?, ?it/s]

Training...


  0%|          | 2/1000 [00:05<45:08,  2.71s/it]


Highest accuracy of 0.2771621346473694 occured at 1...
Minimum loss occured at 2... 
Maximum QWK metric of 0.01687517948448658 occured at 1





In [17]:
import pandas as pd
import json
import os

dir = "data/train_sentiment"

for file in os.listdir(dir):
    petid = file[:file.index(".")]
    with open(dir+"/"+file) as f:
        data = json.load(f)
        print(data["documentSentiment"]["magnitude"])

0.5
2.8
2
2.8
2.3
0.8
1.1
0.8
1
3.6
2.1
0.7
1.3
2.9
4.5
7.6
1.1
1.3
1.1
1.7
0.9
0.1
0.8
2.2
3.7
0.5
1.6
3.1
0.5
1.7
1.8
1.2
1.2
3
0.1
1.8
0.7
3.9
3.8
0.4
1
3.6
1.9
2.7
0
2.8
1.5
0.8
3
0.8
0.1
3.8
4.8
0.9
1.9
1.7
2.5
0.9
0.6
2.1
1
0.7
0.8
4.7
2.4
0.4
0.9
1.9
2.1
6.2
0.6
0
0
2.5
2.1
0.9
2.4
2.6
0.9
3.3
1.1
2.6
4.5
2.6
0.1
0.1
2.8
3.5
5.1
7.8
0.1
2.7
5.4
2.9
2.9
1.1
2.1
2.9
0.8
0.9
0
0.5
1
1.7
0.1
2.5
5.2
1.7
4.9
3.7
1.9
1.5
0.5
0.7
2
3.7
0.8
1.4
0
0.7
0.2
0.2
6.5
3.5
0.4
0.7
1.3
3.9
1.9
3.9
3.2
2.3
0.4
6.5
0.1
1.7
4.1
1.2
2.6
3
0.9
0.7
0.4
1.1
0.6
0.6
3.3
0.7
0.8
3.2
3.3
2.4
4.2
2.3
2.8
1.7
1
1.5
2.2
0.9
0
0.9
3.2
0.2
1.3
1
2.4
0.1
3.8
5.2
0.8
0.5
3.7
0.8
1.1
0.2
3.3
2.7
0
0.4
0
2
0.1
3.9
3.2
0.3
0.3
2.4
2.5
3.4
1.3
1.9
0.3
4.5
2
1.1
1.7
8.1
2.4
0.1
0.2
1.8
0.2
1.9
1.1
3.9
5.4
1.6
8
0.3
4.3
1.9
0.5
1.7
0.8
1.2
1.4
3.7
2.4
10
2.2
3.7
0.6
0
1.2
0.2
0
2.8
8.8
0.4
4.8
2.7
3.1
0
0.9
1.8
1.9
0.8
2.4
2.3
0.3
2.7
1
2.5
3.2
1.1
0.8
3.2
0.2
2.8
3.5
4.9
1.4
2.9
1.7
1
1.9
0.2
0.8
2.9
0
0.5
0.6
4.4
1

4.6
3.2
4.3
1.9
4.9
0.8
0.4
2.9
1
2
2.3
1.9
1.1
2
1.8
1.4
7.9
1
0.2
0.6
2.1
0.9
1
0.5
4.3
2.7
1
3.3
1.7
1.9
0.4
5.4
1.8
0
0.8
2.3
3.9
1.8
0.4
2
1.6
1.9
1
2.1
2.4
0.3
3.3
13.9
2.3
0.4
4.6
0.2
5.4
1.9
3
1.7
0.8
3.2
2.8
4.4
0.6
5.4
0
3.2
4.6
3.4
1.7
0.3
0.9
1.6
1.7
2.4
0.2
0
1.7
6.2
0.1
1.9
3.1
0.5
1.9
1.7
0
4.5
0.7
0.8
1.4
3.5
0.3
3.9
1.2
2.9
2.2
0.8
4.7
0.9
3
1.3
0.9
2
0.8
1.8
0.3
1.3
3.3
2.3
19
6.2
1.9
3.5
3.6
2.2
0.9
0.4
3.9
2.1
2.4
0.5
1.1
1
2.7
2.4
0.2
2.5
7.7
1.7
1.8
1.6
1.7
0.1
4.2
1.3
1.9
0.2
5.1
0
3.4
1.4
1.4
2.5
3
3.4
3.3
1.4
1.2
1.2
1.3
2.5
3.9
3.4
2.6
0.2
0
0.8
1
1.9
0.7
0.7
0
1.9
0.9
1.2
1.4
1.4
0.1
1.8
2.8
0.9
0.8
2.4
3.4
0
2
1.2
1
4.3
0.8
2.9
1.4
1.1
0
0.8
2.8
2
1
2.6
1.9
0.8
1.5
4.4
0.9
0.5
6.3
2.4
0.6
5.1
1.8
1.8
0.9
3.2
0.2
2.6
2.9
3.6
2.7
0.3
0.7
3.1
0
0
4.7
6.9
2.5
1.9
0.3
0.5
1
1
1.2
3.5
0.5
0.7
3.2
0.4
0.6
1.5
0
0.6
3.6
0.5
1.8
3.7
5.7
1.3
1.7
5
0.8
12.5
4.7
1.6
0.9
3.3
1.4
7.6
1.3
0.9
2.3
1.9
0.5
4.2
6.5
1.9
1.6
2.7
0.9
1.7
6.4
1.2
0
1.5
1.7
2.6
2
1.6
1
0.2
3
3.9
3

0
2
0.9
0.2
0.3
2.4
3.9
1.4
2
1.9
0.2
1.1
1.7
1.7
1.1
2.3
2.6
1.3
2
4.3
1.5
0.1
2.4
6.2
0.7
1.2
0.8
0.9
9.2
1.7
0.3
1.5
3.2
0.9
3.8
2.8
5.8
1.8
0.8
2.4
0.9
1.7
1
0.9
0.9
4.2
1.1
0.4
3.2
1.1
0.9
0.9
1.5
1.8
3.7
1.9
1.8
2.4
1.4
1
4.2
0.1
2.4
2.2
0.5
1.2
1.4
0
2
1
2
5.9
0.9
0.9
1.7
4.2
3.6
0
4.8
0.9
1.7
2.9
0.3
1.9
3.1
1.4
4.8
2
1
3.1
0.4
0.9
0.4
3
0
2.1
1.2
1.8
2.7
5.3
1.5
2.6
0.8
0.6
1.2
3.4
3.5
2.3
0.9
2
1.3
4
5.6
0.2
16.7
0.8
1
0.3
1.4
0
2.2
1.1
0.6
8.5
1.4
0.6
0.4
5.1
0.2
1.5
0
5.3
0.6
1.6
1.4
1.2
0.1
1.8
6.6
0.2
0.3
1.8
4.5
2
2.6
1.5
1.9
2.7
2
0.8
1.4
0.6
1.6
14.7
0.7
1.7
1.6
5.6
2.8
0.9
1.2
1.2
3.6
1.3
1.5
2.7
1.7
2.7
1.6
0
4.1
5.9
1
1.7
0.4
1.8
4.3
0.8
0.5
2.1
3.8
3.1
1.3
4.6
1.8
1.2
0
1.3
9.2
3.3
3.5
1.5
1.6
0.5
2.8
1.5
4.2
1.1
0.8
0.4
2.4
1.2
1.4
2.6
6
1.3
0.8
0.5
5.9
0
0.8
1.9
2.4
3.8
2
3.1
0.9
3
1.3
2.6
0.7
1.6
0.8
1.3
3.1
1.2
3.3
2.3
0.9
4.2
0.3
1.4
1.7
1.9
1
1.2
0.9
1.3
2.1
0
4
2.5
3
1.3
1.9
2.6
1.3
2.8
3.3
2
6.6
1.1
1.1
5.2
3.5
3.3
0.4
2.5
0.1
1.2
1.8
2.6
0
0.8
0.6
0
1.4
0.

0.9
2.5
1.8
1.4
2.7
1.5
1.1
3.5
0.9
1.2
0.9
3.2
0.9
2.3
1.1
0.5
0
2.4
3.7
0
1.3
4.7
3.7
1.8
1
1
3.5
0
2.7
0.2
3.2
3.1
1.8
1.3
0.2
0.4
1.7
1
5.4
0.3
6
2.3
2.4
2
1.9
0.3
1.7
1.7
2.1
1.2
0.1
1.9
0
1.2
4
4.2
1.6
5.5
0.2
0
3.4
2.3
0.6
2.3
0.1
1.4
1.9
1.3
0.6
0
2.6
0.3
7.3
0.9
0.8
0.1
0.9
1.5
1.7
3.2
0.3
3.8
0.1
9.7
3.9
1.6
8.3
3.2
0
0
1.8
1.9
6
3.9
3.2
3.6
1.2
3.8
1.1
2.9
0.2
0.9
3.5
2.7
1.1
3.3
0
1.8
1
1.6
1.2
0.9
1.8
0.6
2.1
2.5
3.8
0.6
0.9
2.6
1.5
2.8
2
2
0.6
4.7
1.7
1.4
0
1.5
3.3
1
0.2
5.9
1.5
3.6
3.4
2
3.8
0.8
2.5
2.1
2.4
2.9
1
5.3
0.9
1.8
0.3
0.4
1
0.8
0
0.8
1.4
2.3
1.7
0.9
1
6.9
0.9
4.2
2.4
2.5
0.5
1.3
0.9
2.7
0.9
0.3
0
0.8
0.8
4.2
3.2
1
1
5
1
1.3
0.7
0.8
8.8
0.5
3
2.1
1
1.3
1.7
2.9
5.8
1.9
1.8
4.4
3.9
0.4
3.2
5
0.2
0.4
0
0.8
3.4
2.9
2.6
5.7
3.8
1.3
2.6
1.1
1.2
1.5
1.9
0.9
2.1
2.3
2.9
2.6
0.4
2.7
4.3
0.8
1.2
3.2
1.3
0.5
0.7
7.2
2
0.9
0.5
3.6
4.1
0.8
2.5
3
0.2
2.9
0
4.4
2.3
1
1.5
2.6
1.9
0.9
1.4
3.5
0.9
2.5
4.4
6.9
0.2
3.8
0.2
1.7
1
4.6
0.9
4.1
1.8
1.8
2.7
3.6
1.5
0.9
1.9
0.6
0.3
5.4


0.7
3
0.1
1.4
4.9
1.6
3.1
1.8
0.4
1.4
1.2
5.6
0.5
2.3
0
1.2
1.5
3.5
0
1.2
1.5
5.1
1.9
0.4
0.6
4.7
0.3
0.8
0
0.9
1.1
2.6
2.8
2.1
1.4
1
2
1.6
1.7
6.7
2.3
4
0.8
2.5
4
2.5
0.2
1.7
0.5
1.1
2.5
3.9
0.4
1.4
0
1
1.7
7.8
0.1
3
11.7
1.2
2.8
0.6
0.6
1.4
2.1
5
0.3
3.4
0
3.1
5
3
0.3
0.9
1.8
0.5
2.6
1.2
2.6
0.8
1.3
1.5
2.1
3.2
1.2
1.4
0.5
0.5
1.1
0.5
0.4
1.3
0.7
1.4
3.2
2.7
3.7
4.2
0.3
0
1.9
1.5
1.3
2.3
0.3
2.4
0.8
0.5
0.2
2
5.6
3.1
1.3
0.2
0
0.4
0
1.6
1.5
1.7
1.2
4.8
1.3
0.8
0.9
1.7
1.7
1.3
3.3
2.1
1.9
0.5
0.8
1.4
1.7
2
6.4
2.5
0.9
3.7
0
0.4
2.8
4
2.2
0.1
3.3
3
0.8
0.9
0.1
5
1.1
0.5
0.5
1.3
0.4
0.9
0.1
3.7
0.5
1.2
0.7
2.3
3.6
1.2
1.2
1.6
0.8
0.2
0.4
3.4
1.7
2.1
1.1
1.9
0.8
0
1.5
8
4.1
1.3
0.2
2.1
2.4
1
2
0.1
1
2.7
2.8
3.7
7.7
2.5
0.9
1.2
0.8
4.5
3.3
2.4
0.2
3.4
3.3
5.4
0.8
3.8
1.2
1.9
2.5
3.4
2.3
1
2.1
1.2
1.5
2.2
1.8
3.1
1.4
7.7
1.1
0.8
0.8
2.9
2
1.2
1.2
0.9
0
6.5
0.2
1.1
1.3
0.8
1.3
3.3
1.2
0.1
1
0.4
2.5
1.1
0.5
0
0.5
1.2
8.9
1.6
7.7
3
1.1
0.9
0.2
2.4
0.9
0.7
1.8
3.2
7.8
3.6
3.4
0.3
1.7
3.6
9.1
1

1.4
3.5
2.5
2.8
0
1
1
1.8
3.7
1.2
0.9
2.2
4.3
0.7
1.2
0.4
0.9
3.7
2
1.4
16.1
5.9
3.1
1.6
1.8
3.2
3
1.1
3
0.5
1.3
3.3
1.4
3
0.3
4.6
1.4
8.5
2.6
4.9
2.3
0.6
0.1
0.1
1.6
1.4
1.8
0.9
1.6
2.9
2.3
3.2
5.1
2.3
1.2
0.7
0.7
1.7
4.1
2.1
1.3
0
0
8.4
3.8
16.7
1.2
1.2
1.1
0.9
0.1
3.2
0.5
0
0.1
4.8
2.7
1.8
1
2
0.5
1.8
0.2
3.8
0.2
1.8
1.5
0.9
4
2.5
2.2
0.1
4.4
0.3
3
2.9
0
1.1
0.5
2.9
1.1
3
4.1
0.8
0
4.1
2.6
0.4
3.1
2.7
0.4
0.6
0.8
3.9
1.7
3
0.8
1.5
2.5
3
3.3
7.9
5.2
1.8
1.3
0
1.7
3.7
2.9
13.2
0.4
1.4
3.7
3.2
1.1
0
6.6
1.2
0.1
5.1
2.7
1.7
1.2
0
2.8
0
3.3
0.3
2.7
0.4
3.4
0.9
3.4
0.8
2
1.5
0.7
2.8
0.7
3
2
9
0.1
2.4
1.7
1.1
3.8
1.6
0.1
2.5
0.9
1.8
5.9
1.7
0.1
2.6
2.8
0.5
1.9
0.9
0.7
0.1
3.2
3.2
0.1
2.7
2
6.5
2.2
0
3.1
7.8
1.6
2.1
1.6
0.5
8.1
0.7
1
1.9
2.1
3.7
0.1
1.4
0.2
0.6
0
1.4
3.2
2.3
2.8
0.6
3.2
2
0.2
0.1
2.2
0.6
2.1
0.4
0.9
0.7
0
1.7
2.4
0.6
2.3
1.3
0.4
2.1
1.3
2.8
2.4
1.6
1.5
1.5
1.1
6.8
2.2
1.6
1.8
2
2.6
1.6
2
1.1
3.4
3.8
0.9
0.1
0.6
3.1
0.8
4.3
0.6
3.3
1.5
0.3
2.8
1.7
9.8
2.6
1
2.3
2.5
1.2
5.8
4