# Dataloaders

In [1]:
# Importing libraries
import torchio as tio
import glob
import numpy as np
import random
import os

from collections import OrderedDict
from pathlib import Path

from tqdm import tqdm
import time

import torchio as tio
from torchio.transforms import (RescaleIntensity,RandomFlip,Compose, HistogramStandardization)

from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

In [2]:
# Define dictionary describing assignment of participants to the groups

# Groups
train_groups=['HCP']
test_dev_groups=['UoN']
test_groups=['CHIASM']

# Splits
train_split = 0.8
dev_split = 0.1
test_split = 0.1

# Dictionary with study design
design = {}

design['train']={}
design['train_dev']={}
design['test_dev']={}
design['test']={}

# Training data
for group in train_groups:

    # Idices of all subjects
    ids=[os.path.basename(path) for path in glob.glob('../../1_Data/1_Input/'+group+'/*')]
    
    # Randomize order
    random.shuffle(ids)
    
    # Find split ratios
    train_idx = np.int(np.floor(len(ids)*train_split))
    dev_idx = np.int(np.floor(len(ids)*(train_split+dev_split)))
    
    for i in range(len(ids)):
        
        path_to_folder='../../1_Data/1_Input/'+group+'/'+ids[i]+'/'
        
        files={}
        files['brain']=path_to_folder+'brain_skull-stripped.nii.gz'
        files['probs']=path_to_folder+'sampling_distribution.nii.gz'
        files['chiasm']=path_to_folder+'chiasm.nii.gz'
    
        if i+1<=train_idx:
            design['train'][ids[i]]=files
        elif i+1 > dev_idx:
            design['test'][ids[i]]=files
        else:
            design['train_dev'][ids[i]]=files
            
# Dev data
for group in test_dev_groups:

    # Idices of all subjects
    ids=[os.path.basename(path) for path in glob.glob('../../1_Data/1_Input/'+group+'/*')]
    
    for sub_id in ids:
        
        path_to_folder='../../1_Data/1_Input/'+group+'/'+sub_id+'/'
        
        files={}
        files['brain']=path_to_folder+'brain_skull-stripped.nii.gz'
        files['probs']=path_to_folder+'sampling_distribution.nii.gz'
        files['chiasm']=path_to_folder+'chiasm.nii.gz'
        
        design['test_dev'][sub_id]=files
    
# Test data
for group in test_groups:

    # Idices of all subjects
    ids=[os.path.basename(path) for path in glob.glob('../../1_Data/1_Input/'+group+'/*')]
    
    for sub_id in ids:
        
        path_to_folder='../../1_Data/1_Input/'+group+'/'+sub_id+'/'
        
        files={}
        files['brain']=path_to_folder+'brain_skull-stripped.nii.gz'
        files['probs']=path_to_folder+'sampling_distribution.nii.gz'
        files['chiasm']=path_to_folder+'chiasm.nii.gz'
        
        design['test'][sub_id]=files   

In [3]:
# Save the dictionary

In [4]:
# Dictionary with all images
subjects_list = {'train': [tio.Subject(
                            t1=tio.Image(design['train'][sub]['brain'], type=tio.INTENSITY),
                            probs = tio.Image(design['train'][sub]['probs'], type = tio.INTENSITY)) for sub in design['train']],
                'train_dev': [tio.Subject(
                            t1=tio.Image(design['train_dev'][sub]['brain'], type=tio.INTENSITY),
                            probs = tio.Image(design['train_dev'][sub]['probs'], type = tio.INTENSITY)) for sub in design['train_dev']],
                'test_dev':[tio.Subject(
                            t1=tio.Image(design['test_dev'][sub]['brain'], type=tio.INTENSITY),
                            probs = tio.Image(design['test_dev'][sub]['probs'], type = tio.INTENSITY)) for sub in design['test_dev']],
                'test':[tio.Subject(
                            t1=tio.Image(design['test'][sub]['brain'], type=tio.INTENSITY),
                            probs = tio.Image(design['test'][sub]['probs'], type = tio.INTENSITY)) for sub in design['test']]}

