In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn

import torchvision
import torchvision.transforms as transforms

from libauc.losses import AUCMLoss 
from libauc.optimizers import PESG 

import pandas as pd
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, accuracy_score, balanced_accuracy_score, precision_score, recall_score, roc_auc_score

import matplotlib.pyplot as plt
import numpy as np
import os
import random
from time import sleep

from IPython.display import clear_output
from IPython.core.debugger import Pdb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

RANDOM_SEED = 123
BATCH_SIZE = 32
# LEARN_RATE = 0.1
LEARN_RATE = 0.001

MARGIN = 1.0
EPOCH_DECAY = 0.003
WEIGHT_DECAY = 0.0001

cuda


In [2]:
# Build dataset from csv files
class AwareDataset(torch.utils.data.Dataset):
    def __init__(self, csv_data, csv_outcome, csv_verbose, root_dir, target_classes=None, transform=None):
        self.data_raw = pd.read_csv(csv_data, header=None)
        self.data_out = pd.read_csv(csv_outcome)
        self.data_verb = pd.read_csv(csv_verbose)
        self.root_dir = root_dir
        self.transform = transform
        
        if target_classes != None:
            idx = False
            for i in target_classes:
                idx |= (self.data_out['Diagnosis']==i)
                for col in self.data_out.columns:
                    idx &= (~self.data_out[col].isna())
            self.data_raw = self.data_raw[idx].reset_index(drop=True)
            self.data_out = self.data_out[idx].reset_index(drop=True)
            self.data_verb = self.data_verb[idx].reset_index(drop=True)
            print(self.data_out.columns)
            print(self.data_verb.columns)
            print('# of samples:', len(self.data_raw))
        
#         idx_remove = np.where(self.data_verb['Age']>18)[0].tolist()
#         self.data_raw = self.data_raw.drop(idx_remove).reset_index(drop=True)
#         self.data_out = self.data_out.drop(idx_remove).reset_index(drop=True)
#         self.data_verb = self.data_verb.drop(idx_remove).reset_index(drop=True)
        
        n_pos = np.sum(self.data_out['Diagnosis']==2)
        n_neg = np.sum(self.data_out['Diagnosis']==0)
        print(n_pos,n_neg)
        
        self.data_out['Diagnosis'] = self.data_out['Diagnosis']/2
        
#         idx_remove = np.where(self.data_out['Diagnosis']==1)[0].tolist()
#         idx_remove = random.sample(idx_remove, n_pos-n_neg)
#         self.data_raw = self.data_raw.drop(idx_remove).reset_index(drop=True)
#         self.data_out = self.data_out.drop(idx_remove).reset_index(drop=True)
#         self.data_verb = self.data_verb.drop(idx_remove).reset_index(drop=True)
        
        if len(self.data_raw) != len(self.data_out) or len(self.data_raw) != len(self.data_verb):
            raise Exception("Inconsistent data length")

    def __len__(self):
        return len(self.data_raw)

    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        data = self.data_raw.iloc[idx, 34:118].values.astype('float32')
        data = (data-np.mean(data))/np.std(data)
        target = self.data_out.iloc[idx, :].values.astype('float32')
        verbose = self.data_verb.iloc[idx, :].values.astype('float32')

        return data, target, verbose

dataset = AwareDataset(csv_data = 'data/exhale_data_v6_ave.csv',
                       csv_outcome = 'data/exhale_outcome_v6_ave.csv',
                       csv_verbose = 'data/exhale_verbose_v6_ave.csv',
                       root_dir = 'data/',
                       target_classes = [0,2])

# dataiter = iter(trainloader)
# inputs, labels, info = dataiter.next()
# print(inputs.size())
# print(labels[0])
# plt.plot(inputs[0])
# plt.ylim([-5,5])

Index(['Diagnosis', 'FEV1', 'FEV1/FVC', 'FEF2575', 'FEV1_pred',
       'FEV1/FVC_pred', 'FEF2575_pred'],
      dtype='object')
Index(['ID', 'Trial', 'Age', 'Sex', 'Height', 'Weight'], dtype='object')
# of samples: 121
55 66


In [3]:
from models.MTK import Net
model = Net()
print(model)
model.to(device)
if device == 'cuda':
    model = torch.nn.DataParallel(model)
    cudnn.benchmark = True

