In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import seaborn as sn
from scipy.special import softmax
from sklearn.metrics import classification_report, confusion_matrix, f1_score

from resnet import ResNet
from unet import UNet
from dnn import DNN
from utils import *

In [1]:
# [245, 610, 769]
print([769/245, 769/610, 1)

3.1387755102040815 1.2606557377049181 1


In [2]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

True
0
1
NVIDIA GeForce GTX 1070


## Loaders

### MRI

In [3]:
num_180 = 0
num_166 = 0
num_162 = 0
num_other = 0

cls_num_list_train = [0,0,0]
cls_num_list_test = [0,0,0]
n = 0
bigmat = []

for folder in ["train","test"]:
    for cls_idx,subfolder in enumerate(["AD","CN","MCI"]):
        files = os.listdir('MRI_all/' + folder + '/' + subfolder)
        
        for file in files:
            mat = np.load('MRI_all/' + folder + '/' + subfolder + '/' + file)
            
            if folder=="train":
                bigmat.append(mat)
                n += 1
                cls_num_list_train[cls_idx] += 1
            if folder=="test":
                cls_num_list_test[cls_idx] +=1
            
            # print(folder, subfolder, file, mat.shape)
            # if(mat.shape[0] != 166):
            #     os.remove(folder + '/' + subfolder + '/' + file)
            if(mat.shape[0] == 166):
                # mat = mat[2:164]
                num_166 += 1
            elif(mat.shape[0] == 180):
                # mat = mat[9:171]
                num_180 += 1
                # mat = mat[7:173]
            elif(mat.shape[0] == 162):
                num_162 += 1
            else:
                num_other += 1
            # np.save(folder + '/' + subfolder + '/' + file, mat)
            
print(num_166, num_180, num_162)
print(cls_num_list_train, cls_num_list_test)

bigmat = np.stack(bigmat)
print(bigmat.shape)

# data = bigmat/255
# mean_list = np.mean(data,axis=(0,2,3))
# std_list = np.std(data,axis=(0,2,3))

mean_list = np.load("MRI_mean.npy")
std_list = np.load("MRI_std.npy")

transform_train = transforms.Compose([transforms.ToTensor(),
                                     transforms.RandomCrop(256, padding=16),
                                     # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0, hue=0),
                                     transforms.GaussianBlur(kernel_size = (5,5), sigma=(0.2,0.2)),
                                     transforms.Normalize(mean_list, std_list)
                                    ])
transform_test = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean_list, std_list)])

def npy_loader_train(path):
    mat = np.load(path)
    mat = np.transpose(mat, (1,2,0))
    mat = transform_train(mat)
    return mat

def npy_loader_test(path):
    mat = np.load(path)
    mat = np.transpose(mat, (1,2,0))
    mat = transform_test(mat)
    return mat

trainset_MRI = datasets.DatasetFolder(root="MRI_all/train", loader=npy_loader_train, extensions=(".npy"))
testset_MRI = datasets.DatasetFolder(root="MRI_all/test", loader=npy_loader_test, extensions=(".npy"))

trainloader_MRI = DataLoader(trainset_MRI, batch_size=2, shuffle=False)
testloader_MRI = DataLoader(testset_MRI, batch_size=2, shuffle=False)

0 0 1624
[207, 483, 607] [38, 127, 162]


### EHR

In [None]:
cls_num_list_train = [0,0,0]
cls_num_list_test = [0,0,0]
n = 0
bigmat = []

for folder in ["train","test"]:
    for cls_idx,subfolder in enumerate(["AD","CN","MCI"]):
        files = os.listdir('EHR_all/' + folder + '/' + subfolder)
        
        for file in files:
            mat = np.load('EHR_all/' + folder + '/' + subfolder + '/' + file)
            
            if folder=="train":
                bigmat.append(mat)
                n += 1
                cls_num_list_train[cls_idx] += 1
            if folder=="test":
                cls_num_list_test[cls_idx] +=1
            
            # print(folder, subfolder, file, mat.shape)
            
print(cls_num_list_train, cls_num_list_test)

bigmat = np.stack(bigmat)
print(bigmat.shape)

data = bigmat
mean_list = np.mean(data,axis=0)
std_list = np.std(data,axis=0) + 0.01

def npy_loader_train(path):
    mat = np.load(path)
    # mat = transform_train(mat)
    mat = (mat - mean_list)/std_list
    mat = torch.Tensor(mat)
    return mat

