In [1]:
#import random, os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import os
from os.path import exists
 
from copy import deepcopy
#from random import sample


from tqdm import tqdm
from tqdm import tqdm
from aum import DatasetWithIndex
from aum import AUMCalculator

import sys

sys.path.insert(0, "../")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from data_iq.dataiq_class import *
from src.utils.utils import *
from src.models.neuralnets import *
from src.utils.data_loader import *

# Load dataset

In [3]:
dataset = 'support'

train_loader, train_data, X_train, y_train, X_test, y_test, X_train_pd, y_train_pd, X_test_pd, y_test_pd, nlabels, corr_vals, column_ids, df = load_dataset(dataset)

# Train models and assess

In [4]:
# Storage lists
aums=[]

jtts=[]

datamaps=[]

dataiqs=[]

grads=[]

n_feats=X_train.shape[1]

for i in tqdm(range(5)):
  from aum import AUMCalculator
  delta=0.001
  BATCH_SIZE=64
  LEARNING_RATE = 0.001
  EPOCHS=10

  dataiq_list=[]
    
  nets = [Net1(input_size=n_feats, nlabels=nlabels),Net2(input_size=n_feats, nlabels=nlabels), Net3(input_size=n_feats, nlabels=nlabels)]
  combos = [(0, 1), (0, 2), (1, 2)]
    
  checkpoint_list = []

  overall_aum = []
  overall_jtt = []

  for i in range(len(nets)):
    save_dir = '.'
    aum_calculator = AUMCalculator(save_dir, compressed=True)
    train_loader = DataLoader(dataset=DatasetWithIndex(train_data), batch_size=BATCH_SIZE, shuffle=True)
    
    ckpt_nets = []
    
    net_idx = i
    net = nets[net_idx]
    net.to(device)
    
    criterion = torch.nn.NLLLoss()
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
     
    # Instantiate Data-IQ
    dataiq = DataIQ_Torch(X=X_train , y=y_train, sparse_labels=True)

    early_stopping = EarlyStopping(patience=2, delta=delta)

    for e in range(1, EPOCHS+1):
        net.train()
        epoch_loss = 0
        epoch_acc = 0
        for X_batch, y_batch, sample_ids in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            sf = nn.LogSoftmax()
            y_pred = net(X_batch)
            records = aum_calculator.update(y_pred, y_batch.type(torch.int64), sample_ids.cpu().numpy())

            _, predicted = torch.max(y_pred.data, 1)

            y_batch=y_batch.to(torch.int64)
            loss = criterion(sf(y_pred), y_batch)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_acc += (predicted == y_batch).sum().item()/len(y_batch)
        
        # Log to Data-IQ
        dataiq.on_epoch_end(net, device=device)
        print(f'Epoch {e+0:03}: | Loss: {epoch_loss/len(train_loader):.5f} | Acc: {epoch_acc/len(train_loader):.3f}')

        early_stopping(epoch_loss/len(train_loader), net)

        ckpt_nets.append(deepcopy(net))

        if early_stopping.early_stop:
            print("Early stopping")
            break

    checkpoint_list.append(ckpt_nets)
    dataiq_list.append(dataiq)

    net.load_state_dict(torch.load("checkpoint.pt"))


    #########
    # AUM 
    #########
    if exists('aum_values.csv'):
      os.remove('aum_values.csv')

    aum_calculator.finalize()
    aum_df = pd.read_csv('aum_values.csv')
    aum_scores = []
    for i in range(aum_df.shape[0]):
      aum_sc = aum_df[aum_df['sample_id']==i].aum.values[0]
      aum_scores.append(aum_sc)
    overall_aum.append(aum_scores)
    
    #########
    # JTT
    #########
    jtt_vec = compute_jtt(X_train, y_train, net)
    overall_jtt.append(jtt_vec)

  
  # Get the 2nd outputs since we care about spearman
  _, spearman_corr = get_scores_generic(overall_aum, combos)
  aums.append(spearman_corr)

  _, spearman_corr = get_scores_generic(overall_jtt, combos)
  jtts.append(spearman_corr)

  _, spearman_corr = get_dataiq_scores(dataiq_list, combos, feat='variability')
  datamaps.append(spearman_corr)

  _, spearman_corr = get_dataiq_scores(dataiq_list, combos, feat='aleatoric')
  dataiqs.append(spearman_corr)

  _, spearman_corr = get_grad_scores(dataiq_list, combos)
  grads.append(spearman_corr)

  X = F.softmax(self.output(X))


