<a href="https://colab.research.google.com/github/rfarwell/MPhys/blob/main/CNN4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# To do list
1. Figure out normalization when creating the datasets
2. Get to work with tensorboard to see if weights are changing
3. Issue: currently the network reaches approx minimum loss after just 7 batches of 1 epoch. This may be because the network is very simple or other factors.

In [1]:
!pip install torch torchvision
!pip install opencv-contrib-python
!pip install scikit-learn
!pip install SimpleITK

import numpy as np
import random
import os
import matplotlib.pyplot as plt
import SimpleITK as sitk
import torch

from mpl_toolkits.mplot3d import Axes3D
from torch.nn import Module
from torch.nn import Conv3d
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
from torch.nn import LogSoftmax
from torch import flatten
from torch import nn
from torch import reshape
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torch.optim import Adam
import torchvision.models as models
from torch.autograd import Variable

Collecting SimpleITK
  Downloading SimpleITK-2.1.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (48.4 MB)
[K     |████████████████████████████████| 48.4 MB 48.3 MB/s 
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.1.1


In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
# Connect to GPU is available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
# /content/gdrive/MyDrive/MPhys/Data/COLAB-Clinical-Data.csv
# Specify project folder location
project_folder = "/content/gdrive/MyDrive/MPhys/Data"
clinical_data_filename = "COLAB-Clinical-Data.csv"
print(os.path.join(project_folder, clinical_data_filename))

Using cuda device
/content/gdrive/MyDrive/MPhys/Data/COLAB-Clinical-Data.csv


In [4]:
def equalise_array_lengths(array_1, array_2) :
  """
  This functions takes in the arguments of two lists and makes sure they are returned as the same length.

  Rory Farwell 02/12/2021
  """
  # output_array = []
  if len(array_1) > len(array_2) :
    array_1 = array_1[:len(array_2)]
  elif len(array_1) < len(array_2) :
    array_2 = array_2[:len(array_1)]
  #print(np.vstack((array_1, array_2)))
  # output_array.append(array_1)
  # output_array.append(array_2)
  return (array_1, array_2)

def remove_same_elements(small_array, long_array) :
  """
  For use in the context, all the elements in small_array come from long_array.
  This function will remove all of the elements used in small_array from_long_array.  
  """
  for element in small_array :
    long_array.remove(element)
  return long_array

def create_subgroup(input_array, original_array_length, desired_percentage) :
  """
  This function outputs a subgroup array (e.g. training array) using a specified output array name,
  input array and percentage length
  """
  desired_length = int(original_array_length * desired_percentage)
  output_array = random.sample(input_array, desired_length)
  return output_array
  

# Open the metadata.csv file, convert to an array, and remove column headers
metadata_file = os.path.join(project_folder, clinical_data_filename)
metadata = np.genfromtxt(metadata_file, comments = '%', dtype="str", delimiter=",")
print(f"Length of metadata array is {len(metadata)}")

outcome_type = 1 #int(input("Select which outcome you are aiming to predict \n(1=Locoregional, 2=Distant Metastasis, 3=Death):"))
check_day = 3000 #int(input("Select the number of days at which to check for event:"))
which_patients = 1 #int(input("Do you want to include patients whose last follow up is before the check day? (no = 0, yes = 1):"))
patient_with_event = []
patient_no_event = []
outcomes_train = []
outcomes_test = []
images = []

check_day = 365 * 1.5 # This is defining the timeframe for which our CNN will consider the binary output (in days)

patient_IDs = metadata[:,0]
time_markers = metadata[:,8]
dead_statuses = metadata[:,9]

time_markers = time_markers.astype(np.float32)
dead_statuses = dead_statuses.astype(np.float32)


check_day_dead_statuses = []

counter = 0 

dead_counter = 0
alive_counter = 0
no_info_counter = 0

dead_patient_array = []
alive_patient_array = []

for i in range(len(dead_statuses)) :
  # counter+=1
  # print(counter)
  temp_dead_status = dead_statuses[i]
  #print(temp_dead_status)
  temp_time_marker = time_markers[i]
  #print(temp_time_marker)
  if temp_dead_status == 1 : #if the patient is dead
    #print('y')
    if temp_time_marker < check_day :
      check_day_dead_statuses.append(1) #confirms that the patient was dead after time 'check_day'
      dead_patient_array.append([patient_IDs[i], 1])
      dead_counter += 1
      continue
    elif temp_time_marker > check_day :
      check_day_dead_statuses.append(0)
      alive_patient_array.append([patient_IDs[i], 0])
      alive_counter += 1
      continue
  elif temp_dead_status == 0 : #if the patient is alive
    #print('n')
    if temp_time_marker < check_day :
      no_info_counter += 1
      continue
    elif temp_time_marker > check_day :
      check_day_dead_statuses.append(0)
      alive_patient_array.append([patient_IDs[i], 0])
      alive_counter += 1
      continue
print(f"Dead counter after {check_day} days: {dead_counter}")
print(f"Alive counter after {check_day} days: {alive_counter}")
print(f"No-info counter after {check_day} days: {no_info_counter}")

# print(len(dead_patient_array), dead_patient_array)
# print(len(alive_patient_array), alive_patient_array)

training_array = []
testing_array = []
validation_array = []

random.shuffle(dead_patient_array) #shuffling both arrays to ensure random selection of patient data
random.shuffle(alive_patient_array)

# equalising the length of the 'dead' and 'alive' arrays so that we can ensure optimum training proportions
new_dead_patient_array = equalise_array_lengths(dead_patient_array, alive_patient_array)[0]
new_alive_patient_array = equalise_array_lengths(dead_patient_array, alive_patient_array)[1]
print(f"The alive and dead arrays have been sorted (randomly) so that they are both of length {len(new_dead_patient_array)}")

# print(new_dead_patient_array)
# print(new_alive_patient_array)
# print(len(new_dead_patient_array))
# print(len(new_alive_patient_array))

equalised_array_length = len(new_alive_patient_array)

train_patients_dead = create_subgroup(new_dead_patient_array, equalised_array_length, 0.7)
train_patients_alive = create_subgroup(new_alive_patient_array, equalised_array_length, 0.7)
# print(len(train_patients_dead))
# print(len(train_patients_alive))

new_dead_patient_array = remove_same_elements(train_patients_dead, new_dead_patient_array)
new_alive_patient_array = remove_same_elements(train_patients_alive, new_alive_patient_array)
# print(len(new_dead_patient_array))
# print(len(new_alive_patient_array))

test_patients_dead = create_subgroup(new_dead_patient_array, equalised_array_length, 0.15)
test_patients_alive = create_subgroup(new_alive_patient_array, equalised_array_length, 0.15)
# print(len(test_patients_dead))
# print(len(test_patients_alive))

new_dead_patient_array = remove_same_elements(test_patients_dead, new_dead_patient_array)
new_alive_patient_array = remove_same_elements(test_patients_alive, new_alive_patient_array)
# print(len(new_dead_patient_array))
# print(len(new_alive_patient_array))

validate_patients_dead = create_subgroup(new_dead_patient_array, equalised_array_length, 0.15)
validate_patients_alive = create_subgroup(new_alive_patient_array, equalised_array_length, 0.15)
# print(len(validate_patients_dead))
# print(len(validate_patients_alive))

new_dead_patient_array = remove_same_elements(validate_patients_dead, new_dead_patient_array)
new_alive_patient_array = remove_same_elements(validate_patients_alive, new_alive_patient_array)
# print(len(new_dead_patient_array))
# print(len(new_alive_patient_array))


outcomes_train = train_patients_dead + train_patients_alive
outcomes_test = test_patients_dead + test_patients_alive
outcomes_validate = validate_patients_dead + validate_patients_alive

print(len(outcomes_train))

Length of metadata array is 100
Dead counter after 547.5 days: 60
Alive counter after 547.5 days: 40
No-info counter after 547.5 days: 0
The alive and dead arrays have been sorted (randomly) so that they are both of length 40
56


In [5]:
import os
from torchvision.io import read_image

# Normalize class added at 10pm 12/12/2021
class Normalize():
  def __init__(self):
    pass
  
  # def __call__(self, sample):
  #   inputs, targets = sample
  #   inputs = transforms.Normalize(mean = 0.5, std = 0.5)
  #   return inputs, targets
  def __call__(self,vol):
    vol =(vol-vol.mean())/vol.std()
    return vol

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean = 0.5, std = 0.5)] #added at 10:31pm 13/12/2021 to normalize the inputs
)

