# Classifying Confusion with RNN

This notebook holds our confusion classification pipeline for various RNN architectures. Experiments occur with no data augmentation.


In [1]:
import torch
from torch import nn
from torch import optim
from torch import utils
from torch import autograd
from torchvision import transforms, datasets

import pickle
import os
import random
import numpy as np
import shutil
import matplotlib.pyplot as plt
import pandas as pd

import train
import utils
from networks import ConfusionLSTM
from networks import ConfusionGRU
from imblearn.over_sampling import SMOTE 

MANUAL_SEED = 1

np.random.seed(MANUAL_SEED)
random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)

DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

MAX_SEQUENCE_LENGTH = 150
INPUT_SIZE = 14
HIDDEN_SIZE = 256
OUTPUT_SIZE = 2

BATCH_SIZE = 256

train.MAX_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH
train.INPUT_SIZE = INPUT_SIZE
train.BATCH_SIZE = BATCH_SIZE

utils.MAX_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH
utils.INPUT_SIZE = INPUT_SIZE

## Cross Validation

### The following cell implements 10-fold CV.

In [2]:
def cross_validate(model_type, 
                    folds,
                    epochs,
                    criterion_type,
                    optimizer_type,
                    confused_path,
                    not_confused_path,
                    print_every,
                    plot_every,
                    hidden_size,
                    num_layers,
                    down_sample_training=True,
                    early_stopping_metric='val_loss',
                    early_stopping_patience=10,
                    max_rate_decreases=300,
                    rate_decay_patience=5,
                    initial_learning_rate=0.001,
                    k_neighbors=11,
                    verbose=False):
    """
        Perform Cross Validation of the model using k-folds.
        
        Args:
            model_type (string): the type of RNN to use. Must be 'lstm', 'gru', or 'rnn'
            epochs (int): the max number of epochs to train the model for each fold
            criterion_type (string): the name loss function to use for training. Currently must be 'NLLLoss'
            optimizer_type (string): the name of learning algorithm to use for training. ex 'Adam'
            confused_path (string): the path to the folder containing the confused data samples
            not_confused_path (string): the path to the folder containing the not_confused data samples
            print_every (int): the number of batches to train for before printing relevant stats
            plot_every (int): the number of batches to train for before recording relevant stats, which
                will be plotted after each fold
            hidden_size (int): the number of hidden units for each layer of the RNN
            num_layers (int): the number of hidden_unit sized layers of the RNN
            down_sample_training (boolean): if True training set will be balanced by down sampling not_confused
            early_stopping (boolean): if True, training will stop when early_stopping_patience epochs of 
                of training have passed without improvement in validation AUC ROC score
            early_stopping_patience (int): number of epochs without improvement before stopping training
            rate_dacay (boolean): if True, learning rate will decrease every rate_decay_patience epochs that
                pass without improvement to validation set AUC ROC score.
            rate_decay_patience (int): number of epochs without imporovement in AUC ROC that can pass before
                reducing the learning rate by half.
            initial_learning_rate (float): the first learning rate to be used by the optimizer
            verbose (boolean): if True, function will print additional stats

        Returns: (list,list,list,list,list,list)
            cv_val_accs (list): list containing the validation accuracy for each fold
            cv_val_sensis (list): list containing the validation sensitivity for each fold
            cv_val_specifs (list): list containing the validation specificity for each fold
            cv_test_accs (list): list containing the test accuracy for each fold
            cv_test_sensis (list): list containing the test sensitivity for each fold
            cv_test_specifs (list): list containing the test specificity for each fold
        
    """

    
    confused_file_names = os.listdir(confused_path)
    not_confused_file_names = os.listdir(not_confused_path)
    if '.DS_Store' in confused_file_names:
        confused_file_names.remove('.DS_Store')
    if '.DS_Store' in not_confused_file_names:
        not_confused_file_names.remove('.DS_Store')
    
    #ensure same items appear in folds, for reproducibility:
    infile = open('../grouped_10_fold_split_list_3.pickle', 'rb')
    split = pickle.load(infile)
    infile.close()
    
    train_confused_splits = split[0]
    test_confused_splits = split[1]
    train_not_confused_splits = split[2]
    test_not_confused_splits = split[3]
    
    
    cv_val_accs = []
    cv_val_sensis = []
    cv_val_specifs = []
    cv_val_aucs = []
    
    cv_test_accs = []
    cv_test_sensis = []
    cv_test_specifs = []
    cv_test_aucs = []
    cv_test_combinds = []
    for k in range(folds):
        print("\nFold ", k+1)
        
        # Get data item file names for this fold and downsample not_confused to balance training set

        train_confused, \
        train_not_confused, \
        val_confused, \
        val_not_confused = \
        utils.get_train_val_split(train_confused_splits[k], train_not_confused_splits[k], percent_val_set=0.2)
        
        if down_sample_training:
            train_not_confused = random.sample(train_not_confused, k=(len(train_confused)*3))
        
        test_confused = test_confused_splits[k]
        test_not_confused = test_not_confused_splits[k]

        print("Number of confused training items before SMOTE: ", len(train_confused))
        print("Number of not_confused training items before SMOTE: ", len(train_not_confused))

        y = np.zeros(shape=(len(train_confused+train_not_confused),))
        X = np.zeros(shape=(len(train_confused+train_not_confused), MAX_SEQUENCE_LENGTH*INPUT_SIZE))
        for i, item in enumerate(train_confused):
            x = utils.pickle_loader2(confused_path+item)
            x = np.reshape(x, MAX_SEQUENCE_LENGTH*INPUT_SIZE)
            X[i] = x

        for i, item in enumerate(train_not_confused):
            x = utils.pickle_loader2(not_confused_path+item)
            x = np.reshape(x, MAX_SEQUENCE_LENGTH*INPUT_SIZE)
            X[i+len(train_confused)] = x
            y[i+len(train_confused)] = 1
            
        #sm = SMOTE(random_state=MANUAL_SEED, sampling_strategy=.08165)
        sm = SMOTE(random_state=MANUAL_SEED, k_neighbors=k_neighbors)
        X_res, y_res = sm.fit_sample(X,y)

        print("confused items in training set: ", len(y_res[y_res==0]))
        print("not_confused items in training set: ", len(y_res[y_res==1]))
        print("confused items in validation set: ", len(val_confused))
        print("not_confused items in validation set: ", len(val_not_confused))
        
        if verbose:
            print("\nTest confused items:\n")
            print(test_confused)

        local_train_confused_path = '../dataset/augmented/train_smote_200/confused/'
        local_val_confused_path = '../dataset/augmented/val_smote_200/confused/'
        local_test_confused_path = '../dataset/augmented/test_smote_200/confused/'
        local_train_not_confused_path = '../dataset/augmented/train_smote_200/not_confused/'
        local_val_not_confused_path = '../dataset/augmented/val_smote_200/not_confused/'
        local_test_not_confused_path = '../dataset/augmented/test_smote_200/not_confused/'

        # Remove any old directories
        if os.path.exists(local_train_confused_path):
            shutil.rmtree(local_train_confused_path)
        if os.path.exists(local_val_confused_path):
            shutil.rmtree(local_val_confused_path)
        if os.path.exists(local_test_confused_path):
            shutil.rmtree(local_test_confused_path)
            
        if os.path.exists(local_train_not_confused_path):
            shutil.rmtree(local_train_not_confused_path)
        if os.path.exists(local_val_not_confused_path):
            shutil.rmtree(local_val_not_confused_path)
        if os.path.exists(local_test_not_confused_path):
            shutil.rmtree(local_test_not_confused_path)
            
        # Make new temp directories
        os.makedirs(local_train_confused_path)
        os.makedirs(local_train_not_confused_path)
        for i in range(len(X_res)):
            if y_res[i] == 0:
                item = pd.DataFrame(np.reshape(X_res[i], (MAX_SEQUENCE_LENGTH,INPUT_SIZE)))
                item.to_pickle(local_train_confused_path+'smote_confused_item_'+str(i)+'.pkl')
            else:
                item = pd.DataFrame(np.reshape(X_res[i], (MAX_SEQUENCE_LENGTH,INPUT_SIZE)))
                item.to_pickle(local_train_not_confused_path+'smote_not_confused_item_'+str(i)+'.pkl')
        
        os.makedirs(local_val_confused_path)
        for i in val_confused:
            shutil.copy(src=confused_path+i,dst=local_val_confused_path+i)
        
        os.makedirs(local_test_confused_path)
        for i in test_confused:
            shutil.copy(src=confused_path+i,dst=local_test_confused_path+i)
        
            
        os.makedirs(local_val_not_confused_path)
        for i in val_not_confused:
            shutil.copy(src=not_confused_path+i,dst=local_val_not_confused_path+i)
        
        os.makedirs(local_test_not_confused_path)
        for i in test_not_confused:
            shutil.copy(src=not_confused_path+i,dst=local_test_not_confused_path+i)

        # Prepare training and validation data
        training_data = datasets.DatasetFolder('../dataset/augmented/train_smote_200/', 
                                               loader=utils.pickle_loader2, 
                                               extensions='.pkl')

        validation_data = datasets.DatasetFolder('../dataset/augmented/val_smote_200/', 
                                                 loader=utils.pickle_loader2, 
                                                 extensions='.pkl')
        
        test_data = datasets.DatasetFolder('../dataset/augmented/test_smote_200/', 
                                                 loader=utils.pickle_loader2, 
                                                 extensions='.pkl')
        
        test_data_loader = torch.utils.data.DataLoader(test_data, 
                                                  batch_size=BATCH_SIZE,
                                                  shuffle=False,
                                                  num_workers=10 if DEVICE.type == 'cuda:2' else 5,
                                                  pin_memory=True, drop_last=False)
        print("Training data: ", training_data)
        print("Validation data: ", validation_data)
        print("Test data: ", test_data)
        
        torch.manual_seed(MANUAL_SEED)

        if model_type == 'lstm':
            model = ConfusionLSTM(input_size=INPUT_SIZE, hidden_size=hidden_size, 
                           output_size=OUTPUT_SIZE, batch_size=BATCH_SIZE, num_layers=num_layers)
            if verbose:
                print(model)
        elif model_type == 'gru':
            model = ConfusionGRU(input_size=INPUT_SIZE, hidden_size=hidden_size, 
                           output_size=OUTPUT_SIZE, batch_size=BATCH_SIZE, num_layers=num_layers)
            if verbose:
                print(model)

        else:
            model = ConfusionRNN(input_size=INPUT_SIZE, hidden_size=hidden_size, 
                           output_size=OUTPUT_SIZE, batch_size=BATCH_SIZE, num_layers=num_layers)
            if verbose:
                print(model)

        model = model.to(DEVICE)
        
        #save fresh model to clear any old ones out
        torch.save(model.state_dict(), './best_rnn_smote_200'+'_fold_'+str(k) +'.pt')

        
        # Train and evaluate for the k'th fold
        model, \
        (training_accs, \
        validation_accs, \
        training_losses, \
        training_aucs, \
        validation_losses, \
        validation_recalls, \
        validation_specifs, \
        validation_aucs, \
        val_thresh) = train.train(model=model, 
                                epochs=epochs, 
                                criterion_type=criterion_type, 
                                optimizer_type=optimizer_type, 
                                training_data=training_data, 
                                val_data=validation_data, 
                                print_every=print_every,
                                plot_every=plot_every,
                                early_stopping=True,
                                early_stopping_metric=early_stopping_metric,
                                early_stopping_patience=early_stopping_patience,
                                rate_decay_patience=rate_decay_patience,
                                max_rate_decreases=300,
                                initial_learning_rate=initial_learning_rate,
                                model_name='best_rnn_smote_200_fold_'+str(k),
                                verbose=True,
                                return_thresh=True)
        
        utils.plot_metrics(training_accs, training_losses, training_aucs,
                     validation_accs, validation_losses, validation_recalls, 
                     validation_specifs, validation_aucs)
        
        # store metrics for last epoch of the current fold of CV
        cv_val_accs.append(validation_accs[-1])
        cv_val_sensis.append(validation_recalls[-1])
        cv_val_specifs.append(validation_specifs[-1])
        cv_val_aucs.append(validation_aucs[-1])
        
        test_accuracy, \
        test_recall, \
        test_specificity, \
        test_auc, \
        test_loss = utils.check_metrics(model, test_data_loader)
        combined = (test_recall + test_specificity ) / 2.0

        cv_test_accs.append(test_accuracy)
        cv_test_sensis.append(test_recall)
        cv_test_specifs.append(test_specificity)
        cv_test_aucs.append(test_auc)
        cv_test_combinds.append(combined)
        
        
        
        

    #clean up temp directories
    shutil.rmtree(local_val_confused_path)
    shutil.rmtree(local_val_not_confused_path)
    shutil.rmtree(local_train_confused_path)
    shutil.rmtree(local_train_not_confused_path)
        
    if verbose:
        print("\n%d-fold CV accuracy: %f"% (folds, sum(cv_val_accs)/len(cv_val_accs)))
        print("%d-fold CV sensitivity: %f "% (folds, sum(cv_val_sensis)/len(cv_val_sensis)))
        print("%d-fold CV specificity: %f "% (folds, sum(cv_val_specifs)/len(cv_val_specifs)))
        print("%d-fold CV AUC: %f "% (folds, sum(cv_val_aucs)/len(cv_val_aucs)))
        print("\n%d-fold test accuracy: %f"% (folds, sum(cv_test_accs)/len(cv_test_accs)))
        print("%d-fold test sensitivity: %f "% (folds, sum(cv_test_sensis)/len(cv_test_sensis)))
        print("%d-fold test specificity: %f "% (folds, sum(cv_test_specifs)/len(cv_test_specifs)))
        print("%d-fold test combined: %f "% (folds, sum(cv_test_combinds)/len(cv_test_combinds)))

        print("%d-fold test AUC: %f "% (folds, sum(cv_test_aucs)/len(cv_test_aucs)))    
    return cv_test_sensis, cv_test_specifs, cv_test_combinds, cv_test_aucs

