# Dataloaders

In [1]:
# Importing libraries
import torchio as tio
import glob
import numpy as np
import random
import os
import pickle

from collections import OrderedDict
from pathlib import Path

from tqdm import tqdm
import time

import torchio as tio
from torchio.transforms import (RescaleIntensity,RandomFlip,Compose, HistogramStandardization, CropOrPad, ToCanonical)

from sklearn.metrics import f1_score

from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

from Functions_classification_training import UNet_1_layer, UNet_2_layer, Classifier

In [2]:
with open('../subjects_dict.pkl', 'rb') as f:
    subjects_dict = pickle.load(f)
    
# Remove CHP1 and ACH1 from dictionary
subjects_dict['CHIASM']['control'].remove('CHP1')
subjects_dict['CHIASM']['control'].remove('ACH1')

In [3]:
# Function used for splitting the list
def splitter(list_to_be_splitted, number_of_groups):
    a, b = divmod(len(list_to_be_splitted), number_of_groups)
    return (list_to_be_splitted[i*a+min(i,b):(i+1)*a+min(i+1,b)] for i in range(number_of_groups))

In [4]:
# Function returning trained model
def train_network(n_epochs, dataloaders, model, optimizer, criterion, device, save_path):
    
    track_train_loss = []
    track_dev_train_loss = []
    track_test_loss = []
    
    track_train_f1 = []
    track_dev_train_f1 = []
    track_test_f1 = []
    
    valid_loss_min = np.Inf
    
    model.to(device)
        
    for epoch in tqdm(range(1, n_epochs+1)):
        
        # Initialize loss monitoring variables
        train_loss = 0.0
        dev_train_loss = 0.0
        test_loss = 0.0
                
        # Training
        model.train()
        
        acc_targets=[]
        acc_predictions=[]
        
        for batch in dataloaders['train']:
            
            data = batch['chiasm']['data'].to(device)
            data.requires_grad = True
            
            optimizer.zero_grad()
            
            output=model(data)
            
            loss = criterion(output[:,0], batch['label'].to(device).float())
            loss.backward()
            
            optimizer.step()
            
            train_loss+= (loss.item()*len(batch['label']))
            
            acc_targets+=batch['label'][:].numpy().tolist()
            acc_predictions+=output.round().detach().cpu().numpy().tolist()
            
        track_train_loss.append(train_loss/len(dict_kfold_combined_training['train']))        
        track_train_f1.append(f1_score(acc_targets, acc_predictions, average='weighted')) 
            
        # Validation on dev_train dataset
        model.eval()
        
        acc_targets=[]
        acc_predictions=[]
        
        for batch in dataloaders['dev_train']:
            
            data = batch['chiasm']['data'].to(device)
            data.requires_grad = True
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output[:,0], batch['label'].to(device).float())
                
                dev_train_loss+= (loss.item()*len(batch['label']))
                
                acc_targets+=batch['label'][:].numpy().tolist()
                acc_predictions+=output.round().detach().cpu().numpy().tolist()
                
        track_dev_train_loss.append(dev_train_loss/len(dict_kfold_combined_training['dev_train']))
        track_dev_train_f1.append(f1_score(acc_targets, acc_predictions, average='weighted')) 
        
        acc_targets=[]
        acc_predictions=[]
        
        for batch in dataloaders['test1']:
            
            data = batch['chiasm']['data'].to(device)
            data.requires_grad = True
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output[:,0], batch['label'].to(device).float())
                
                test_loss+= (loss.item()*len(batch['label']))
                
                acc_targets+=batch['label'][:].numpy().tolist()
                acc_predictions+=output.round().detach().cpu().numpy().tolist()
                
        track_test_loss.append(test_loss/len(dict_kfold_combined_training['test1']))
        track_test_f1.append(f1_score(acc_targets, acc_predictions, average='weighted')) 
        
        if epoch%500 ==0:
            print('END OF EPOCH: {} \n Training loss per image: {:.6f}\n Training_dev loss per image: {:.6f}\n Test_dev loss per image: {:.6f}'.format(epoch, train_loss/len(dict_kfold_combined_training['train']),dev_train_loss/len(dict_kfold_combined_training['dev_train']),test_loss/len(dict_kfold_combined_training['test1'])))
            
        ## Save the model if reached min validation loss and save the number of epoch               
        if dev_train_loss < valid_loss_min:
            valid_loss_min = dev_train_loss
            torch.save(model.state_dict(),save_path+'optimal_weights')
            last_updated_epoch = epoch
        
            with open(save_path+'number_epochs.txt','w') as f:
                print('Epoch:', str(epoch), file=f)  
                
        # Early stopping
        if (epoch - last_updated_epoch) == 1000:
            break
                                
    # return trained model
    return track_train_loss, track_dev_train_loss, track_test_loss, track_train_f1, track_dev_train_f1, track_test_f1