In [5]:
'''# Data preprocessing and augmentation

# Histogram standardization (to mitigate cross-site differences)

# For t1
t1_paths = [design['train'][sub]['brain'] for sub in design['train']]
t1_landmarks_path = Path('t1_landmarks.npy')

t1_landmarks = HistogramStandardization.train(t1_paths)
torch.save(t1_landmarks, t1_landmarks_path)

# For probs
probs_paths = [design['train'][sub]['probs'] for sub in design['train']]
probs_landmarks_path = Path('probs_landmarks.npy')

probs_landmarks = HistogramStandardization.train(probs_paths)
torch.save(probs_landmarks, probs_landmarks_path)

landmarks={'t1':t1_landmarks,#'t1_landmarks.npy',
            'probs': probs_landmarks}# 'probs_landmarks.npy'}

torch.save(landmarks, 'path_to_landmarks.pth')

standardize = HistogramStandardization('path_to_landmarks.pth')
'''
# Rescale
rescale = RescaleIntensity((0,1))

# Flip
flip = RandomFlip((0,1,2), flip_probability=0.5, p=0.25)

# Composing transforms - flip serves as data augmentation and is used only for training

transform_train = Compose([rescale, flip]) # leaving out standardization for now
transform_dev = Compose([rescale]) # leaving out standardization for now

100%|██████████| 888/888 [01:35<00:00,  9.34it/s]
100%|██████████| 888/888 [01:05<00:00, 13.62it/s]


In [6]:
# Torchio's (Pytorch's) Dataset
dataset = {'train': tio.SubjectsDataset(subjects_list['train'], transform = transform_train),
           'train_dev': tio.SubjectsDataset(subjects_list['train_dev'], transform = transform_dev),
           'test_dev': tio.SubjectsDataset(subjects_list['test_dev'], transform = transform_dev),
           'test': tio.SubjectsDataset(subjects_list['test'], transform = transform_dev)}

In [7]:
# Sampler
patch_size = (24,24,8)
queue_length = 600
samples_per_volume = 5

sampler = tio.data.WeightedSampler(patch_size,'probs')

In [8]:
# Dataloader
dataloader={'train': DataLoader( tio.Queue(dataset['train'], queue_length, samples_per_volume, sampler, num_workers=8), batch_size = 50),
            'train_dev': DataLoader( tio.Queue(dataset['train_dev'], queue_length, samples_per_volume, sampler, num_workers=8), batch_size = 50),
            'test_dev': DataLoader( tio.Queue(dataset['test_dev'], queue_length, samples_per_volume, sampler, num_workers=8), batch_size = 50),
            'test': DataLoader( tio.Queue(dataset['test'], queue_length, samples_per_volume, sampler, num_workers=8), batch_size = 50)
           }

In [9]:
# Testing

#num_epochs = 1

#model = torch.nn.Identity()

#for epoch_index in range(num_epochs):
#    for patches_batch in dataloader['test_dev']:
#        inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
#        targets = patches_batch['t1'][tio.DATA]  # key 'brain' is in subject
#        logits = model(inputs)  # model being an instance of torch.nn.Module
        
#inputs.shape
'''
fig = plt.figure(figsize=(20, 10))

for i in range(inputs.shape[0]):
    plt.subplot(5,8,i+1)
    plt.imshow(inputs[i,0,:,:,5],cmap='gray');
    
plt.show()'''

"\nfig = plt.figure(figsize=(20, 10))\n\nfor i in range(inputs.shape[0]):\n    plt.subplot(5,8,i+1)\n    plt.imshow(inputs[i,0,:,:,5],cmap='gray');\n    \nplt.show()"

# Network and parameters

In [10]:
# Try setting CUDA if possible
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu") 
    
print(device)

cuda


In [24]:
# U-Net architecture

