In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import seaborn as sn
from scipy.special import softmax
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix, f1_score

from resnet import ResNet
from unet import UNet
from dnn import DNN
from utils import *

In [2]:
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

True
0
1
NVIDIA GeForce GTX 1070


In [3]:
cls_num_list = [0,0,0]
n = 0
bigmat = []

for cls_idx,subfolder in enumerate(["AD","CN","MCI"]):
    files = os.listdir('MRI_trial/' + subfolder)
    
    for file in files:
        mat = np.load('MRI_trial/'  + subfolder + '/' + file)
        # bigmat.append(mat)
        n += 1
        cls_num_list[cls_idx] += 1 

# bigmat = np.stack(bigmat)
# print(bigmat.shape)
print(cls_num_list)           

[245, 610, 769]


In [4]:
# mean_list = np.mean(bigmat,axis=(0,2,3))/255
# std_list = np.std(bigmat,axis=(0,2,3))/255

# std_list = np.std(bigmat[:100],axis=(0,2,3))/255

# for j in range(15):
#     std_list += np.std(bigmat[100*(j+1):100*(j+2)],axis=(0,2,3))/255

# std_list /= 16

# np.save("MRI_trial_mean.npy",mean_list)
# np.save("MRI_trial_std.npy",std_list)

In [5]:
mean_list = np.load("MRI_trial_mean.npy")
std_list = np.load("MRI_trial_std.npy")

In [6]:
transform_train = transforms.Compose([transforms.ToTensor(),
                                     transforms.RandomCrop(256, padding=16),
                                     # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0, hue=0),
                                     transforms.GaussianBlur(kernel_size = (5,5), sigma=(0.2,0.2)),
                                     transforms.Normalize(mean_list, std_list)
                                    ])
transform_test = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean_list, std_list)
                                    ])

def npy_loader_train(path):
    mat = np.load(path)
    mat = np.transpose(mat, (1,2,0))
    mat = transform_train(mat)
    return mat

def npy_loader_test(path):
    mat = np.load(path)
    mat = np.transpose(mat, (1,2,0))
    mat = transform_test(mat)
    return mat

dataset = datasets.DatasetFolder(root="MRI_trial", loader=npy_loader_test, extensions=(".npy"))

# trainloader = DataLoader(trainset, batch_size=2, shuffle=True)
# testloader = DataLoader(testset, batch_size=2, shuffle=False)

In [7]:
k_folds = 5
num_epochs = 10

# torch.manual_seed(42)

kfold = KFold(n_splits=k_folds, shuffle=True)

# per_cls_weights = [0.55,0.35,0.1]
# criterion = nn.CrossEntropyLoss(weight=torch.tensor(per_cls_weights))

per_cls_weights = reweight(cls_num_list)
criterion = FocalLoss(weight=torch.tensor(per_cls_weights, device=device))

criterion = criterion.to(device)

# For fold results
results = {}

In [8]:
# K-fold Cross Validation model evaluation
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):

    # Print
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
    
    # Define data loaders for training and testing data in this fold
    trainloader = torch.utils.data.DataLoader(
                      dataset, 
                      batch_size=2, sampler=train_subsampler)
    testloader = torch.utils.data.DataLoader(
                      dataset,
                      batch_size=2, sampler=test_subsampler)

    model = ResNet(in_channels= 162, num_classes=3).to(device)
    # model = UNet(in_channels=162, num_classes=3).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Adam optimizer
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.5)

    # Run the training loop for defined number of epochs
    train_loss_list = []
    test_loss_list = []

    for epoch in range(num_epochs):  # num epochs
        
        # Training
        model.train() # Set to train mode
        running_loss = 0.0
        running_corrects = 0

        for i, data in enumerate(trainloader): # Get data batch-wise
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # print(inputs.shape, inputs.dtype)
            # labels = labels.type(torch.LongTensor)

            # zero out gradients
            optimizer.zero_grad()

            outputs = model(inputs) # forward pass
            # print(outputs.shape, labels.shape)
            loss = criterion(outputs, labels) # Get loss
            loss.backward() # Backward pass
            optimizer.step() # Optimize model weights

            _, preds = torch.max(outputs, 1) # Get predictions
            running_loss += loss.detach() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        # Get loss and accuracy
        train_loss = (running_loss / len(train_subsampler))
        train_loss_list.append(train_loss.item())
        train_accuracy = (running_corrects.float() / len(train_subsampler))

        # Testing
        model.eval() # Set to eval mode
        running_loss = 0.0
        running_corrects = 0
        y = []; yhat = []

        for i, data in enumerate(testloader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # labels = labels.type(torch.LongTensor)

            # forward + backward + optimize
            with torch.no_grad(): # Don't build computation graph for testing
                outputs = model(inputs)
            # print(outputs.shape, labels.shape)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            y.append(labels.tolist())
            yhat.append(preds.tolist())
            running_loss += loss.detach() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        # scheduler.step() # Update lr
        
        # Get loss and accuracy
        test_loss = (running_loss / len(test_subsampler))
        test_loss_list.append(test_loss.item())
        test_accuracy = (running_corrects.float() / len(test_subsampler))

        # Display loss, accuracy values after each epoch
        print("Fold:{} | Epoch:{} | Train loss: {:.4f} | Test loss: {:.4f} | Train acc.: {:.4f} | Test acc.: {:.4f}\n"
                  .format(fold, epoch, train_loss.item(),test_loss.item(),train_accuracy.item(),test_accuracy.item()))

    # Saving the model
    save_path = f'checkpoint/ResNet-fold-{fold}.pth'
    # save_path = f'checkpoint/UNet-fold-{fold}.pth'
    
    torch.save(model.state_dict(), save_path)

    # Loss curves
    plot_loss(train_loss_list,test_loss_list)

    # Report
    results[fold] = print_report(y,yhat, class_names = dataset.classes)

FOLD 0
--------------------------------
Fold:0 | Epoch:0 | Train loss: 0.0123 | Test loss: 0.0020 | Train acc.: 0.3718 | Test acc.: 0.3785

Fold:0 | Epoch:1 | Train loss: 0.0019 | Test loss: 0.0019 | Train acc.: 0.3449 | Test acc.: 0.4554

Fold:0 | Epoch:2 | Train loss: 0.0019 | Test loss: 0.0019 | Train acc.: 0.3687 | Test acc.: 0.1692



KeyboardInterrupt: 