In [5]:
# Dictionary with splits
'''
for dataset in subjects_dict.keys():
    for label in subjects_dict[dataset].keys():
        if(dataset=='CHIASM' and label=='albinism'):
            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],9))
        else:
            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],8))
            
# Save the dictionary
with open('design_kfold.pkl','wb') as f:
    pickle.dump(subjects_dict,f)
'''

"\nfor dataset in subjects_dict.keys():\n    for label in subjects_dict[dataset].keys():\n        if(dataset=='CHIASM' and label=='albinism'):\n            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],9))\n        else:\n            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],8))\n            \n# Save the dictionary\nwith open('design_kfold.pkl','wb') as f:\n    pickle.dump(subjects_dict,f)\n"

In [6]:
# Histogram standardization (to mitigate cross-site differences) - shared by all datasets
chiasm_paths=[]

# Obtain paths of all chiasm images
for dataset in subjects_dict.keys():
    for label in subjects_dict[dataset].keys():
        for subject in subjects_dict[dataset][label]:
            chiasm_paths.append('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz')

chiasm_landmarks_path = Path('chiasm_landmarks.npy')    

chiasm_landmarks = HistogramStandardization.train(chiasm_paths)
torch.save(chiasm_landmarks, chiasm_landmarks_path)

landmarks={'chiasm': chiasm_landmarks}

standardize = HistogramStandardization(landmarks)

100%|████████████████████████████████████████████████████████████████████████| 1740/1740 [00:01<00:00, 1024.10it/s]


In [7]:
# Data preprocessing and augmentation - shared by all datasets

# Canonical
canonical = ToCanonical()

# Rescale
rescale = RescaleIntensity((0,1))

# Flip
flip = RandomFlip((0,1,2), flip_probability=0.5, p=0.5)

# Affine transformations
affine = tio.RandomAffine(degrees=5, translation=(2,2,2), center='image')

crop = CropOrPad((24,24,8))

# Elastic deformation
#elastic = tio.transforms.RandomElasticDeformation(num_control_points=4, max_displacement=4, locked_borders=1)

# Composing transforms - flip serves as data augmentation and is used only for training
transform_train = Compose([canonical, standardize, rescale, affine, flip, crop])
transform_dev = Compose([canonical, standardize, rescale, crop])

In [8]:
# Split the participants into 8 equal groups

#              train test1 test2
# control   80 - 10 - 0 - 10
# albinism  80 - 10 - 10 -0

groups=['train','dev_train','test1','test2']

if not os.path.exists('../../1_Data/4_K-fold_combined_extraction'):
    os.makedirs('../../1_Data/4_K-fold_combined_extraction')