class UNet(nn.Module):
    
    def __init__(self, in_channels=1, out_channels=1, init_features=10):
        super(UNet, self).__init__()
        
        # Parameter determining depth of layers when going down the network
        features = init_features
        
        # Encoding layers
        self.encoder1 = self.unet_block(in_channels, features, "enc1")
        self.pool1 = nn.AvgPool3d(kernel_size=2, stride=2, padding=0)
        self.encoder2 = self.unet_block(features, features*2, name='enc2')
        self.pool2 = nn.AvgPool3d(kernel_size=2, stride=2, padding=0)
        self.encoder3 = self.unet_block(features*2, features*4, name='enc3')
        self.pool3 = nn.AvgPool3d(kernel_size=2, stride=2, padding=0)

        # Bottleneck layer
        self.bottleneck = self.unet_block(features*4, features*4*2, name='bottleneck')
        
        # Decoding layers (where merge with prevois encoding layers occurs)        
        self.upconv3 = nn.ConvTranspose3d(features*4*2, features*4, kernel_size=2, stride=2)
        self.decoder3 = self.unet_block(features*4*2, features*4, name='dec3')
        
        self.upconv2 = nn.ConvTranspose3d(features*2*2, features*2, kernel_size=2, stride=2)
        self.decoder2 = self.unet_block(features*2*2, features*2, name='dec2')
        
        self.upconv1 = nn.ConvTranspose3d(features*2, features, kernel_size=2, stride=2)
        self.decoder1 = self.unet_block(features*2, features, name='dec1')
        
        # Final convolution - output equals number of output channels
        self.conv = nn.Conv3d(features, out_channels, kernel_size=1) 
        
    def forward(self,x):
        
        # Encoding
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool3(enc3))

        # Upconvolving, concatenating data from respective encoding phase and executing UNet block
        dec3 = self.upconv3(bottleneck)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        
        out_conv = self.conv(dec1)
        
        return torch.tanh(out_conv)
    
    def unet_block(self, in_channels, features, name):
        
        return nn.Sequential(OrderedDict([(name+'conv1',nn.Conv3d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False)),
                             (name+'bnorm1', nn.BatchNorm3d(num_features=features)),
                             (name+'relu1', nn.ReLU(inplace=True)),
                             (name+'conv2', nn.Conv3d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False)),
                             (name+'bnorm2', nn.BatchNorm3d(num_features=features)),
                             (name+'relu2', nn.ReLU(inplace=True))])
                            )

    def output_latent_representations(self,x):

        # Encoding
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool3(enc3))

        return bottleneck

In [25]:
unet = UNet(1,1,2)
unet.to(device)