def npy_loader_test(path):
    mat = np.load(path)
    # mat = transform_test(mat)
    mat = (mat - mean_list)/std_list
    mat = torch.Tensor(mat)
    return mat

trainset_EHR = datasets.DatasetFolder(root="EHR_all/train", loader=npy_loader_train, extensions=(".npy"))
testset_EHR = datasets.DatasetFolder(root="EHR_all/test", loader=npy_loader_test, extensions=(".npy"))

trainloader_EHR = DataLoader(trainset_EHR, batch_size=2, shuffle=False)
testloader_EHR = DataLoader(testset_EHR, batch_size=2, shuffle=False)

[34, 52, 89] [8, 21, 35]
(175, 100)


## Models

In [None]:
# Initialize
model_MRI = ResNet(in_channels= 162, num_classes=3).to(device)
model_EHR = DNN(in_channels= 100, num_classes=3).to(device)

# Load saved model weights
model_MRI.load_state_dict(torch.load("checkpoint/" + "ResNet.pth"))
model_EHR.load_state_dict(torch.load("checkpoint/" + "DNN.pth"))

model_MRI.eval()
model_EHR.eval()

DNN(
  (fc1): Linear(in_features=100, out_features=10, bias=True)
  (fc1_drop): Dropout(p=0.25, inplace=False)
  (fc2): Linear(in_features=10, out_features=3, bias=True)
)

### Trainset

In [None]:
running_corrects_MRI = 0; running_corrects_EHR = 0; running_corrects_total = 0;
y = []; yhat_MRI = []; yhat_EHR = []; yhat_total = []

for data_MRI, data_EHR in zip(trainloader_MRI, trainloader_EHR):
    
    inputs, labels = data_MRI
    inputs, labels = inputs.to(device), labels.to(device)
    with torch.no_grad():
        outputs_MRI = model_MRI(inputs)
    
    inputs, labels = data_EHR
    inputs, labels = inputs.to(device), labels.to(device)
    with torch.no_grad():
        outputs_EHR = model_EHR(inputs)
    
    probs_MRI = F.softmax(outputs_MRI, -1)
    probs_EHR = F.softmax(outputs_EHR, -1)
    entropy_MRI = get_entropy(probs_MRI)
    entropy_EHR = get_entropy(probs_EHR)
    
    # Combine probability vectors
    # probs_total = (probs_MRI + probs_EHR)/2
    probs_total = probs_MRI/entropy_MRI + probs_EHR/entropy_EHR
    
    _, preds_MRI = torch.max(outputs_MRI, 1)
    _, preds_EHR = torch.max(outputs_EHR, 1)
    _, preds_total = torch.max(probs_total, 1)
    
    y.append(labels.tolist())
    yhat_MRI.append(preds_MRI.tolist())
    yhat_EHR.append(preds_EHR.tolist())
    yhat_total.append(preds_total.tolist())
    
    running_corrects_MRI += torch.sum(preds_MRI == labels.data)
    running_corrects_EHR += torch.sum(preds_EHR == labels.data)
    running_corrects_total += torch.sum(preds_total == labels.data)

test_accuracy_MRI = (running_corrects_MRI.float() / len(trainset_MRI))
test_accuracy_EHR = (running_corrects_EHR.float() / len(trainset_EHR))
test_accuracy_total = (running_corrects_total.float() / len(trainset_EHR))

y = np.hstack(y)
yhat_MRI = np.hstack(yhat_MRI)
yhat_EHR = np.hstack(yhat_EHR)
yhat_total = np.hstack(yhat_total)

# print("MRI | Accuracy: {:.4f}\n".
#       format(test_accuracy_MRI.item()))
# print("EHR | Accuracy: {:.4f}\n".
#       format(test_accuracy_EHR.item()))

print("MRI:\n", classification_report(y, yhat_MRI, target_names=trainset_MRI.classes, digits=4))
print("EHR:\n ", classification_report(y, yhat_EHR, target_names=trainset_EHR.classes, digits=4))
print("Total:\n ", classification_report(y, yhat_total, target_names=trainset_EHR.classes, digits=4))

MRI:
               precision    recall  f1-score   support

          AD     0.6667    0.0588    0.1081        34
          CN     1.0000    0.0385    0.0741        52
         MCI     0.5235    1.0000    0.6873        89

    accuracy                         0.5314       175
   macro avg     0.7301    0.3658    0.2898       175