encoder = nn.Sequential(
    nn.Linear(84, 100),
    nn.ReLU(),
    nn.Linear(100, 48),
    nn.ReLU(),
    nn.Linear(48, 48))
encoder.load_state_dict(torch.load('./checkpoint/encoder_0.04.pth')['net'])
encoder.to(device)
if device == 'cuda':
    encoder = torch.nn.DataParallel(encoder)
    cudnn.benchmark = True

Net(
  (regressor1): Sequential(
    (0): Linear(in_features=48, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
    (5): Sigmoid()
  )
  (regressor2): Sequential(
    (0): Linear(in_features=48, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=1, bias=True)
    (5): Sigmoid()
  )
  (classifier): Sequential(
    (0): Linear(in_features=50, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=50, bias=True)
    (3): ReLU()
    (4): Linear(in_features=50, out_features=2, bias=True)
    (5): Softmax(dim=None)
  )
)


In [4]:
# criterion1 = AUCMLoss()
# criterion2 = nn.MSELoss()
# optimizer = PESG(model, 
#                  loss_fn=criterion1,
#                  lr=LEARN_RATE,
#                  momentum=0.9,
#                  margin=MARGIN, 
#                  epoch_decay=EPOCH_DECAY, 
#                  weight_decay=WEIGHT_DECAY)

criterion1 = nn.CrossEntropyLoss(weight=torch.Tensor([1, 1]).to(device))
# criterion1 = nn.CrossEntropyLoss(weight=torch.Tensor([2.5, 0.625]).to(device))
criterion2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARN_RATE)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

writer = SummaryWriter()

def train(epoch):
    clear_output(wait=False)
    print('Epoch: %d | TRAIN' % epoch)
    model.train()
    train_loss = 0
    train_correct = 0
    total = 0
    for batch_idx, (inputs, labels, info) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        inputs = encoder(inputs)
        outputs_pred, outputs_fev1, outputs_fev1_fvc = model(inputs)
        
        loss1 = criterion1(outputs_pred, labels[:,0].long())
        loss2 = criterion2(outputs_fev1, labels[:,4])
        loss3 = criterion2(outputs_fev1_fvc, labels[:,5])
        loss = loss1 + (loss2 + loss3)*1e-3
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, preds = torch.max(outputs_pred, 1)
        total += labels.size(0)
        train_correct += torch.sum(preds == labels[:,0].data)

        print('Batch: %d/%d | Loss: %.3f | Acc: %.3f%% (%d/%d)'
              % (batch_idx+1, len(trainloader), train_loss/(batch_idx+1), 100.*train_correct/total, train_correct, total), end='\r')
        
    writer.add_scalar('Loss/train', train_loss/(batch_idx+1), epoch)
    writer.add_scalar('Acc/train', train_correct/total, epoch)
        