In [None]:
sens = []
spec = []
comb = []
auc = []


for i in range(10):

    np.random.seed(MANUAL_SEED+i)
    random.seed(MANUAL_SEED+i)
    torch.manual_seed(MANUAL_SEED+i)

    sens_list, spec_list, comb_list, auc_list = cross_validate(model_type='gru',
                   folds=10,
                   epochs=100,
                   criterion_type='NLLLoss',
                   optimizer_type='Adam',
                   confused_path='../dataset/augmented/confused_highly_valid_new/',
                   not_confused_path='../dataset/augmented/not_confused_highly_valid_new/',
                   print_every=1,
                   plot_every=1,
                   hidden_size=HIDDEN_SIZE,
                   num_layers=1,
                   early_stopping_metric='val_auc',
                   early_stopping_patience=30,
                   rate_decay_patience=10,
                   initial_learning_rate=0.003,
                   verbose=True)
    # add mean of each measure for 10-fold CV to list
    sens.append(np.mean(sens_list))
    spec.append(np.mean(spec_list))
    comb.append(np.mean(comb_list))
    auc.append(np.mean(auc_list))
    
print("sensitivities: ", sens)
print("specificities: ", spec)
print("combined: ", comb)
print("auc: ", auc)

print("average sensitivity: ", np.mean(sens))
print("average specificity: ", np.mean(spec))
print("average combined: ", np.mean(comb))
print("average auc: ", np.mean(auc))