transform = transforms.Compose(
    [transforms.ToTensor(), Normalize() ] #added at 11:00pm 13/12/2021 to normalize the inputs. THIS NORMALIZES to mean = 0 and std = -1
)


class ImageDataset(Dataset) :
  def __init__(self, annotations, img_dir, transform = transform, target_transform = None) :
    self.img_labels = annotations
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

  def __len__(self) :
    return len(self.img_labels)

  def __getitem__(self,idx) :
    img_path = os.path.join(self.img_dir, self.img_labels[idx][0] + "-GTV-1.nii" )
    image_sitk = sitk.ReadImage(img_path)
    image = sitk.GetArrayFromImage(image_sitk)
    label = self.img_labels[idx][1]
    if self.transform :
      image = self.transform(image)
    if self.target_transform :
      label = self.target_transform(label)
    return image,label

training_data = ImageDataset(outcomes_train, os.path.join(project_folder, "Textured_Masks"), transform = transforms.ToTensor())
validation_data = ImageDataset(outcomes_validate, os.path.join(project_folder, "Textured_Masks"), transform = transforms.ToTensor())
test_data = ImageDataset(outcomes_test, os.path.join(project_folder, "Textured_Masks"), transform = transforms.ToTensor())
print(training_data[0][0][0][0][0])
print(training_data[0][0].mean(), training_data[0][0].std())
print(training_data[0][1])
print(training_data[1][0][0][0][0])
print(training_data[1][0].mean(), training_data[1][0].std())
print(training_data[1][1])

