In [1]:
import torch
import torchvision
import csv
import numpy as np
import torch.nn.functional as F
from torch import Tensor
from functools import partial
from typing import Any, Callable, List, Optional, Type, Union
import os
import pydicom
from PIL import Image
import random
import nibabel as nib
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch import nn
import sklearn.metrics
import matplotlib

device = "cuda" if torch.cuda.is_available() else "cpu"
device_name = 'F' #E
if device_name == 'F':
    torch.cuda.set_device(0)
else:
    torch.cuda.set_device(1)

In [2]:
epochs = 50
batch_size=5

class Data_load:
    def __init__(self,num):
        self.len_=0
        self.x_data_name=[]
        self.x_data_test_name=[]
        self.normal_size=0
        self.mci_size=0
        self.ad_size=0
        self.test_len=0
        list_disease = ['Normal', 'MCI', 'AD']
        for i in range(1,6,1):
            for label_target, j in enumerate(list_disease):
                part_data_name=[]
                path = device_name+":/k-fold validation/"+str(i)+"/"+str(j)+"/"
                dir_list = os.listdir(path)
                len_dir=len(dir_list)
                for k in range(len_dir):
                    part_data_name.append([path+dir_list[k], label_target])
                if num != i:
                    self.x_data_name.append(part_data_name)
                else:
                    self.x_data_test_name.append(part_data_name)
                    if label_target == 0:
                        self.normal_size = len_dir
                    elif label_target == 1:
                        self.mci_size = len_dir
                    else:
                        self.ad_size = len_dir
                    self.test_len+=len_dir
                        
        
        self.final_data_name=[]
        self.final_data_test_name=[]
        for i in range(len(self.x_data_name)):
            for j in range(len(self.x_data_name[i])):
                self.final_data_name.append(self.x_data_name[i][j])
        for i in range(len(self.x_data_test_name)):
            for j in range(len(self.x_data_test_name[i])):
                self.final_data_test_name.append(self.x_data_test_name[i][j])
        
        random.shuffle(self.final_data_name)
        
    def out(self):
        return self.final_data_name, self.final_data_test_name
    
    def random(self):
        random.shuffle(self.final_data_name)
        return self.final_data_name

def training(data, batch_size):
    batch_total = len(data)//batch_size
    batch_na = len(data)%batch_size
    for i in range(batch_total+1):
        x_data=[]
        y_data=[]
        extra = []
        extra_final = []
        if i == batch_total:
            if batch_na == 0:
                break
            else:
                temp_data = data[i*batch_size:] 
        else:
            temp_data = data[i*batch_size:(i+1)*batch_size]
        for j in range(len(temp_data)):
            img = nib.load(temp_data[j][0]+'/1/final/1_t1_final.mnc').get_fdata() 
            img = 255*img/img.max() 

            x_data.append(img)      
            
            if temp_data[j][1] == 0:  
                y_data.append([1,0,0])
            elif temp_data[j][1] == 1:
                y_data.append([0,1,0]) 
            elif temp_data[j][1] == 2: 
                y_data.append([0,0,1])  
    

        x_data = torch.from_numpy(np.array(x_data)[:,np.newaxis,:,:,:]).type(torch.FloatTensor).to("cuda")
        y_data = torch.from_numpy(np.array(y_data)).type(torch.FloatTensor).to("cuda")

        yield (x_data, y_data)

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.layer_1 = nn.Conv3d(1,64,5,stride=2)
        self.layer_2 = nn.MaxPool3d(3)
        self.layer_3 = nn.Conv3d(64,128,3,stride=2)
        self.layer_4 = nn.MaxPool3d(3)
        self.layer_5 = nn.Conv3d(128,256,3,stride=1)
        self.layer_6 = nn.Linear(3072,3)
        self.flatten = torch.nn.Flatten()
        self.lrelu = nn.LeakyReLU()

    def forward(self, x):
        x = self.lrelu(self.layer_1(x))
        x = self.layer_2(x)
        x = self.lrelu(self.layer_3(x))
        x = self.layer_4(x)
        x = self.lrelu(self.layer_5(x))
        x = self.flatten(x)
        x = self.layer_6(x)
        return x


