In [1]:
%matplotlib inline
import cv2
import numpy as np
import os
import sys
import pandas as pd
from random import shuffle
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
import nibabel as nib

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
def create_df(csvpath):
    #     Create dataframe from CSV
    table = pd.read_csv(csvpath)
    return table[['Image Data ID', 'Subject', 'Group', 'Sex', 'Age']]


def create_paths(datapath):
    #     Create paths to all nested images
    imagepaths = []
    for root, dirs, files in os.walk(datapath, topdown=False):
        for name in files:
            imagepaths.append(os.path.join(root, name))
    return imagepaths


def prepare_data(datapath, csvpath):
    #     Create training set
    train_data = []
    imagepaths = create_paths(datapath)
    table = create_df(csvpath)

    for imagepath in tqdm(imagepaths[:-1]):

        #       Find ID of each image, then lookup values using it
        #         13th index for linux, 10th for windows
        path = imagepath.split('\\')[13]
        img_id = path[path.find('_I') + 2:-4]
        subject_group = table.loc[table['Image Data ID'] == int(
            img_id)]["Group"].reset_index(drop=True)[0]
        subject_age = table.loc[table['Image Data ID'] == int(
            img_id)]["Age"].reset_index(drop=True)[0]
        subject_sex = table.loc[table['Image Data ID'] == int(
            img_id)]["Sex"].reset_index(drop=True)[0]
        subject_id = table.loc[table['Image Data ID'] == int(
            img_id)]["Subject"].reset_index(drop=True)[0]

        #       Nibabel to load NIFTI Images
        img = nib.load(imagepath)
        imgdata = img.get_fdata()

        train_data.append((imgdata, subject_group))

    return train_data