training_data = ImageDataset(outcomes_train, os.path.join(project_folder, "Textured_Masks"), transform = transform)
validation_data = ImageDataset(outcomes_validate, os.path.join(project_folder, "Textured_Masks"), transform = transform)
test_data = ImageDataset(outcomes_test, os.path.join(project_folder, "Textured_Masks"), transform = transform) 
print(training_data[0][0][0][0][0])
print(training_data[0][0].mean(), training_data[0][0].std())
print(training_data[0][1])
print(training_data[1][0][0][0][0])
print(training_data[1][0].mean(), training_data[1][0].std())
print(training_data[1][1])



tensor(-1024., dtype=torch.float64)
tensor(-1021.4073, dtype=torch.float64) tensor(52.0958, dtype=torch.float64)
1
tensor(-1024.)
tensor(-1022.9258) tensor(32.4395)
1
tensor(-0.0498, dtype=torch.float64)
tensor(-1.3302e-16, dtype=torch.float64) tensor(1.0000, dtype=torch.float64)
1
tensor(-0.0331)
tensor(5.7648e-07) tensor(1.)
1


In [6]:
train_dataloader = DataLoader(training_data, batch_size = 4, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = 4, shuffle = True)
validation_dataloader = DataLoader(validation_data, batch_size = 4, shuffle = True)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):   
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv3d(1,4,2,2)
        self.pool = nn.MaxPool3d(2,2)
        self.conv2 = nn.Conv3d(4,16,22)
        self.conv3 = nn.Conv3d(16,64,2,2)
        self.conv4 = nn.Conv3d(64,256,2,2)
        self.fc1 = nn.Linear(256,64)
        self.fc2 = nn.Linear(64,16)
        self.fc3 = nn.Linear(16,2)

    # Defining the forward pass    
    def forward(self, x):
        x = self.pool(F.leaky_relu(self.conv1(x)))
        x = self.pool(F.leaky_relu(self.conv2(x)))
        x = self.pool(F.leaky_relu(self.conv3(x)))
        x = self.pool(F.leaky_relu(self.conv4(x)))
        x = x.view(-1, 256)
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        # return F.leaky_relu(x)
        return x
  