In [4]:

nSamples = [1,2,3]
weights = torch.tensor(nSamples).to(device)
for tt in range(1,4,1):
    model = CNN().to("cuda")
    loss_fn = nn.CrossEntropyLoss(weight= weights, reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    data = Data_load(tt)
    for t in range(epochs):
        X_training = data.random()
        model.train()
        for batch, (X, y) in enumerate(training(X_training,batch_size)):
            pred = model(X)
            batch_loss_result = loss_fn(pred, y)
            optimizer.zero_grad()
            batch_loss_result.backward()
            optimizer.step()
            
    torch.save(model,"3D_CNN_"+str(tt)+".pt")


In [5]:
"""
nSamples = [1,2,3]
weights = torch.tensor(nSamples).to(device)
print(weights)
for tt in range(1,4,1):
    accuracy_list=list()
    roc_normal_list = list()
    roc_mci_list = list()
    roc_ad_list = list()
    training_list=list()
    test_list=list()
    model = CNN().to(device)
    loss_fn = nn.CrossEntropyLoss(weight= weights, reduction='sum')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    data = Data_load(tt)
    _,X_test = data.out()
    accuracy = 0
    
    for t in range(epochs):
        print(f"Epoch {t}\n-------------------------------")
        loss_sum=0
        X_training = data.random()
        model.train()
        for batch, (X, y) in enumerate(training(X_training,batch_size)):
            pred = model(X)
            batch_loss_result = loss_fn(pred, y)
            optimizer.zero_grad()
            batch_loss_result.backward()
            optimizer.step()
            loss_sum+=batch_loss_result
       
        loss_sum=loss_sum/(batch+1)
        training_list.append(loss_sum)
        print("train_average_loss",loss_sum)
        A=np.zeros((3,3))
        model.eval()
        with torch.no_grad():
            loss_sum=0
            final_pred=[]
            final_y=[]
            
            accuracy_sum=0
            precision = 0
            recall = 0
            specificity=0
            q = 0
            w = 0
            e = 0
            
            precision1 = 0
            recall1 = 0
            specificity1=0
            q1 = 0
            w1 = 0
            e1 = 0
            
            precision2 = 0
            recall2 = 0
            specificity2=0
            q2 = 0
            w2 = 0
            e2 = 0
            for batch, (X, y) in enumerate(training(X_test,batch_size)):
                pred = model(X)
                pred_softmax = F.softmax(pred,dim=1)
                #print(pred,pred_softmax,y)
                for i in range(y.shape[0]):
                    final_pred.append(pred_softmax[i].to('cpu').numpy())
                    final_y.append(y[i].to('cpu').numpy())
                batch_loss_result = loss_fn(pred, y)
                loss_sum+=batch_loss_result
                
                for i in range(y.shape[0]):
                    if torch.argmax(y[i], dim=0) == 0:
                        if torch.argmax(pred[i], dim=0) == 0:
                            accuracy_sum+= 1
                    elif torch.argmax(y[i], dim=0) == 1:
                        if torch.argmax(pred[i], dim=0) == 1:
                            accuracy_sum+= 1
                    elif torch.argmax(y[i], dim=0) == 2:
                        if torch.argmax(pred[i], dim=0) == 2:
                            accuracy_sum+= 1
                    
                    if torch.argmax(y[i], dim=0) == 0:
                        q+=1
                        if torch.argmax(pred[i], dim=0) == 0:
                            recall += 1
                    if torch.argmax(pred[i], dim=0) == 0:
                        w+=1
                        if torch.argmax(y[i], dim=0) == 0:
                            precision +=1
                    if torch.argmax(y[i], dim=0) != 0:
                        e+=1
                        if torch.argmax(pred[i], dim=0) != 0:
                            specificity +=1
                    
                    if torch.argmax(y[i], dim=0) == 1:
                        q1+=1
                        if torch.argmax(pred[i], dim=0) == 1:
                            recall1 += 1
                    if torch.argmax(pred[i], dim=0) == 1:
                        w1+=1
                        if torch.argmax(y[i], dim=0) == 1:
                            precision1 +=1
                    if torch.argmax(y[i], dim=0) != 1:
                        e1+=1
                        if torch.argmax(pred[i], dim=0) != 1:
                            specificity1 +=1
                    
                    if torch.argmax(y[i], dim=0) == 2:
                        q2+=1
                        if torch.argmax(pred[i], dim=0) == 2:
                            recall2 += 1
                    if torch.argmax(pred[i], dim=0) == 2:
                        w2+=1
                        if torch.argmax(y[i], dim=0) == 2:
                            precision2 +=1
                    if torch.argmax(y[i], dim=0) != 2:
                        e2+=1
                        if torch.argmax(pred[i], dim=0) != 2:
                            specificity2 +=1
                
                pred_index = torch.argmax(pred,dim=1)
                y_index = torch.argmax(y,dim=1)
                for i in range(pred.shape[0]):
                    A[pred_index[i],y_index[i]]+=1
            final_y = np.array(final_y).astype(int)
            final_pred=np.array(final_pred)
            
            pre_graph, re_graph, thresholds = sklearn.metrics.precision_recall_curve(final_y[:,0],final_pred[:,0])
            auc_pr = sklearn.metrics.average_precision_score(final_y[:,0],final_pred[:,0])
            pre_graph1, re_graph1, thresholds1 = sklearn.metrics.precision_recall_curve(final_y[:,1],final_pred[:,1])
            auc_pr1 = sklearn.metrics.average_precision_score(final_y[:,1],final_pred[:,1])
            pre_graph2, re_graph2, thresholds2 = sklearn.metrics.precision_recall_curve(final_y[:,2],final_pred[:,2])
            auc_pr2 = sklearn.metrics.average_precision_score(final_y[:,2],final_pred[:,2])

            fig = plt.figure(figsize=(6, 6))
            ax1 = fig.add_subplot()
            base_x = np.array([0,1])
            nor_base_y = np.array([round(data.normal_size/data.test_len,3),round(data.normal_size/data.test_len,3)])
            mci_base_y = np.array([round(data.mci_size/data.test_len,3),round(data.mci_size/data.test_len,3)])
            ad_base_y = np.array([round(data.ad_size/data.test_len,3),round(data.ad_size/data.test_len,3)])
            line1=ax1.plot(re_graph, pre_graph, label='Normal vs Others, AUC: '+str(round(auc_pr,3)), color='red')
            line2=ax1.plot(re_graph1, pre_graph1, label='MCI vs Others, AUC: '+str(round(auc_pr1,3)), color='blue')
            line3=ax1.plot(re_graph2, pre_graph2, label='Dementia vs Others, AUC: '+str(round(auc_pr2,3)), color='green')
            line4=ax1.plot(base_x,nor_base_y,linestyle='--', color='red', label='Normal baseline, AUC: '+str(round(data.normal_size/data.test_len,3)))
            line5=ax1.plot(base_x,mci_base_y,linestyle='--', color='blue', label='MCI baseline, AUC: '+str(round(data.mci_size/data.test_len,3)))
            line6=ax1.plot(base_x,ad_base_y,linestyle='--', color='green', label='Dementia baseline, AUC: '+str(round(data.ad_size/data.test_len,3)))
            #ax2.set_ylabel('Accuracy')
            ax1.set_xlabel('Recall')
            ax1.set_ylabel('Precision')
            lines = line1 + line2 + line3 + line4 + line5 + line6
            labels = [l.get_label() for l in lines]
            ax1.legend(lines, labels, loc='upper left', bbox_to_anchor=(1, 1))
            plt.grid(True)
            print("Normal AP: ",round(auc_pr,3))
            print("MCI AP: ",round(auc_pr1,3))
            print("Dementia AP: ",round(auc_pr2,3))
            plt.show()
            plt.close()
            
            random_roc_x = np.array([0,1])
            random_roc_y = np.array([0,1])
            fpr, tpr, thresholds_roc = sklearn.metrics.roc_curve(final_y[:,0],final_pred[:,0])
            auc_roc = sklearn.metrics.roc_auc_score(final_y[:,0],final_pred[:,0])
            fpr1, tpr1, thresholds_roc1 = sklearn.metrics.roc_curve(final_y[:,1],final_pred[:,1])
            auc_roc1 = sklearn.metrics.roc_auc_score(final_y[:,1],final_pred[:,1])
            fpr2, tpr2, thresholds_roc2 = sklearn.metrics.roc_curve(final_y[:,2],final_pred[:,2])
            auc_roc2 = sklearn.metrics.roc_auc_score(final_y[:,2],final_pred[:,2])

            fig = plt.figure(figsize=(6, 6))
            ax1 = fig.add_subplot()
            line1=ax1.plot(fpr, tpr, label='Normal vs Others, AUC: '+str(round(auc_roc,3)), color='red')
            line2=ax1.plot(fpr1, tpr1, label='MCI vs Others, AUC: '+str(round(auc_roc1,3)), color='blue')
            line3=ax1.plot(fpr2, tpr2, label='Dementia vs Others, AUC: '+str(round(auc_roc2,3)), color='green')
            line4=ax1.plot(random_roc_x,random_roc_y, label='Baseline, AUC: '+str(0.5),linestyle = '--', color='black')
            #ax2.set_ylabel('Accuracy')
            ax1.set_xlabel('False Positive Rate')
            ax1.set_ylabel('True Positive Rate')
            lines = line1 + line2 + line3 + line4
            labels = [l.get_label() for l in lines]
            ax1.legend(lines, labels, loc='upper left', bbox_to_anchor=(1, 1))
            plt.grid(True)
            print("Normal AUC: ",round(auc_roc,3))
            print("MCI AUC: ",round(auc_roc1,3))
            print("Dementia AUC: ",round(auc_roc2,3))
            plt.show()
            plt.close()
            print("-----------------")
         
            #print("test_average_loss",loss_sum)
            try:
                recall_normal = 100*recall/q
                print("Normal_sensitivity", round(recall_normal,1))
            except:
                pass
            try:
                precision_normal = 100*precision/w
                print("Normal_precision", round(precision_normal,1))
            except:
                pass
            try:
                specificity_normal = 100*specificity/e
                print("Normal_specificity", round(specificity_normal,1))
            except:
                pass
            try:
                F1_normal = 2*recall_normal*precision_normal/(recall_normal+precision_normal)
                print('Normal_F1-score', round(F1_normal,1))
            except:
                pass
            
            print("-----------------")
            #print("test_average_loss",loss_sum)
            try:
                recall_mci = 100*recall1/q1
                print("MCI_sensitivity", round(recall_mci,1))
            except:
                pass
            try:
                precision_mci = 100*precision1/w1
                print("MCI_precision", round(precision_mci,1))
            except:
                pass
            try:
                specificity_mci = 100*specificity1/e1
                print("MCI_specificity", round(specificity_mci,1))
            except:
                pass
            try:
                F1_mci = 2*recall_mci*precision_mci/(recall_mci+precision_mci)
                print('MCI_F1-score', round(F1_mci,1))
            except:
                pass
            
            print("-----------------")
            #print("test_average_loss",loss_sum)
            try:
                recall_ad = 100*recall2/q2
                print("AD_sensitivity", round(recall_ad,1))
            except:
                pass
            try:
                precision_ad = 100*precision2/w2
                print("AD_precision", round(precision_ad,1))
            except:
                pass
            try:
                specificity_ad = 100*specificity2/e2
                print("AD_specificity", round(specificity_ad,1))
            except:
                pass
            try:
                F1_ad = 2*recall_ad*precision_ad/(recall_ad+precision_ad)
                print('AD_F1-score', round(F1_ad,1))
            except:
                pass
            
            print("-----------------")
            try:
                temp_accuracy = 100*accuracy_sum/len(X_test)
                print("Total_Accuracy", round(temp_accuracy,1))
            except:
                pass
            print("-----------------")
            #print("test_average_loss",loss_sum)
            try:
                macro_recall = (recall_normal + recall_mci + recall_ad)/3
                print("Macro_sensitivity", round(macro_recall,1))
            except:
                pass
            try:
                macro_precision = (precision_normal + precision_mci + precision_ad)/3
                print("Macro_precision", round(macro_precision,1))
            except:
                pass
            try:
                macro_specificity = (specificity_normal + specificity_mci + specificity_ad)/3
                print("Macro_specificity", round(macro_specificity,1))
            except:
                pass
            try:
                macro_F1 = 2*macro_recall*macro_precision/(macro_recall+macro_precision)
                print('Macro_F1-score', round(macro_F1,1))
            except:
                pass
            print("-----------------")
            print(A)
            
            loss_sum=loss_sum/(batch+1)
            accuracy_list.append(temp_accuracy/100)
            roc_normal_list.append(round(auc_roc,3))
            roc_mci_list.append(round(auc_roc1,3))
            roc_ad_list.append(round(auc_roc2,3))
            test_list.append(loss_sum)
            print("test_average_loss",loss_sum)
        
        if temp_accuracy > accuracy:
            accuracy = temp_accuracy
            #torch.save(model,"Resnet_"+str(accuracy)+".pt")
    time_steps = list(range(1, len(training_list) + 1))

    # Plot the learning curve
    for i in range(len(training_list)):
        training_list[i] = training_list[i].cpu().detach().numpy()
        test_list[i] = test_list[i].cpu().detach().numpy()

    fig = plt.figure(figsize=(8, 6))
    ax1 = fig.add_subplot()
    line1=ax1.plot(time_steps, training_list, label='Training Loss',color='red')
    line2=ax1.plot(time_steps, test_list, label='Test Loss',color='blue')
    ax2 = plt.twinx()
    line3=ax2.plot(time_steps, roc_normal_list, label='AUC_Normal',color='green')
    line4=ax2.plot(time_steps, roc_mci_list, label='AUC_MCI',color='orange')
    line5=ax2.plot(time_steps, roc_ad_list, label='AUC_AD',color='gray')
    line6=ax2.plot(time_steps, accuracy_list, label='Accuracy',color='brown')
    #ax2.set_ylabel('Accuracy')
    ax1.set_xlabel('Time')
    ax1.set_ylabel('Loss')

    lines = line1 + line2 + line3 + line4 + line5 + line6
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, loc='upper left', bbox_to_anchor=(1.1, 1))
    plt.grid(True)
    plt.show()
    print("Done!")
"""

'\naccuracy_list=list()\nroc_normal_list = list()\nroc_mci_list = list()\nroc_ad_list = list()\ntraining_list=list()\ntest_list=list()\nnSamples = [1,2,3]\nweights = torch.tensor(nSamples).to(device)\nprint(weights)\nfor tt in range(1,4,1):\n    model = CNN().to(device)\n    loss_fn = nn.CrossEntropyLoss(weight= weights, reduction=\'sum\')\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)\n    data = Data_load(tt)\n    _,X_test = data.out()\n    accuracy = 0\n    \n    for t in range(epochs):\n        print(f"Epoch {t}\n-------------------------------")\n        loss_sum=0\n        X_training = data.random()\n        model.train()\n        for batch, (X, y) in enumerate(training(X_training,batch_size)):\n            pred = model(X)\n            batch_loss_result = loss_fn(pred, y)\n            optimizer.zero_grad()\n            batch_loss_result.backward()\n            optimizer.step()\n            loss_sum+=batch_loss_result\n       \n        loss_sum=loss_sum/(batch+1)\n