def test(epoch):
    global best_acc
    best_acc = 0
    # clear_output(wait=False)
    print('Epoch: %d | TEST' % epoch)
    model.eval()
    test_loss = 0
    test_correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, labels, info) in enumerate(testloader):
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = encoder(inputs)
            outputs_pred, outputs_fev1, outputs_fev1_fvc = model(inputs)

            loss1 = criterion1(outputs_pred, labels[:,0].long())
            loss2 = criterion2(outputs_fev1, labels[:,4])
            loss3 = criterion2(outputs_fev1_fvc, labels[:,5])
            loss = loss1 + (loss2 + loss3)*1e-3
            
            test_loss += loss.item()
            _, preds = torch.max(outputs_pred, 1)
            total += labels.size(0)
            test_correct += torch.sum(preds == labels[:,0].data)

            print('Batch: %d/%d | Loss: %.3f | Acc: %.3f%% (%d/%d)'
                  % (batch_idx+1, len(testloader), test_loss/(batch_idx+1), 100.*test_correct/total, test_correct, total), end='\r')
        
    writer.add_scalar('Loss/test', test_loss/(batch_idx+1), epoch)
    writer.add_scalar('Acc/test', test_correct/total, epoch)
    
    # Save checkpoint.
    acc = 100.*test_correct/total
    if acc > best_acc:
        print('Saving.. | Acc:%.3f%%\n' % acc)
        state = {
            'net': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt_MTK.pth')
        best_acc = acc

metrics = {'Accuracy':[], 
          'BalancedAccuracy':[],
          'Precision':[],
          'Recall':[],
          'Specificity':[],
          'AUROC':[]}

predictions = {'Score0':[], 'Score1':[], 'Prediction':[], 'FEV1_est':[], 'FEV1/FVC_est':[],
               'Diagnosis':[], 'FEV1':[], 'FEV1/FVC':[], 'FEF2575':[],
               'FEV1_pred':[], 'FEV1/FVC_pred':[], 'FEF2575_pred':[],
               'ID':[], 'Trial':[], 'Age':[], 'Sex':[], 'Height':[], 'Weight':[]}               

def evaluate(fold):
    with torch.no_grad():
        for batch_idx, (inputs, labels, info) in enumerate(testloader):
            inputs, labels = inputs.to(device), labels.to(device)
            inputs = encoder(inputs)
            outputs_pred, outputs_fev1, outputs_fev1_fvc = model(inputs)
            _, preds = torch.max(outputs_pred, 1)
            tn, fp, fn, tp = confusion_matrix(labels[:,0].cpu(), preds.cpu()).ravel()
            
            metrics['Accuracy'].append(accuracy_score(labels[:,0].cpu(), preds.cpu()))
            metrics['BalancedAccuracy'].append(balanced_accuracy_score(labels[:,0].cpu(), preds.cpu()))
            metrics['Precision'].append(precision_score(labels[:,0].cpu(), preds.cpu()))
            metrics['Recall'].append(recall_score(labels[:,0].cpu(), preds.cpu()))
            metrics['Specificity'].append(tn/(tn+fp))
            metrics['AUROC'].append(roc_auc_score(labels[:,0].cpu(), outputs_pred[:,1].cpu()))
            
            print(outputs_pred[:,0].cpu().numpy())
            predictions['Score0'].extend(outputs_pred[:,0].cpu().numpy())
            predictions['Score1'].extend(outputs_pred[:,1].cpu().numpy())
            predictions['Prediction'].extend(preds.cpu().numpy())
            predictions['Diagnosis'].extend(labels[:,0].long().cpu().numpy())
            predictions['FEV1_est'].extend(outputs_fev1[:,0].cpu().numpy())
            predictions['FEV1/FVC_est'].extend(outputs_fev1_fvc[:,0].cpu().numpy())
            predictions['FEV1'].extend(labels[:,1].cpu().numpy())
            predictions['FEV1/FVC'].extend(labels[:,2].cpu().numpy())
            predictions['FEF2575'].extend(labels[:,3].cpu().numpy())
            predictions['FEV1_pred'].extend(labels[:,4].cpu().numpy())
            predictions['FEV1/FVC_pred'].extend(labels[:,5].cpu().numpy())
            predictions['FEF2575_pred'].extend(labels[:,6].cpu().numpy())
            predictions['ID'].extend(info[:,0].cpu().numpy())
            predictions['Trial'].extend(info[:,1].cpu().numpy())
            predictions['Age'].extend(info[:,2].cpu().numpy())
            predictions['Sex'].extend(info[:,3].cpu().numpy())
            predictions['Height'].extend(info[:,4].cpu().numpy())
            predictions['Weight'].extend(info[:,5].cpu().numpy())
            
#             print('Fold: ', fold)
#             print(torch.cat((preds,labels[:,0]), 2))
#             print(confusion_matrix(labels[:,0].cpu(), preds.cpu()))
#             print('Accuracy:          %.2f' % accuracy_score(labels[:,0].cpu(), preds.cpu()))
#             print('Balanced Accuracy: %.2f' % balanced_accuracy_score(labels[:,0].cpu(), preds.cpu()))
#             print('Precision:         %.2f' % precision_score(labels[:,0].cpu(), preds.cpu()))
#             print('Recall:            %.2f' % recall_score(labels[:,0].cpu(), preds.cpu()))
#             print('AUROC:             %.2f' % roc_auc_score(labels[:,0].cpu(), outputs_pred[:,1].cpu()))

def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()
        
def subject_accuracy(output, labels, ids):
    unique_ids = torch.unique(ids)
    for i in unique_ids:
        print(i)

In [5]:
idx = list(range(0,len(dataset)))
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_SEED)
for i, (train_idx, test_idx) in enumerate(skf.split(idx, dataset.data_out['Diagnosis'])):
    model.apply(weight_reset)
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
    trainloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    testloader = torch.utils.data.DataLoader(dataset, batch_size=len(test_idx), sampler=test_sampler)
    for epoch in range(200):
        train(epoch)
        test(epoch)
        if isinstance(criterion1, AUCMLoss):
            optimizer.update_regularizer()
        if 'scheduler' in globals():
            scheduler.step()
    
    model.load_state_dict(torch.load('./checkpoint/ckpt_MTK.pth')['net'])

    evaluate(i)
    print(metrics)
    os.rename('./checkpoint/ckpt_MTK.pth', f'./checkpoint/ckpt_MTK_{best_acc:.2f}.pth')