model = CNN().to(device)

In [8]:
from torchsummary import summary
summary(model, (1,264,264,264), batch_size = 4)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1      [4, 4, 132, 132, 132]              36
         MaxPool3d-2         [4, 4, 66, 66, 66]               0
            Conv3d-3        [4, 16, 45, 45, 45]         681,488
         MaxPool3d-4        [4, 16, 22, 22, 22]               0
            Conv3d-5        [4, 64, 11, 11, 11]           8,256
         MaxPool3d-6           [4, 64, 5, 5, 5]               0
            Conv3d-7          [4, 256, 2, 2, 2]         131,328
         MaxPool3d-8          [4, 256, 1, 1, 1]               0
            Linear-9                    [4, 64]          16,448
           Linear-10                    [4, 16]           1,040
           Linear-11                     [4, 2]              34
Total params: 838,630
Trainable params: 838,630
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 280.76
Forwar

In [9]:
# loss and optimizer
learning_rate = 0.001
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [23]:
#training_loop
num_epochs = 1
n_total_steps = len(train_dataloader)
train_loss = []
valid_loss = []
avg_train_loss = []
avg_valid_loss = []

for epoch in range(num_epochs):
  # ======================================= TRAINING LOOP ======================================
  epoch_train_loss = 0 # will be used for plotting test vs valid loss curves
  n_training_samples = 0

  print(f'Training for epoch {epoch+1}')
  print('=============================================')
  

  for i, (images, labels) in enumerate(train_dataloader):
    # Reformatting input images to have 5 dimensions and casting to a float
    images = reshape(images, (images.shape[0],1 ,264,264,264))
    images = images.float()


    # turning labels of size one to one-hot labels 
    # e.g labels = (1,0,0,1) --> hot_labels [(0,1), (1,0), (1,0), 0,1]
    # This is need because nn.BCELogitsWithLoss() expects inputs of this format
    hot_labels = torch.empty((images.shape[0], 2))
    for index in range(len(labels)):
          if labels[index] == 0 :
            hot_labels[index,0] = 1
            hot_labels[index,1] = 0
          elif labels[index] == 1 :
            hot_labels[index,0] = 0
            hot_labels[index,1] = 1
    
    
    # Send images and one-hot labels to the device (cuda/GPU)
    images = images.to(device)
    hot_labels = hot_labels.to(device)


    # Forward pass
    outputs = model(images)
    loss = criterion(outputs, hot_labels)
    

    # Backwards pass
    optimizer.zero_grad() # Clear gradients before 
    loss.backward()
    optimizer.step()

    # Add the number of images in this batch to n_training_samples which will
    # be used when calculating the average loss per image in the training set
    n_training_samples += labels.shape[0]
    print(f'Number of samples completed after {i+1} batches = {n_training_samples}')
    

    # Updating the total training loss of this epoch
    # Printing loss for current batch and the new updated total
    # training loss for this epoch
    print(f'Loss of batch {i+1} = {loss.item():.2f}')
    epoch_train_loss += loss.item()
    print(f'Total training loss after batch {i+1} = {epoch_train_loss:.2f}')
    

    # Print a progress statement to ensure the network is running
    if (i+1)%7 == 0 :
      print(f'Epoch {epoch+1}/{num_epochs}, step {i+1}/{n_total_steps}, loss = {loss.item():.4f}')
  
  # Append the train_loss list with the total training loss for this epoch
  train_loss.append(epoch_train_loss)

  # Append the avg_train_loss list with the average training loss of this epoch
  avg_train_loss.append(epoch_train_loss/n_training_samples)
  
  print(f'Training loss array at end of epoch {epoch + 1}: {train_loss}. Total number of images used = {n_training_samples}')
  print(f'Finished training for epoch {epoch+1}')
  
  #================================================ VALIDATION LOOP =================================================
  print(f'Validation for epoch {epoch+1}')
  print('=============================================')
  with torch.no_grad():
    valid_epoch_loss = 0
    n_valid_correct = 0
    n_valid_samples = 0
    for images, labels in validation_dataloader :
      images = images = reshape(images, (images.shape[0],1 ,264,264,264))
      images = images.float()
      hot_labels = torch.empty((images.shape[0], 2))
      #print(new_labels.shape)
      for index in range(len(labels)):
          if labels[index] == 0 :
            hot_labels[index,0] = 1
            hot_labels[index,1] = 0
          elif labels[index] == 1 :
            hot_labels[index,0] = 0
            hot_labels[index,1] = 1
      images = images.to(device)
      hot_labels = hot_labels.to(device)
      outputs = model(images)
      # calculate loss of validation set
      loss = criterion(outputs, hot_labels)
      valid_epoch_loss += loss.item()

      # max returns (value, index) 
      _,predictions = torch.max(outputs, 1)
      _,targets = torch.max(hot_labels, 1)
      #print(f'predictions: {predictions}')
      #print(f'targets: {targets}')
      #print(f'correct in this batch: {(predictions == targets).sum().item()}')
      n_valid_samples += labels.shape[0]
      n_valid_correct += (predictions == targets).sum().item()
      #print(f'n_correct = {n_correct}. n_samples = {n_samples}')

    valid_loss.append(valid_epoch_loss)
    acc = (100*n_valid_correct)/n_valid_samples
    print(f'Accuracy on validation set for epoch {epoch+1} = {acc:.1f}%')

    print(f'Finished validation for epoch {epoch+1}')
    print('=============================================')