Epoch 001: | Loss: 0.65455 | Acc: 0.567
Validation loss decreased (inf --> 0.654553).  Saving model ...
Epoch 002: | Loss: 0.57587 | Acc: 0.711
Validation loss decreased (0.654553 --> 0.575874).  Saving model ...
Epoch 003: | Loss: 0.55938 | Acc: 0.746
Validation loss decreased (0.575874 --> 0.559384).  Saving model ...
Epoch 004: | Loss: 0.55088 | Acc: 0.755
Validation loss decreased (0.559384 --> 0.550881).  Saving model ...
Epoch 005: | Loss: 0.54165 | Acc: 0.762
Validation loss decreased (0.550881 --> 0.541650).  Saving model ...
Epoch 006: | Loss: 0.53621 | Acc: 0.769
Validation loss decreased (0.541650 --> 0.536207).  Saving model ...
Epoch 007: | Loss: 0.52791 | Acc: 0.777
Validation loss decreased (0.536207 --> 0.527907).  Saving model ...
Epoch 008: | Loss: 0.52365 | Acc: 0.783
Validation loss decreased (0.527907 --> 0.523647).  Saving model ...
Epoch 009: | Loss: 0.51754 | Acc: 0.790
Validation loss decreased (0.523647 --> 0.517535).  Saving model ...
Epoch 010: | Loss: 0.511

  X = F.softmax(self.output(X))


Epoch 001: | Loss: 0.65691 | Acc: 0.672
Validation loss decreased (inf --> 0.656911).  Saving model ...
Epoch 002: | Loss: 0.58797 | Acc: 0.709
Validation loss decreased (0.656911 --> 0.587969).  Saving model ...
Epoch 003: | Loss: 0.55060 | Acc: 0.747
Validation loss decreased (0.587969 --> 0.550604).  Saving model ...
Epoch 004: | Loss: 0.54175 | Acc: 0.756
Validation loss decreased (0.550604 --> 0.541745).  Saving model ...
Epoch 005: | Loss: 0.53677 | Acc: 0.761
Validation loss decreased (0.541745 --> 0.536768).  Saving model ...
Epoch 006: | Loss: 0.53056 | Acc: 0.768
Validation loss decreased (0.536768 --> 0.530558).  Saving model ...
Epoch 007: | Loss: 0.52557 | Acc: 0.778
Validation loss decreased (0.530558 --> 0.525571).  Saving model ...
Epoch 008: | Loss: 0.52208 | Acc: 0.784
Validation loss decreased (0.525571 --> 0.522079).  Saving model ...
Epoch 009: | Loss: 0.51613 | Acc: 0.792
Validation loss decreased (0.522079 --> 0.516128).  Saving model ...
Epoch 010: | Loss: 0.512

  X = F.softmax(self.output(X))


Epoch 001: | Loss: 0.61714 | Acc: 0.636
Validation loss decreased (inf --> 0.617145).  Saving model ...
Epoch 002: | Loss: 0.56487 | Acc: 0.740
Validation loss decreased (0.617145 --> 0.564869).  Saving model ...
Epoch 003: | Loss: 0.54942 | Acc: 0.756
Validation loss decreased (0.564869 --> 0.549421).  Saving model ...
Epoch 004: | Loss: 0.53382 | Acc: 0.769
Validation loss decreased (0.549421 --> 0.533815).  Saving model ...
Epoch 005: | Loss: 0.52843 | Acc: 0.778
Validation loss decreased (0.533815 --> 0.528431).  Saving model ...
Epoch 006: | Loss: 0.52368 | Acc: 0.781
Validation loss decreased (0.528431 --> 0.523680).  Saving model ...
Epoch 007: | Loss: 0.51175 | Acc: 0.793
Validation loss decreased (0.523680 --> 0.511753).  Saving model ...
Epoch 008: | Loss: 0.50502 | Acc: 0.803
Validation loss decreased (0.511753 --> 0.505019).  Saving model ...
Epoch 009: | Loss: 0.49866 | Acc: 0.809
Validation loss decreased (0.505019 --> 0.498657).  Saving model ...
Epoch 010: | Loss: 0.494

 20%|██        | 1/5 [00:46<03:04, 46.08s/it]

Epoch 001: | Loss: 0.63487 | Acc: 0.681
Validation loss decreased (inf --> 0.634866).  Saving model ...
Epoch 002: | Loss: 0.58488 | Acc: 0.681
Validation loss decreased (0.634866 --> 0.584876).  Saving model ...
Epoch 003: | Loss: 0.56000 | Acc: 0.729
Validation loss decreased (0.584876 --> 0.559998).  Saving model ...
Epoch 004: | Loss: 0.54357 | Acc: 0.753
Validation loss decreased (0.559998 --> 0.543570).  Saving model ...
Epoch 005: | Loss: 0.53584 | Acc: 0.766
Validation loss decreased (0.543570 --> 0.535839).  Saving model ...
Epoch 006: | Loss: 0.53046 | Acc: 0.773
Validation loss decreased (0.535839 --> 0.530457).  Saving model ...
Epoch 007: | Loss: 0.52700 | Acc: 0.775
Validation loss decreased (0.530457 --> 0.526996).  Saving model ...
Epoch 008: | Loss: 0.52166 | Acc: 0.783
Validation loss decreased (0.526996 --> 0.521658).  Saving model ...
Epoch 009: | Loss: 0.51585 | Acc: 0.791
Validation loss decreased (0.521658 --> 0.515853).  Saving model ...
Epoch 010: | Loss: 0.511

 40%|████      | 2/5 [01:31<02:17, 45.76s/it]