In [4]:
class ADNI(Dataset):
    def __init__(self, datapath, csvpath, transform=None):
        """
        Args:
            datapath (string): Directory with all the images.
            csvpath (string): Path to CSV
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.table = create_df(csvpath)
        self.imagepaths = create_paths(datapath)[:-1]
        self.transform = transform

    def __len__(self):
        #         Returns the length of the dataset
        return len(self.table)

    def __getitem__(self, idx):
        #         Returns a tuple of the image and its group/label
        imgsize = 224

        if torch.is_tensor(idx):
            idx = idx.tolist()

        imagepath = self.imagepaths[idx]
        #         10th index for windows, 13th for linux
        idpath = imagepath.split('/')[13]
        img_id = idpath[idpath.find('_I') + 2:-4]
        group = self.table.loc[self.table['Image Data ID'] == int(
            img_id)]["Group"].reset_index(drop=True)[0]
        group_to_id = {'CN': 0, 'MCI': 1, 'AD': 2}
        group_id = group_to_id[group]

        imgdata = nib.load(imagepath).get_fdata()
        #         temporary, for only one slice
        imgdata = cv2.resize(imgdata[100, :, :], (imgsize, imgsize))
        imgdata = imgdata.reshape(1, imgsize, imgsize)
        imgdata = torch.from_numpy(imgdata)
        #         Apply transform to sample
        if self.transform:
            imgdata = self.transform(imgdata)

        sample = (imgdata, torch.tensor(group_id))
        return sample

In [53]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 32, 5)
        self.conv3 = nn.Conv2d(32, 64, 5)
        self.conv4 = nn.Conv2d(64, 128, 5)
        self.fc1 = nn.Linear(128*10*10, 224)
        self.fc2 = nn.Linear(224,100)
        self.fc3 = nn.Linear(100,3)
        self.drop1 = nn.Dropout(p=0.1)
        self.drop2 = nn.Dropout(p=0.5)
        x = torch.randn(224, 224).view(-1, 1, 224, 224)
    def forward(self, x):
        x = self.pool(self.drop1(F.relu(self.conv1(x))))
        x = self.pool(self.drop1(F.relu(self.conv2(x))))
        x = self.pool(self.drop1(F.relu(self.conv3(x))))
        x = self.pool(self.drop1(F.relu(self.conv4(x))))
#         print(x.shape)
        x = x.view(-1, 128*10*10)
        x = F.relu(self.fc1(x))
        x = self.drop2(x)
        x = F.relu(self.fc2(x))
        x = self.drop2(x)
        x = self.fc3(x)
        return x

In [62]:
datapath = r"/media/swang/Windows/Users/swang/Downloads/ADNI1_Complete_1Yr_1.5T"
csvpath = r"/media/swang/Windows/Users/swang/Downloads/ADNI1_Complete_1Yr_1.5T_7_08_2020.csv"
dataset = ADNI(datapath, csvpath)
#  transform = transforms.Normalize(mean=[192.1213], std=[215.9763])

lengths = [
    int(len(dataset) * 0.8),
    int(len(dataset) * 0.05),
    int(len(dataset) * 0.15) + 1
]


trainset, valset, testset = random_split(dataset, lengths)
image_datasets = {'train': trainset, 'val': valset, 'test': testset}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=16, shuffle=True, num_workers=4)
              for x in ['train', 'val', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}  


# train_loader = DataLoader(trainset, batch_size, shuffle=True, num_workers)
# dev_loader = DataLoader(devset, batch_size, shuffle=True, num_workers)
# test_loader = DataLoader(testset, batch_size, shuffle=True, num_workers)

# for batch_x, batch_y in train_loader:
#     batch_x, batch_y = batch_x.to(device), batch_y.to(device)
#     print(batch_x.shape, batch_y.shape)

In [55]:
def train(net):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    for epoch in range(10):

        running_loss = 0.0
        correct = 0
        total = 0
        for data in tqdm(dataloaders['train'], total = dataset_sizes['train']//16):
            batch_x, batch_y = data
            batch_x, batch_y = batch_x.to(device, dtype=torch.float), batch_y.to(device)
            net.zero_grad()
            optimizer.zero_grad()
            outputs = net(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch: {epoch + 1}, Loss: {running_loss/(dataset_sizes['train']//16)}")
        print('---------------------------------------------------------')
        running_loss = 0.0

    with torch.no_grad():
        for image, group in tqdm(dataloaders['val'], total = dataset_sizes['val']//16):
            image, group = image.to(device, dtype=torch.float), group.to(device)
            outputs = net(image)
            _, predicted = torch.max(outputs.data, 1)
            total += len(group)
            correct += (predicted == group).sum().item()
    #         print((predicted, group))  
        print(f"Accuracy = {100 * correct / total}")

# For finding optimal weights and preventing overfitting:
#  - Set model to eval mode (and maybe to torch.no_grad()) with torch.no_grad():
#  - Save the model (pytorch automatically pickles model and saves it I think look up how to do this)
#  - Run the model on validation set and save its accuracy
#  - If validation accuracy has been declining for like last 5 epochs or something just take the most accurate one and load weights


In [56]:
if __name__ == '__main__':
    device = torch.device("cuda:0")
    net = Net().to(device)
    train(net)

HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 1, Loss: 1.0723267223751336
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 2, Loss: 1.0637481155102713
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 3, Loss: 1.0552606906807214
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 4, Loss: 1.0573667219856329
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 5, Loss: 1.0599884108493203
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 6, Loss: 1.0545184800499363
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 7, Loss: 1.0512872866371221
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 8, Loss: 1.0526386154325384
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 9, Loss: 1.0507566353730988
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=114.0), HTML(value='')))


Epoch: 10, Loss: 1.052500265732146
---------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))


Accuracy = 42.98245614035088


In [58]:
torch.cuda.empty_cache()


In [69]:
from torchsummary import summary
summary(net, (1, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 6, 220, 220]             156
           Dropout-2          [-1, 6, 220, 220]               0
         MaxPool2d-3          [-1, 6, 110, 110]               0
            Conv2d-4         [-1, 32, 106, 106]           4,832
           Dropout-5         [-1, 32, 106, 106]               0
         MaxPool2d-6           [-1, 32, 53, 53]               0
            Conv2d-7           [-1, 64, 49, 49]          51,264
           Dropout-8           [-1, 64, 49, 49]               0
         MaxPool2d-9           [-1, 64, 24, 24]               0
           Conv2d-10          [-1, 128, 20, 20]         204,928
          Dropout-11          [-1, 128, 20, 20]               0
        MaxPool2d-12          [-1, 128, 10, 10]               0
           Linear-13                  [-1, 224]       2,867,424
          Dropout-14                  [

In [61]:
for i in range(20):
    print(torch.argmax(net(dataset[i][0].reshape(1,1,224,224).to(device, dtype=torch.float))))
    print(dataset[i][1])

tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(2)
tensor(1, device='cuda:0')
tensor(2)
tensor(1, device='cuda:0')
tensor(2)
tensor(1, device='cuda:0')
tensor(2)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(0)
tensor(1, device='cuda:0')
tensor(1)
tensor(1, device='cuda:0')
tensor(1)


In [57]:
def test_accuracy(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for image, group in tqdm(dataloaders['test'], total = dataset_sizes['test']//16):
            image, group = image.to(device, dtype=torch.float), group.to(device)
            outputs = net(image)
            _, predicted = torch.max(outputs.data, 1)
            total += len(group)
            correct += (predicted == group).sum().item()
    #         print((predicted, group))  
        print(f"Accuracy = {100 * correct / total}")

test_accuracy(net)

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))


Accuracy = 50.43478260869565


In [None]:
torch.cuda.empty_cache()

In [None]:
means = []
stds = []
for img in tqdm(dataset):
    means.append(torch.mean(img[0]))
    stds.append(torch.std(img[0]))

mean = torch.mean(torch.tensor(means))
std = torch.mean(torch.tensor(stds))
print(mean, std)