print('FINISHED TRAINING')
print(f'Training losses = {train_loss}')
print(f'Average training losses = {avg_train_loss}')
print(f'Validation losses = {valid_loss}')
#testing
with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for images, labels in test_dataloader :
      images = images = reshape(images, (images.shape[0],1 ,264,264,264))
      images = images.float()
      hot_labels = torch.empty((images.shape[0], 2))
      #print(new_labels.shape)
      for index in range(len(labels)):
          if labels[index] == 0 :
            hot_labels[index,0] = 1
            hot_labels[index,1] = 0
          elif labels[index] == 1 :
            hot_labels[index,0] = 0
            hot_labels[index,1] = 1
      images = images.to(device)
      hot_labels = hot_labels.to(device)
      outputs = model(images)
      # max returns (value, index) 
      _,predictions = torch.max(outputs, 1)
      _,targets = torch.max(hot_labels,1)
      #print(f'predictions: {predictions}')
      #print(f'targets: {targets}')
      n_samples += hot_labels.shape[0]
      n_correct += (predictions == targets).sum().item()
      #print(f'n_correct = {n_correct}. n_samples = {n_samples}')
    
    acc = (100*n_correct)/n_samples
    print(f'Accuracy on training set = {acc:.1f}%')


# time renewal


Training for epoch 1
Number of samples completed after 1 batches = 4
Loss of batch 1 = 0.67
Total training loss after batch 1 = 0.67
Number of samples completed after 2 batches = 8
Loss of batch 2 = 0.70
Total training loss after batch 2 = 1.38
Number of samples completed after 3 batches = 12
Loss of batch 3 = 0.64
Total training loss after batch 3 = 2.02
Number of samples completed after 4 batches = 16
Loss of batch 4 = 0.70
Total training loss after batch 4 = 2.72
Number of samples completed after 5 batches = 20
Loss of batch 5 = 0.70
Total training loss after batch 5 = 3.42
Number of samples completed after 6 batches = 24
Loss of batch 6 = 0.65
Total training loss after batch 6 = 4.07
Number of samples completed after 7 batches = 28
Loss of batch 7 = 0.70
Total training loss after batch 7 = 4.77
Epoch 1/1, step 7/14, loss = 0.6974
Number of samples completed after 8 batches = 32
Loss of batch 8 = 0.67
Total training loss after batch 8 = 5.44
Number of samples completed after 9 batch