for i in range(8):
    
    iteration=i
    
    output_folder='../../1_Data/4_K-fold_combined_extraction'+'/'+str(i)

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Load the dictionary
    with open('design_kfold.pkl','rb') as f:
        kfold_design = pickle.load(f)

    design_kfold_combined={}

    # test2 - (i+1)-th group from CHIASM albinism + i-th group from all control groups

    design_kfold_combined['test2']={}

    # CHIASM albinism
    design_kfold_combined['test2']['CHIASM']={}
    design_kfold_combined['test2']['CHIASM']['albinism']=kfold_design['CHIASM']['albinism'][i]
    design_kfold_combined['test2']['CHIASM']['control']=[]
    kfold_design['CHIASM']['albinism'].pop(i)

    # Other publicly available datasets of controls
    for dataset in ['ABIDE', 'Athletes', 'HCP', 'COBRE', 'Leipzig', 'MCIC']:

        design_kfold_combined['test2'][dataset]={}
        design_kfold_combined['test2'][dataset]['control']=kfold_design[dataset]['control'][i]
        kfold_design[dataset]['control'].pop(i)


    # test 1 - (i or i+1)-th group per controls and albinism from CHIASM and UoN datasets

    design_kfold_combined['test1']={}

    for dataset in ['CHIASM','UoN']:
        design_kfold_combined['test1'][dataset]={}
        for label in kfold_design[dataset].keys():
            design_kfold_combined['test1'][dataset][label]=kfold_design[dataset][label][i]
            kfold_design[dataset][label].pop(i)


    # dev_train - (i+1)-th group

    design_kfold_combined['dev_train']={}

    for dataset in kfold_design.keys():
        design_kfold_combined['dev_train'][dataset]={}
        for label in kfold_design[dataset].keys():
            if i==7:
                design_kfold_combined['dev_train'][dataset][label]=kfold_design[dataset][label][0]
                kfold_design[dataset][label].pop(0)
            else:
                design_kfold_combined['dev_train'][dataset][label]=kfold_design[dataset][label][i]
                kfold_design[dataset][label].pop(i)


    # train - rest

    design_kfold_combined['train']={}

    for dataset in kfold_design.keys():
        design_kfold_combined['train'][dataset]={}
        for label in kfold_design[dataset].keys():            
            design_kfold_combined['train'][dataset][label]=[item for sublist in kfold_design[dataset][label] for item in sublist]

    # Save the design
    with open(output_folder+'/kfold_design_'+str(i)+'.pkl','wb') as f:
        pickle.dump(design_kfold_combined, f)

    # Sanity check by counting total number of entries
    #total=0
    #for k in design_kfold_combined.keys():
    #    for l in design_kfold_combined[k].keys():
    #        for m in design_kfold_combined[k][l].keys():
    #            #print(k,l,m,len(design_kfold_combined[k][l][m]))
    #            total+=len(design_kfold_combined[k][l][m])
    #print(total, design_kfold_combined['test1']['UoN']['albinism'])


    # Torchio's subjects' dictionary + upsample the albinism group, so it matches controls in train and dev_train + add labels

    print(i)
    #for group in design_kfold_combined.keys():
    #    total_con=0
    #    total_alb=0
    #    for dataset in design_kfold_combined[group].keys():
    #        for label in design_kfold_combined[group][dataset].keys():
    #            if label == 'control':
    #                total_con += len(design_kfold_combined[group][dataset][label])
    #            else:
    #                total_alb += len(design_kfold_combined[group][dataset][label])
    #            #print(group,dataset,label, len(design_kfold_combined[group][dataset][label]) )
    #    print(group, total_con, total_alb)
    #print('\n')
    
    dict_kfold_combined_training={}

    for group in design_kfold_combined.keys():

        dict_kfold_combined_training[group]=[]

        # Calculate the number of albinism and controls, calculate the scaling coefficient
        num_control=0
        num_albinism=0

        for dataset in design_kfold_combined[group].keys():

            num_control+=len(design_kfold_combined[group][dataset]['control'])

            if dataset in ['CHIASM', 'UoN']:
                num_albinism+=len(design_kfold_combined[group][dataset]['albinism'])

        scaling_factor=int(num_control/num_albinism)

        # Create Torchio's subject for listed IDs, for train & dev_train upsample the albinism
        for dataset in design_kfold_combined[group].keys():

            # If test just aggregate all the data
            if (group=='test2' or group == 'test1'):

                for label in design_kfold_combined[group][dataset].keys():

                    if label=='albinism':
                        label_as=1
                    elif label=='control':
                        label_as=0

                    dict_kfold_combined_training[group]+=[tio.Subject(chiasm=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz', type=tio.INTENSITY),
                                                                        label=label_as) for subject in design_kfold_combined[group][dataset][label]]

            # otherwise upsample albinism by calculated scaling_factor
            else:

                for label in design_kfold_combined[group][dataset].keys():

                    if label=='control':

                        label_as=0

                        dict_kfold_combined_training[group]+=[tio.Subject(chiasm=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz', type=tio.INTENSITY),
                                                                        label=label_as) for subject in design_kfold_combined[group][dataset][label]]

                    if label=='albinism':

                        label_as=1

                        for i in range(scaling_factor):

                            dict_kfold_combined_training[group]+=[tio.Subject(chiasm=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz', type=tio.INTENSITY),
                                                                              label=label_as) for subject in design_kfold_combined[group][dataset][label]] 

                            
    #for group in dict_kfold_combined_training.keys():
    #    print(len(dict_kfold_combined_training[group]))
    #print('\n')
    
    
    datasets_list={}

    for group in dict_kfold_combined_training.keys():

        if group =='train':

            datasets_list[group] = tio.SubjectsDataset(dict_kfold_combined_training[group], transform=transform_train)

        else:

            datasets_list[group] = tio.SubjectsDataset(dict_kfold_combined_training[group], transform=transform_dev)


    # Create dataloaders
    dataloaders_chiasm={'train': DataLoader(dataset=datasets_list['train'], batch_size=10, shuffle=True, num_workers=8),
                       'dev_train': DataLoader(dataset=datasets_list['dev_train'], batch_size=10, shuffle=True, num_workers=8),
                       'test1': DataLoader(dataset=datasets_list['test1'], batch_size=10, shuffle=True, num_workers=8),
                       'test2': DataLoader(dataset=datasets_list['test2'], batch_size=10, shuffle=True, num_workers=8)}

    # Try setting CUDA if possible
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu") 

    # Criterion
    criterion = nn.BCELoss()

    model_parameters=[[1,2,2,1,256]]
    learning_rates = [0.00005]
    n_epochs=8000

    folder=output_folder

    for parameters in model_parameters:
        for learning_rate in learning_rates:

            # Initialize the proper model
            classifying_network = Classifier(parameters[0],parameters[1], parameters[2], parameters[3], parameters[4])
            classifying_network.load_state_dict(torch.load('../../1_Data/4_K-fold_combined/'+str(iteration)+'/1_2_2_1_256_5e-05/optimal_weights'))
            classifying_network.freeze_classification()

            # Optimizer    
            optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, classifying_network.parameters()), lr=learning_rate)
            #optimizer = torch.optim.Adam(params=classifying_network.parameters(), lr=0.00005)

            # Create output folder
            data_folder = folder+'/'+str(parameters[0])+'_'+str(parameters[1])+'_'+str(parameters[2])+'_'+str(parameters[3])+'_'+str(parameters[4])+'_'+str(learning_rate)+'/'
            os.makedirs(data_folder, exist_ok=True)

            # Train & save weights
            train_loss, dev_train_loss, test_loss, train_f1, dev_train_f1, test_f1 = train_network(n_epochs, dataloaders_chiasm, classifying_network, optimizer, criterion, device, data_folder)

            # Save losses
            with open(data_folder+'train_loss.pkl', 'wb') as f:
                pickle.dump(train_loss, f)

            with open(data_folder+'dev_train_loss.pkl', 'wb') as f:
                pickle.dump(dev_train_loss, f)

            with open(data_folder+'test_loss.pkl', 'wb') as f:
                pickle.dump(test_loss, f)

            with open(data_folder+'train_f1.pkl', 'wb') as f:
                pickle.dump(train_f1, f)

            with open(data_folder+'dev_train_f1.pkl', 'wb') as f:
                pickle.dump(dev_train_f1, f)

            with open(data_folder+'test_f1.pkl', 'wb') as f:
                pickle.dump(test_f1, f)
                
    