Epoch: 199 | TRAIN
Epoch: 199 | TEST: 1.026 | Acc: 69.072% (67/97)
Saving.. | Acc:62.500%46 | Acc: 62.500% (15/24)

[0.4225707  0.55744696 0.32489306 0.17405713 0.2179031  0.64161175
 0.43087277 0.6094339  0.46348938 0.7344333  0.73258007 0.23296742
 0.28449577 0.61162853 0.35487267 0.4861696  0.63921314 0.53512454
 0.42013776 0.37249818 0.51960915 0.25033778 0.45208523 0.5593583 ]
{'Accuracy': [0.6, 0.625, 0.5833333333333334, 0.5, 0.625], 'BalancedAccuracy': [0.5844155844155844, 0.5979020979020979, 0.5664335664335665, 0.5104895104895105, 0.6328671328671329], 'Precision': [0.5555555555555556, 0.75, 0.5714285714285714, 0.4666666666666667, 0.5714285714285714], 'Recall': [0.45454545454545453, 0.2727272727272727, 0.36363636363636365, 0.6363636363636364, 0.7272727272727273], 'Specificity': [0.7142857142857143, 0.9230769230769231, 0.7692307692307693, 0.38461538461538464, 0.5384615384615384], 'AUROC': [0.6818181818181818, 0.7342657342657343, 0.6853146853146853, 0.46153846153846145, 0.65734265

In [6]:
metrics_df = pd.DataFrame(metrics)
print(metrics_df)
print(metrics_df.mean(axis=0))

predictions_df = pd.DataFrame(predictions)
print(predictions_df)

   Accuracy  BalancedAccuracy  Precision    Recall  Specificity     AUROC
0  0.600000          0.584416   0.555556  0.454545     0.714286  0.681818
1  0.625000          0.597902   0.750000  0.272727     0.923077  0.734266
2  0.583333          0.566434   0.571429  0.363636     0.769231  0.685315
3  0.500000          0.510490   0.466667  0.636364     0.384615  0.461538
4  0.625000          0.632867   0.571429  0.727273     0.538462  0.657343
Accuracy            0.586667
BalancedAccuracy    0.578422
Precision           0.583016
Recall              0.490909
Specificity         0.665934
AUROC               0.644056
dtype: float64
       Score0    Score1  Prediction   FEV1_est  FEV1/FVC_est  Diagnosis  FEV1  \
0    0.038800  0.961200           1  94.114609     93.413284          0  3.96   
1    0.986024  0.013976           0  91.937126     93.013824          1  1.90   
2    0.999121  0.000879           0  91.749153     91.928444          0  3.14   
3    0.951940  0.048061           0  91.003

In [7]:
idx_pos = predictions_df['Prediction']==predictions_df['Diagnosis']
preds_pos = predictions_df.loc[idx_pos,:]
idx_neg = predictions_df['Prediction']!=predictions_df['Diagnosis']
preds_neg = predictions_df.loc[idx_neg,:]
for i in range(len(dataset)):
    print(dataset.data_raw[i])
    Pdb.set_trace()
print(preds_neg.mean(axis=0))
plt.hist(predictions_df['Age'])
plt.show()
plt.hist(preds_neg['Age'])
plt.show()
os.makedirs('outputs', exist_ok=True)  
preds_neg.to_csv('outputs/preds_neg.csv', index=False)

0      1.2
1      1.2
2      1.2
3      1.2
4      1.2
      ... 
116    1.2
117    1.2
118    1.2
119    1.2
120    1.2
Name: 0, Length: 121, dtype: float64


TypeError: set_trace() missing 1 required positional argument: 'self'

In [None]:
len(dataset)