<a href="https://colab.research.google.com/github/yecatstevir/teambrainiac/blob/main/source/DL/Group_3DCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Learning with PyTorch
## 3D Convolutional Neural Network on Group Brain fMRI
Contributors: Ben Merrill, Stacey Rivet Beck

This CNN is implemented in similar form to the fMRI CNN dicussed in the paper by [Wang et al.](https://arxiv.org/pdf/1801.09858.pdf)

You can see the architecture of the CNN in the class ConvNet below. Throughout the model training, we used cross-entropy loss and back propogation to update, keeping track of both accuracy and loss at each epoch. Due to RAM limitations, we trained the model from 4 files. Each file was split into two sets of 756 images and used for training. The first run through each of the 8 datasets entailed 10 epochs. The second run-through contained one epoch for each dataset, to avoid overfitting the model.

Once the training was complete, the model measured error and trained one epoch on the validation set. Finally, the model runs the test data and returns the predictions and the error. 

Although this script only shows the training portion of the model building, it could very easily be updated for model validation and testing.

## Importing Dataset and Labels

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/gdrive')  

In [None]:
# Clone the entire repo.
!git clone -l -s https://github.com/yecatstevir/teambrainiac.git

# Change directory into cloned repo DL folder
%cd teambrainiac/source/DL

# !ls

### Load path_config.py

In [None]:
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

## Import Packages

In [None]:
# # Possible Missing Packages
!pip install boto3
!pip install nilearn

In [None]:
# General Library Imports
import scipy.io
import os
import pickle
import numpy as np
import nibabel as nib
import pandas as pd
import boto3
import tempfile
import tqdm
import random
from path_config import mat_path
from botocore.exceptions import ClientError
from collections import defaultdict
from sklearn.preprocessing import StandardScaler

# From Local Directory
from access_data_dl import *
from process_dl import *

# Pytroch Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import TensorDataset

#import torchvision.transforms as transforms
from torch.nn import ReLU, CrossEntropyLoss, Conv3d, Module, Softmax, AdaptiveAvgPool3d
from torch.optim import Adam, SGD

#from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader


## Build Model

In [None]:
class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    
    #Conv1
    self.conv1 = nn.Conv3d(in_channels = 1, 
                           out_channels = 32, 
                           kernel_size = (1,1,1), 
                           stride = (1,1,1)
                           )
    self.bn1 = nn.BatchNorm3d(32)
    self.conv2 = nn.Conv3d(in_channels = 32, 
                           out_channels = 64, 
                           kernel_size = (7,7,7),
                           stride = (2,2,2)
                           )
    self.bn2 = nn.BatchNorm3d(64)
    self.conv3 = nn.Conv3d(in_channels = 64, 
                           out_channels = 64, 
                           kernel_size = (3,3,3),
                           stride = (2,2,2)
                           )
    self.bn3 = nn.BatchNorm3d(64)
    self.conv4 = nn.Conv3d(in_channels = 64, 
                           out_channels = 128, 
                           kernel_size = (3,3,3),
                           stride = (2,2,2)
                           )
    self.bn4 = nn.BatchNorm3d(128) 
    self.pool1 = nn.AdaptiveAvgPool3d((1,1,1)) #Global Average Pool, takes the average over last two dimensions to flatten 
  
                                                             
    # Fully connected layer
    self.fc1 = nn.Linear(128,64) # need to find out the size where AdaptiveAvgPool 
    self.fc2 = nn.Linear(64, 2) # left with 2 for the two classes                     

  def forward(self, xb):
    xb = self.bn1(F.relu(self.conv1((xb))))
    xb = self.bn2(F.relu(self.conv2((xb)))) # Takes a long time
    xb = self.bn3(F.relu(self.conv3((xb))))
    xb = self.bn4(F.relu(self.conv4((xb))))
    xb = self.pool1(xb)
    xb = xb.view(xb.shape[:2])
    xb = self.fc1(xb)
    xb = self.fc2(xb)
    return xb      

    

## Functions to Run the Model

In [None]:
def run_cnn(model, epochs, learning_rate, loss_func, opt, dl, val=False, test=False):
  metrics_dict = {}

  for epoch in range(1, 1+epochs):
    accuracy_list = []
    loss_list = []
    model.train()
    print('epoch', epoch)
    for i,(xb, yb) in enumerate(dl):
      print('batch', i)

      xb = xb.float()
      pred = model(xb)
      loss_batch = loss_func(pred, yb)
      loss_list.append(loss_batch)
      accuracy_batch = accuracy(pred, yb)
      
      # Early Stopping
      if int(accuracy_batch) == 1:
        print('Perfect Accuracy\nStopping early to avoid overfitting\n\n')
        return model, metrics_dict

      accuracy_list.append(accuracy_batch)

      print('Batch Loss', loss_batch)
      print('Batch Accuracy', accuracy_batch)

      loss_batch.backward()
      opt.step()
      opt.zero_grad()
      if val == True or test == True:
        metrics_dict['preds_'+str(i)] = pred
        metrics_dict['labels'] = yb


    model.eval()
    metrics_dict['epoch_'+str(epoch)] = {'accuracy':accuracy_list, 'loss':loss_list}

    print('epoch', epoch, 'finished\n')
    
    # Other early stopping criterion
    try:
      past_epoch_accuracies = [sum(metrics_dict['epoch_'+str(epoch-2)]['accuracy']), sum(metrics_dict['epoch_'+str(epoch-1)]['accuracy'])]
      current_epoch_accuracy = sum(metrics_dict['epoch_'+str(epoch)]['accuracy'])
      if past_epoch_accuracies[0] > current_epoch_accuracy and past_epoch_accuracies[1] > current_epoch_accuracy:
        print('Early stop to avoid overfitting\nModel accuracies did not decrease for two epochs')
        return model, metrics_dict
    except:
      pass
  
  return model, metrics_dict
  

def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

## Setting Up the Model and Parameters

In [None]:
# Set to GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Get model
model = ConvNet()
model = model.to(device)

# Initialize other parameters
epochs = 10
learning_rate = 0.001
loss_func = F.cross_entropy
opt = torch.optim.Adam(model.parameters(), lr = learning_rate)

## Decide Train, Validate, or Test
Next, decide which part of the pipeline you are running. To not overwhelm ram, please select one of the three options: training, validation, or testing.

## Train

The training data is saved in 4 files. This section loads the recent parameters if needed, loads the training data of choice, and trains the model implemented above.

In [None]:
# Load older models for continued training
load_recent_model = True
train_path = '/content/gdrive/My Drive/cnn_train_2_round_3.pt'

if load_recent_model:
  model.load_state_dict(torch.load(train_path))

In [None]:
%%time
# Pick the dataset to train on (1, 2, 3, or 4)
train_index = 4

# Load the training data
path = 'dl/partition_train_%i.pkl'%(train_index)
train_images = access_load_data(path, False)

In [None]:
# Run the model

# Additional hyperparameters
n_train_set_images = int(train_images['images'].shape[0])
n_train_dataset_portions = 2
bs = 54

metrics_dict = {}
for i,image_index in enumerate(range(0, n_train_set_images, int(n_train_set_images/n_train_dataset_portions))):

  x_train = train_images['images'][image_index:image_index+n_train_dataset_portions]
  y_train = train_images['labels'][image_index:image_index+n_train_dataset_portions]

  ds = TensorDataset(x_train, y_train)
  dl = DataLoader(ds, batch_size = bs, shuffle=True)

  model, metrics = run_cnn(model, epochs, learning_rate, loss_func, opt, dl)
  metrics_dict['round_'+str(i)] = metrics
  
  metrics_path = "/content/gdrive/My Drive/metrics_dict_train_%i_%i.pkl"%(train_index, i)
  f = open(metrics_path,"wb")
  pickle.dump(metrics_dict,f)

  print('Saving model')
  model_path = F'/content/gdrive/My Drive/cnn_train%i_%i.pt'%(train_index, i)
  torch.save(model.state_dict(), path)

  print('Finshed with set', str(i), 'of' + n_train_dataset_portions + 'images\nStarting next set.\n\n')

## Validation
Validation trains one epoch, and returns predictions and accuracies before back propogation.

In [None]:
# Load model
model_path = '/content/gdrive/My Drive/cnn_train_complete.pt'
model.load_state_dict(torch.load(model_path))

In [None]:
%time
# Load validation data
data_path = 'dl/partition_val.pkl'
val_images = access_load_data(data_path, False)
val_images['images'].shape

In [None]:
bs = 84
epochs = 1

x_val = val_images['images']
y_val = val_images['labels']

ds = TensorDataset(x_val, y_val)
dl = DataLoader(ds, batch_size = bs)

model, val_metrics = run_cnn(model, epochs, learning_rate, loss_func, opt, dl, val=True)

metrics_path = "/content/gdrive/My Drive/metrics_dict_val.pkl"
f = open(metrics_path,"wb")
pickle.dump(val_metrics,f)

print('Saving model')
model_path = F'/content/gdrive/My Drive/cnn_val.pt'
torch.save(model.state_dict(), model_path)

print('Finshed with validation set')

## Test

In [None]:
# Load model
model_path = '/content/gdrive/My Drive/cnn_val.pt'
model.load_state_dict(torch.load(model_path))

In [None]:
%%time
# Load validation data
test_partition_n = 2
data_path = 'dl/partition_test_%i.pkl'%(test_partition_n)
test_images = access_load_data(data_path, False)
test_images['images'].shape

In [None]:
bs = 84
epochs = 1

x_test = test_images['images']
y_test = test_images['labels']

ds = TensorDataset(x_test, y_test)
dl = DataLoader(ds, batch_size = bs)

model, val_metrics = run_cnn(model, epochs, learning_rate, loss_func, opt, dl, test=True)

metrics_path = "/content/gdrive/My Drive/metrics_dict_test_%i.pkl"%(test_partition_n)
f = open(metrics_path,"wb")
pickle.dump(val_metrics,f)

print('Finshed with test set', test_partition_n)

In [None]:
metrics_path = "/content/gdrive/My Drive/metrics_dict_test_%i.pkl"%(test_partition_n)
f = open(metrics_path,"wb")
pickle.dump(val_metrics,f)