Epoch 001: | Loss: 0.64927 | Acc: 0.625
Validation loss decreased (inf --> 0.649267).  Saving model ...
Epoch 002: | Loss: 0.57877 | Acc: 0.688
Validation loss decreased (0.649267 --> 0.578769).  Saving model ...
Epoch 003: | Loss: 0.55864 | Acc: 0.745
Validation loss decreased (0.578769 --> 0.558643).  Saving model ...
Epoch 004: | Loss: 0.54259 | Acc: 0.757
Validation loss decreased (0.558643 --> 0.542594).  Saving model ...
Epoch 005: | Loss: 0.53734 | Acc: 0.764
Validation loss decreased (0.542594 --> 0.537339).  Saving model ...
Epoch 006: | Loss: 0.53231 | Acc: 0.772
Validation loss decreased (0.537339 --> 0.532313).  Saving model ...
Epoch 007: | Loss: 0.52711 | Acc: 0.776
Validation loss decreased (0.532313 --> 0.527110).  Saving model ...
Epoch 008: | Loss: 0.52088 | Acc: 0.784
Validation loss decreased (0.527110 --> 0.520878).  Saving model ...
Epoch 009: | Loss: 0.51579 | Acc: 0.792
Validation loss decreased (0.520878 --> 0.515793).  Saving model ...
Epoch 010: | Loss: 0.511

 60%|██████    | 3/5 [02:17<01:31, 45.67s/it]

Epoch 001: | Loss: 0.62820 | Acc: 0.682
Validation loss decreased (inf --> 0.628198).  Saving model ...
Epoch 002: | Loss: 0.57077 | Acc: 0.682
Validation loss decreased (0.628198 --> 0.570769).  Saving model ...
Epoch 003: | Loss: 0.56072 | Acc: 0.737
Validation loss decreased (0.570769 --> 0.560719).  Saving model ...
Epoch 004: | Loss: 0.55416 | Acc: 0.758
Validation loss decreased (0.560719 --> 0.554164).  Saving model ...
Epoch 005: | Loss: 0.54993 | Acc: 0.760
Validation loss decreased (0.554164 --> 0.549932).  Saving model ...
Epoch 006: | Loss: 0.54480 | Acc: 0.769
Validation loss decreased (0.549932 --> 0.544802).  Saving model ...
Epoch 007: | Loss: 0.54021 | Acc: 0.771
Validation loss decreased (0.544802 --> 0.540209).  Saving model ...
Epoch 008: | Loss: 0.53569 | Acc: 0.780
Validation loss decreased (0.540209 --> 0.535694).  Saving model ...
Epoch 009: | Loss: 0.53015 | Acc: 0.787
Validation loss decreased (0.535694 --> 0.530154).  Saving model ...
Epoch 010: | Loss: 0.526

 80%|████████  | 4/5 [03:02<00:45, 45.63s/it]

Epoch 001: | Loss: 0.63245 | Acc: 0.663
Validation loss decreased (inf --> 0.632445).  Saving model ...
Epoch 002: | Loss: 0.57401 | Acc: 0.703
Validation loss decreased (0.632445 --> 0.574008).  Saving model ...
Epoch 003: | Loss: 0.55680 | Acc: 0.749
Validation loss decreased (0.574008 --> 0.556805).  Saving model ...
Epoch 004: | Loss: 0.54286 | Acc: 0.759
Validation loss decreased (0.556805 --> 0.542860).  Saving model ...
Epoch 005: | Loss: 0.53461 | Acc: 0.769
Validation loss decreased (0.542860 --> 0.534612).  Saving model ...
Epoch 006: | Loss: 0.52920 | Acc: 0.774
Validation loss decreased (0.534612 --> 0.529204).  Saving model ...
Epoch 007: | Loss: 0.52220 | Acc: 0.783
Validation loss decreased (0.529204 --> 0.522197).  Saving model ...
Epoch 008: | Loss: 0.51864 | Acc: 0.789
Validation loss decreased (0.522197 --> 0.518645).  Saving model ...
Epoch 009: | Loss: 0.51214 | Acc: 0.795
Validation loss decreased (0.518645 --> 0.512139).  Saving model ...
Epoch 010: | Loss: 0.506

100%|██████████| 5/5 [03:48<00:00, 45.66s/it]


In [5]:
results_mean = {}

results_mean['aum'] = np.mean(aums)

results_mean['jtt'] = np.mean(jtts)

results_mean['datamaps'] = np.mean(datamaps)

results_mean['dataiq'] = np.mean(dataiqs)

results_mean['gradn'] = np.mean(grads)

results_mean

{'aum': 0.916190798530897,
 'jtt': 0.7298764801900242,
 'datamaps': 0.9047176065270961,
 'dataiq': 0.948829636213856,
 'gradn': 0.8754141693086199}

In [6]:
results_std= {}

results_std['aum'] = np.std(aums)

results_std['jtt'] = np.std(jtts)

results_std['datamaps'] = np.std(datamaps)

results_std['dataiq'] = np.std(dataiqs)

results_std['gradn'] = np.std(grads)

results_std

{'aum': 0.004698694963591204,
 'jtt': 0.022488183177470037,
 'datamaps': 0.013541113950829782,
 'dataiq': 0.010409463769110547,
 'gradn': 0.042576085289189}