0


  6%|████▌                                                                    | 500/8000 [24:26<6:22:27,  3.06s/it]

END OF EPOCH: 500 
 Training loss per image: 0.009157
 Training_dev loss per image: 0.036462
 Test_dev loss per image: 1.068799


 12%|█████████                                                               | 1000/8000 [50:02<5:40:48,  2.92s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.006921
 Training_dev loss per image: 0.036471
 Test_dev loss per image: 0.923181


 14%|██████████▎                                                             | 1141/8000 [57:00<5:42:44,  3.00s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

1


  6%|████▌                                                                    | 500/8000 [25:04<6:43:09,  3.23s/it]

END OF EPOCH: 500 
 Training loss per image: 0.003691
 Training_dev loss per image: 0.060238
 Test_dev loss per image: 2.662048


 12%|█████████                                                               | 1000/8000 [50:47<5:59:15,  3.08s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.001856
 Training_dev loss per image: 0.070412
 Test_dev loss per image: 2.156930


 19%|█████████████▏                                                        | 1500/8000 [1:15:35<5:18:29,  2.94s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.003721
 Training_dev loss per image: 0.117950
 Test_dev loss per image: 2.292683


 24%|████████████████▋                                                     | 1905/8000 [1:35:51<5:06:41,  3.02s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

2


  6%|████▌                                                                    | 500/8000 [25:35<6:33:03,  3.14s/it]

END OF EPOCH: 500 
 Training loss per image: 0.057826
 Training_dev loss per image: 0.261303
 Test_dev loss per image: 3.654088


 12%|█████████                                                               | 1000/8000 [51:00<6:02:34,  3.11s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.057665
 Training_dev loss per image: 0.344043
 Test_dev loss per image: 4.143208


 14%|█████████▊                                                              | 1094/8000 [55:55<5:53:02,  3.07s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

3


  6%|████▌                                                                    | 500/8000 [26:57<7:19:46,  3.52s/it]

END OF EPOCH: 500 
 Training loss per image: 0.003512
 Training_dev loss per image: 0.011938
 Test_dev loss per image: 1.079044


 12%|█████████                                                               | 1000/8000 [55:13<6:35:51,  3.39s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.003265
 Training_dev loss per image: 0.008890
 Test_dev loss per image: 0.799368


 19%|█████████████▏                                                        | 1500/8000 [1:23:30<6:51:54,  3.80s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.002852
 Training_dev loss per image: 0.013108
 Test_dev loss per image: 0.635923


 21%|██████████████▎                                                       | 1642/8000 [1:32:41<5:58:55,  3.39s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

4


  6%|████▌                                                                    | 500/8000 [28:22<7:05:50,  3.41s/it]

END OF EPOCH: 500 
 Training loss per image: 0.053056
 Training_dev loss per image: 0.085836
 Test_dev loss per image: 1.745799


 12%|█████████                                                               | 1000/8000 [56:42<6:33:01,  3.37s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.041341
 Training_dev loss per image: 0.141856
 Test_dev loss per image: 2.617060


 19%|█████████████▏                                                        | 1500/8000 [1:21:21<5:18:17,  2.94s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.030415
 Training_dev loss per image: 0.161283
 Test_dev loss per image: 2.444938


 19%|█████████████▏                                                        | 1512/8000 [1:21:59<5:51:49,  3.25s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

5


  6%|████▌                                                                    | 500/8000 [24:27<6:07:23,  2.94s/it]

END OF EPOCH: 500 
 Training loss per image: 0.038950
 Training_dev loss per image: 0.011035
 Test_dev loss per image: 0.893373


 12%|█████████                                                               | 1000/8000 [48:56<5:42:39,  2.94s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.040143
 Training_dev loss per image: 0.018810
 Test_dev loss per image: 0.946135


 19%|█████████████▏                                                        | 1500/8000 [1:13:24<5:18:04,  2.94s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.027130
 Training_dev loss per image: 0.028395
 Test_dev loss per image: 1.102117


 21%|██████████████▍                                                       | 1654/8000 [1:20:59<5:10:43,  2.94s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

6


  6%|████▌                                                                    | 500/8000 [24:30<6:05:38,  2.93s/it]

END OF EPOCH: 500 
 Training loss per image: 0.002311
 Training_dev loss per image: 0.000384
 Test_dev loss per image: 0.047278


 12%|█████████                                                               | 1000/8000 [49:01<6:04:55,  3.13s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.004718
 Training_dev loss per image: 0.000301
 Test_dev loss per image: 0.054434


 19%|█████████████▏                                                        | 1500/8000 [1:15:08<5:33:35,  3.08s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.002375
 Training_dev loss per image: 0.000245
 Test_dev loss per image: 0.062858


 25%|█████████████████▌                                                    | 2000/8000 [1:41:27<5:01:03,  3.01s/it]

END OF EPOCH: 2000 
 Training loss per image: 0.003550
 Training_dev loss per image: 0.000233
 Test_dev loss per image: 0.020619


 31%|█████████████████████▉                                                | 2500/8000 [2:06:53<4:32:44,  2.98s/it]

END OF EPOCH: 2500 
 Training loss per image: 0.000659
 Training_dev loss per image: 0.000520
 Test_dev loss per image: 0.023185


 38%|██████████████████████████▎                                           | 3000/8000 [2:31:53<4:05:09,  2.94s/it]

END OF EPOCH: 3000 
 Training loss per image: 0.003278
 Training_dev loss per image: 0.000303
 Test_dev loss per image: 0.512356


 43%|██████████████████████████████▏                                       | 3450/8000 [2:54:06<3:49:37,  3.03s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

7


  6%|████▌                                                                    | 500/8000 [25:43<6:24:35,  3.08s/it]

END OF EPOCH: 500 
 Training loss per image: 0.003416
 Training_dev loss per image: 0.015067
 Test_dev loss per image: 0.061564


 12%|█████████                                                               | 1000/8000 [51:21<5:50:44,  3.01s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.003011
 Training_dev loss per image: 0.017248
 Test_dev loss per image: 0.074811


 19%|█████████████▏                                                        | 1500/8000 [1:16:12<5:21:39,  2.97s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.002374
 Training_dev loss per image: 0.018701
 Test_dev loss per image: 0.021141


 24%|████████████████▊                                                     | 1921/8000 [1:37:07<5:07:20,  3.03s/it]