In [None]:
#output from above cell needed to be cleared (too long to save), so I saved it here:
sensitivities:  [0.7731520562770562, 0.7601109307359308, 0.7341314935064934, 0.7123430735930735, 0.7725541125541125, 0.7586958874458875, 0.7608306277056277, 0.7702137445887447, 0.7300676406926406, 0.7728030303030302]
specificities:  [0.784921678662956, 0.7906244110088444, 0.8175996158638646, 0.8202607975190039, 0.7508108806063986, 0.8089859522746448, 0.7913672840412677, 0.7983119960739928, 0.8116556030981371, 0.8340328734099112]
combined:  [0.7790368674700061, 0.7753676708723876, 0.7758655546851791, 0.7663019355560388, 0.7616824965802556, 0.783840919860266, 0.7760989558734476, 0.7842628703313687, 0.7708616218953889, 0.8034179518564708]
auc:  [0.8117319063151832, 0.7925897904826205, 0.8159519482062955, 0.8122335544460914, 0.795228013742219, 0.8223236337368904, 0.8124265814100273, 0.8216778763418311, 0.7996696354163962, 0.8207653776013885]
average sensitivity:  0.7544902597402596
average specificity:  0.800857109255902
average combined:  0.7776736844980808
average auc:  0.8104598317698943