weighted avg     0.6929    0.5314    0.3925       175

EHR:
                precision    recall  f1-score   support

          AD     0.5667    1.0000    0.7234        34
          CN     0.7895    0.5769    0.6667        52
         MCI     0.7792    0.6742    0.7229        89

    accuracy                         0.7086       175
   macro avg     0.7118    0.7504    0.7043       175
weighted avg     0.7410    0.7086    0.7063       175

Total:
                precision    recall  f1-score   support

          AD     0.5789    0.9706    0.7253        34
          CN     1.0000    0.0385    0.0741        52
         MCI     0.5948    0.7753    0.6732        89

### Testset

In [None]:
running_corrects_MRI = 0; running_corrects_EHR = 0; running_corrects_total = 0;
y = []; yhat_MRI = []; yhat_EHR = []; yhat_total = []

for data_MRI, data_EHR in zip(testloader_MRI, testloader_EHR):
    
    inputs, labels = data_MRI
    inputs, labels = inputs.to(device), labels.to(device)
    with torch.no_grad():
        outputs_MRI = model_MRI(inputs)
    
    inputs, labels = data_EHR
    inputs, labels = inputs.to(device), labels.to(device)
    with torch.no_grad():
        outputs_EHR = model_EHR(inputs)
    
    probs_MRI = F.softmax(outputs_MRI, -1)
    probs_EHR = F.softmax(outputs_EHR, -1)
    entropy_MRI = get_entropy(probs_MRI)
    entropy_EHR = get_entropy(probs_EHR)
    
    # Combine probability vectors
    # probs_total = (probs_MRI + probs_EHR)/2
    probs_total = probs_MRI/entropy_MRI + probs_EHR/entropy_EHR
    
    _, preds_MRI = torch.max(outputs_MRI, 1)
    _, preds_EHR = torch.max(outputs_EHR, 1)
    _, preds_total = torch.max(probs_total, 1)
    
    y.append(labels.tolist())
    yhat_MRI.append(preds_MRI.tolist())
    yhat_EHR.append(preds_EHR.tolist())
    yhat_total.append(preds_total.tolist())
    
    running_corrects_MRI += torch.sum(preds_MRI == labels.data)
    running_corrects_EHR += torch.sum(preds_EHR == labels.data)
    running_corrects_total += torch.sum(preds_total == labels.data)

test_accuracy_MRI = (running_corrects_MRI.float() / len(testset_MRI))
test_accuracy_EHR = (running_corrects_EHR.float() / len(testset_EHR))
test_accuracy_total = (running_corrects_total.float() / len(testset_EHR))

y = np.hstack(y)
yhat_MRI = np.hstack(yhat_MRI)
yhat_EHR = np.hstack(yhat_EHR)
yhat_total = np.hstack(yhat_total)

# print("MRI | Accuracy: {:.4f}\n".
#       format(test_accuracy_MRI.item()))
# print("EHR | Accuracy: {:.4f}\n".
#       format(test_accuracy_EHR.item()))

print("MRI:\n", classification_report(y, yhat_MRI, target_names=trainset_MRI.classes, digits=4))
print("EHR:\n ", classification_report(y, yhat_EHR, target_names=trainset_EHR.classes, digits=4))
print("Total:\n ", classification_report(y, yhat_total, target_names=trainset_EHR.classes, digits=4))

MRI:
               precision    recall  f1-score   support

          AD     0.0000    0.0000    0.0000         8
          CN     0.0000    0.0000    0.0000        21
         MCI     0.5469    1.0000    0.7071        35

    accuracy                         0.5469        64
   macro avg     0.1823    0.3333    0.2357        64
weighted avg     0.2991    0.5469    0.3867        64

EHR:
                precision    recall  f1-score   support

          AD     0.4211    1.0000    0.5926         8
          CN     0.6667    0.5714    0.6154        21
         MCI     0.8148    0.6286    0.7097        35

    accuracy                         0.6562        64
   macro avg     0.6342    0.7333    0.6392        64
weighted avg     0.7170    0.6562    0.6641        64

Total:
                precision    recall  f1-score   support

          AD     0.3889    0.8750    0.5385         8
          CN     0.0000    0.0000    0.0000        21
         MCI     0.6087    0.8000    0.6914        35

  _warn_prf(average, modifier, msg_start, len(result))