UNet(
  (encoder1): Sequential(
    (enc1conv1): Conv3d(1, 2, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (enc1bnorm1): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu1): ReLU(inplace=True)
    (enc1conv2): Conv3d(2, 2, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (enc1bnorm2): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc1relu2): ReLU(inplace=True)
  )
  (pool1): AvgPool3d(kernel_size=2, stride=2, padding=0)
  (encoder2): Sequential(
    (enc2conv1): Conv3d(2, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (enc2bnorm1): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (enc2relu1): ReLU(inplace=True)
    (enc2conv2): Conv3d(4, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (enc2bnorm2): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, tra

In [26]:
#sum(p.numel() for p in unet.parameters() if p.requires_grad)

#print(torch.version.cuda)

In [27]:
# try processing

#outputs = unet(inputs.to(device))
#outputs = outputs.cpu().detach().numpy()

#fig = plt.figure(figsize=(20, 10))

#for i in range(outputs.shape[0]):
#    plt.subplot(10,10,i+1)
#    plt.imshow(outputs[i,0,:,:,5],cmap='gray');
    
#plt.show()

In [28]:
# Criterion
#criterion = DiceLoss()
criterion = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(params=unet.parameters(), lr=0.0005)

# Number of epochs
n_epochs = 100

In [29]:
#outcome = criterion(inputs.to('cpu'), outputs.to('cpu'))

In [30]:
# INVESTIGATE CROSS ENTROPY LOSS

#print(criterion(1,2))

#print(criterion(inputs,inputs.to('cpu')))


In [31]:
#inputs[0].shape

In [32]:
'''fig = plt.figure(figsize=(20, 10))

for i in range(inputs.shape[0]):
    plt.subplot(10,10,i+1)
    plt.imshow(inputs[i,0,:,:,5],cmap='gray');
    
plt.show()'''

"fig = plt.figure(figsize=(20, 10))\n\nfor i in range(inputs.shape[0]):\n    plt.subplot(10,10,i+1)\n    plt.imshow(inputs[i,0,:,:,5],cmap='gray');\n    \nplt.show()"

In [33]:
'''fig = plt.figure(figsize=(20, 10))

outputs = outputs.cpu().detach().numpy()

for i in range(outputs.shape[0]):
    plt.subplot(10,10,i+1)
    plt.imshow(outputs[i,0,:,:,5],cmap='gray');
    
plt.show()'''

"fig = plt.figure(figsize=(20, 10))\n\noutputs = outputs.cpu().detach().numpy()\n\nfor i in range(outputs.shape[0]):\n    plt.subplot(10,10,i+1)\n    plt.imshow(outputs[i,0,:,:,5],cmap='gray');\n    \nplt.show()"

# Training

In [34]:
# Function returning trained model
def train_network(n_epochs, dataloaders, model, optimizer, criterion, device, print_every, save_path):
    
    track_train_loss = []
    track_train_dev_loss = []
    track_test_dev_loss = []
    
    valid_loss_min = np.Inf
    
    model.to(device)
        
    for epoch in tqdm(range(1, n_epochs+1)):
        
        # Initialize loss monitoring variables
        train_loss = 0.0
        train_dev_loss = 0.0
        test_dev_loss = 0.0
        
        i = j = k = 0
        
        start = time.time()
        
        # Training
        model.train()
        
        for batch in dataloaders['train']:
            
            data = batch['t1']['data'].to(device)
            
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output, data)
            loss.backward()
            
            optimizer.step()
            
            train_loss += loss.item()
            i+=1
            
        track_train_loss.append(train_loss/i)
        
        # Validation on two datasets
        model.eval()
        
        for batch in dataloaders['train_dev']:
            
            data = batch['t1']['data'].to(device)
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output,data)
                
                train_dev_loss += loss.item()
                j+=1
                
        track_train_dev_loss.append(train_dev_loss/j)
        
        
        for batch in dataloaders['test_dev']:
            
            data = batch['t1']['data'].to(device)
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output,data)
                
                test_dev_loss += loss.item()
                k+=1
                
        track_test_dev_loss.append(train_dev_loss/k)
        
        # Print summary of epoch
        duration = time.time() - start

        print('END OF EPOCH: {} \tTraining loss per batch: {:.6f}\tTraining_dev loss per batch: {:.6f}\tTest_dev loss per batch: {:.6f}'.format(epoch, train_loss/i, train_dev_loss/j, test_dev_loss/k))
       
        
        ## Save the model if reached min validation loss
        if train_dev_loss + test_dev_loss < valid_loss_min:
            valid_loss_min = train_dev_loss + test_dev_loss
            torch.save(model.state_dict(),save_path)
                        
    # return trained model
    return model, track_train_loss, track_train_dev_loss, track_test_dev_loss         

In [35]:
trained_model, train_loss, train_dev_loss, test_dev_loss = train_network(200, dataloader, unet, optimizer, criterion, device, 500,'200ep_00005lr.pt')

  0%|          | 0/200 [00:00<?, ?it/s]


RuntimeError: Given transposed=1, weight of size [4, 4, 2, 2, 2], expected input[50, 8, 6, 6, 2] to have 4 channels, but got 8 channels instead

In [None]:
plt.plot(train_loss, label='train')
plt.plot(train_dev_loss, label='train_dev')
plt.plot(test_dev_loss, label='test_dev')
plt.ylim([0,0.05])
plt.legend()

In [None]:
# Load trained model
unet.load_state_dict(torch.load('300ep_000005lr.pt'))

In [None]:
# Evaluate the model on all datasets

unet.eval()

for group in ['train', 'train_dev', 'test_dev']:
    
    i=0
    track_loss = 0
    
    for batch in dataloader[group]: 
        
            data = batch['t1']['data'].to(device)
            
            with torch.no_grad():
                
                output = unet(data)
                loss = criterion(output,data)
                
                track_loss += loss.item()
                i+=1
                
    print(group, ': loss per batch = ', track_loss/i)

In [None]:
# Visualize the output (train_dev and test_dev)

# Generate sample from two dev datasets
train_dev_input = next(enumerate(dataloader['train_dev']))
test_dev_input = next(enumerate(dataloader['test_dev']))

In [None]:
# Run network on samples
train_dev_output = unet(train_dev_input[1]['t1']['data'].to(device))
test_dev_output = unet(test_dev_input[1]['t1']['data'].to(device))

In [None]:
# Extract latent representations

train_dev_latent = unet.output_latent_representations(train_dev_input[1]['t1']['data'].to(device))
test_dev_latent = unet.output_latent_representations(test_dev_input[1]['t1']['data'].to(device))

In [None]:
test_dev_latent.shape

In [None]:
# Visualize train_dev sample

inputs = train_dev_input[1]['t1']['data'].to('cpu')
latent = train_dev_latent.cpu().detach().numpy()
outputs = train_dev_output.cpu().detach().numpy()

fig = plt.figure(figsize=(20,75))

for i in range(outputs.shape[0]//3):
    
    plt.subplot(outputs.shape[0],6,6*i+1)
    plt.imshow(inputs[i,0,:,:,5],cmap='gray');
    
    plt.subplot(outputs.shape[0],6,6*i+2)
    plt.imshow(latent[i,0,:,:,0],cmap='gray');
    
    plt.subplot(outputs.shape[0],6,6*i+3)
    plt.imshow(latent[i,1,:,:,0],cmap='gray');
    
    plt.subplot(outputs.shape[0],6,6*i+4)
    plt.imshow(latent[i,2,:,:,0],cmap='gray');
    
    plt.subplot(outputs.shape[0],6,6*i+5)
    plt.imshow(outputs[i,0,:,:,5], cmap ='gray')
    
    plt.subplot(outputs.shape[0],6,6*i+6)
    plt.imshow(inputs[i,0,:,:,5]-outputs[i,0,:,:,5], cmap ='gray')
    
plt.show()

In [None]:
# Visualize test_dev sample

inputs = test_dev_input[1]['t1']['data'].to('cpu')
latent = test_dev_latent.cpu().detach().numpy()
outputs = test_dev_output.cpu().detach().numpy()

fig = plt.figure(figsize=(20,75))

for i in range(outputs.shape[0]//3):
    
    plt.subplot(outputs.shape[0],3,3*i+1)
    plt.imshow(inputs[i,0,:,:,5],cmap='gray');
        
    plt.subplot(outputs.shape[0],3,3*i+2)
    plt.imshow(outputs[i,0,:,:,5], cmap ='gray')
    
    plt.subplot(outputs.shape[0],3,3*i+3)
    plt.imshow(inputs[i,0,:,:,5]-outputs[i,0,:,:,5], cmap ='gray')
    
plt.show()

# Evaluate and visualize results for chiasms


In [None]:
# Dictionary with all chiasm images
subjects_chiasms_list = {'train': [tio.Subject(t1=tio.Image(design['train'][sub]['chiasm'], type=tio.INTENSITY)) for sub in design['train']],
                'train_dev': [tio.Subject(t1=tio.Image(design['train_dev'][sub]['chiasm'], type=tio.INTENSITY)) for sub in design['train_dev']],
                'test_dev':[tio.Subject(t1=tio.Image(design['test_dev'][sub]['chiasm'], type=tio.INTENSITY)) for sub in design['test_dev']],
                'test':[tio.Subject(t1=tio.Image(design['test'][sub]['chiasm'], type=tio.INTENSITY)) for sub in design['test']]}

In [None]:
# Dataset containing only chiasm images
dataset_chiasms = {'train': tio.SubjectsDataset(subjects_chiasms_list['train'], transform = Compose([rescale])),
           'train_dev': tio.SubjectsDataset(subjects_chiasms_list['train_dev'], transform = Compose([rescale])),
           'test_dev': tio.SubjectsDataset(subjects_chiasms_list['test_dev'], transform = Compose([rescale])),
           'test': tio.SubjectsDataset(subjects_chiasms_list['test'], transform = Compose([rescale]))}

In [None]:
# Standard dataloader with images of chiasm
# Dataloader
dataloader_chiasms={'train': DataLoader(dataset=dataset_chiasms['train'], batch_size = 10, shuffle=True, num_workers =8),
            'train_dev': DataLoader(dataset=dataset_chiasms['train_dev'], batch_size = 10, shuffle=True, num_workers =8),
            'test_dev': DataLoader(dataset=dataset_chiasms['test_dev'], batch_size = 10, shuffle=True, num_workers =8),
            'test': DataLoader(dataset=dataset_chiasms['test'], batch_size = 10, shuffle=True, num_workers =8)
           }

In [None]:
# Visualize results

unet.eval()

for group in ['train', 'train_dev', 'test_dev']:
    
    print(group)
    
    batch = next(enumerate(dataloader_chiasms[group]))
    
    data = batch[1]['t1']['data']
    
    output = unet(data.to(device))
    
    inputs = data.cpu().numpy()
    outputs = output.cpu().detach().numpy()
        
    fig = plt.figure(figsize=(20,20))
    
    for i in range(outputs.shape[0]):
        
        plt.subplot(outputs.shape[0],3,3*i+1)
        plt.imshow(inputs[i,0,:,:,5],cmap='gray');

        plt.subplot(outputs.shape[0],3,3*i+2)
        plt.imshow(outputs[i,0,:,:,5],cmap='gray');
        
        plt.subplot(outputs.shape[0],3,3*i+3)
        plt.imshow(inputs[i,0,:,:,5]-outputs[i,0,:,:,5],cmap='gray');
    
    
    plt.show()

In [None]:
batch = next(enumerate(dataloader_chiasms['train']))
print(batch[1]['t1']['data'].shape)