In [3]:
sensitivities = [0.7731520562770562, 0.7601109307359308, 0.7341314935064934, 0.7123430735930735, 0.7725541125541125, 0.7586958874458875, 0.7608306277056277, 0.7702137445887447, 0.7300676406926406, 0.7728030303030302]
specificities = [0.784921678662956, 0.7906244110088444, 0.8175996158638646, 0.8202607975190039, 0.7508108806063986, 0.8089859522746448, 0.7913672840412677, 0.7983119960739928, 0.8116556030981371, 0.8340328734099112]
for i in sensitivities:
    print(i)

0.7731520562770562
0.7601109307359308
0.7341314935064934
0.7123430735930735
0.7725541125541125
0.7586958874458875
0.7608306277056277
0.7702137445887447
0.7300676406926406
0.7728030303030302


In [4]:
for i in specificities:
    print(i)

0.784921678662956
0.7906244110088444
0.8175996158638646
0.8202607975190039
0.7508108806063986
0.8089859522746448
0.7913672840412677
0.7983119960739928
0.8116556030981371
0.8340328734099112


In [5]:
combined = [0.7790368674700061, 0.7753676708723876, 0.7758655546851791, 0.7663019355560388, 0.7616824965802556, 0.783840919860266, 0.7760989558734476, 0.7842628703313687, 0.7708616218953889, 0.8034179518564708]
for i in combined:
    print(i)

0.7790368674700061
0.7753676708723876
0.7758655546851791
0.7663019355560388
0.7616824965802556
0.783840919860266
0.7760989558734476
0.7842628703313687
0.7708616218953889
0.8034179518564708
