# Dataloaders

In [1]:
# Importing libraries
import torchio as tio
import glob
import numpy as np
import random
import os
import pickle

from collections import OrderedDict
from pathlib import Path

from tqdm import tqdm
import time

import torchio as tio
from torchio.transforms import (RescaleIntensity,RandomFlip,Compose, HistogramStandardization, CropOrPad, ToCanonical)

from sklearn.metrics import f1_score

from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

from Functions_classification_training import UNet_1_layer, UNet_2_layer, Classifier

In [2]:
with open('../subjects_dict.pkl', 'rb') as f:
    subjects_dict = pickle.load(f)
    
# Remove CHP1 and ACH1 from dictionary
subjects_dict['CHIASM']['control'].remove('CHP1')
subjects_dict['CHIASM']['control'].remove('ACH1')

In [3]:
# Function used for splitting the list
def splitter(list_to_be_splitted, number_of_groups):
    a, b = divmod(len(list_to_be_splitted), number_of_groups)
    return (list_to_be_splitted[i*a+min(i,b):(i+1)*a+min(i+1,b)] for i in range(number_of_groups))

In [4]:
# Function returning trained model
def train_network(n_epochs, dataloaders, model, optimizer, criterion, device, save_path):
    
    track_train_loss = []
    track_dev_train_loss = []
    track_test_loss = []
    
    track_train_f1 = []
    track_dev_train_f1 = []
    track_test_f1 = []
    
    valid_loss_min = np.Inf
    
    model.to(device)
        
    for epoch in tqdm(range(1, n_epochs+1)):
        
        # Initialize loss monitoring variables
        train_loss = 0.0
        dev_train_loss = 0.0
        test_loss = 0.0
                
        # Training
        model.train()
        
        acc_targets=[]
        acc_predictions=[]
        
        for batch in dataloaders['train']:
            
            data = batch['chiasm']['data'].to(device)
            data.requires_grad = True
            
            optimizer.zero_grad()
            
            output=model(data)
            
            loss = criterion(output[:,0], batch['label'].to(device).float())
            loss.backward()
            
            optimizer.step()
            
            train_loss+= (loss.item()*len(batch['label']))
            
            acc_targets+=batch['label'][:].numpy().tolist()
            acc_predictions+=output.round().detach().cpu().numpy().tolist()
            
        track_train_loss.append(train_loss/len(dict_kfold_combined_training['train']))        
        track_train_f1.append(f1_score(acc_targets, acc_predictions, average='weighted')) 
            
        # Validation on dev_train dataset
        model.eval()
        
        acc_targets=[]
        acc_predictions=[]
        
        for batch in dataloaders['dev_train']:
            
            data = batch['chiasm']['data'].to(device)
            data.requires_grad = True
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output[:,0], batch['label'].to(device).float())
                
                dev_train_loss+= (loss.item()*len(batch['label']))
                
                acc_targets+=batch['label'][:].numpy().tolist()
                acc_predictions+=output.round().detach().cpu().numpy().tolist()
                
        track_dev_train_loss.append(dev_train_loss/len(dict_kfold_combined_training['dev_train']))
        track_dev_train_f1.append(f1_score(acc_targets, acc_predictions, average='weighted')) 
        
        acc_targets=[]
        acc_predictions=[]
        
        for batch in dataloaders['test1']:
            
            data = batch['chiasm']['data'].to(device)
            data.requires_grad = True
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output[:,0], batch['label'].to(device).float())
                
                test_loss+= (loss.item()*len(batch['label']))
                
                acc_targets+=batch['label'][:].numpy().tolist()
                acc_predictions+=output.round().detach().cpu().numpy().tolist()
                
        track_test_loss.append(test_loss/len(dict_kfold_combined_training['test1']))
        track_test_f1.append(f1_score(acc_targets, acc_predictions, average='weighted')) 
        
        if epoch%500 ==0:
            print('END OF EPOCH: {} \n Training loss per image: {:.6f}\n Training_dev loss per image: {:.6f}\n Test_dev loss per image: {:.6f}'.format(epoch, train_loss/len(dict_kfold_combined_training['train']),dev_train_loss/len(dict_kfold_combined_training['dev_train']),test_loss/len(dict_kfold_combined_training['test1'])))
            
        ## Save the model if reached min validation loss and save the number of epoch               
        if dev_train_loss < valid_loss_min:
            valid_loss_min = dev_train_loss
            torch.save(model.state_dict(),save_path+'optimal_weights')
            last_updated_epoch = epoch
        
            with open(save_path+'number_epochs.txt','w') as f:
                print('Epoch:', str(epoch), file=f)  
                
        # Early stopping
        if (epoch - last_updated_epoch) == 1000:
            break
                                
    # return trained model
    return track_train_loss, track_dev_train_loss, track_test_loss, track_train_f1, track_dev_train_f1, track_test_f1

In [5]:
# Dictionary with splits
'''
for dataset in subjects_dict.keys():
    for label in subjects_dict[dataset].keys():
        if(dataset=='CHIASM' and label=='albinism'):
            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],9))
        else:
            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],8))
            
# Save the dictionary
with open('design_kfold.pkl','wb') as f:
    pickle.dump(subjects_dict,f)
'''

"\nfor dataset in subjects_dict.keys():\n    for label in subjects_dict[dataset].keys():\n        if(dataset=='CHIASM' and label=='albinism'):\n            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],9))\n        else:\n            subjects_dict[dataset][label]=list(splitter(subjects_dict[dataset][label],8))\n            \n# Save the dictionary\nwith open('design_kfold.pkl','wb') as f:\n    pickle.dump(subjects_dict,f)\n"

In [6]:
# Histogram standardization (to mitigate cross-site differences) - shared by all datasets
chiasm_paths=[]

# Obtain paths of all chiasm images
for dataset in subjects_dict.keys():
    for label in subjects_dict[dataset].keys():
        for subject in subjects_dict[dataset][label]:
            chiasm_paths.append('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz')

chiasm_landmarks_path = Path('chiasm_landmarks.npy')    

chiasm_landmarks = HistogramStandardization.train(chiasm_paths)
torch.save(chiasm_landmarks, chiasm_landmarks_path)

landmarks={'chiasm': chiasm_landmarks}

standardize = HistogramStandardization(landmarks)

100%|████████████████████████████████████████████████████████████████████████| 1740/1740 [00:01<00:00, 1003.91it/s]


In [7]:
# Data preprocessing and augmentation - shared by all datasets

# Canonical
canonical = ToCanonical()

# Rescale
rescale = RescaleIntensity((0,1))

# Flip
flip = RandomFlip((0,1,2), flip_probability=0.5, p=0.5)

# Affine transformations
affine = tio.RandomAffine(degrees=5, translation=(2,2,2), center='image')

crop = CropOrPad((24,24,8))

# Elastic deformation
#elastic = tio.transforms.RandomElasticDeformation(num_control_points=4, max_displacement=4, locked_borders=1)

# Composing transforms - flip serves as data augmentation and is used only for training
transform_train = Compose([canonical, standardize, rescale, affine, flip, crop])
transform_dev = Compose([canonical, standardize, rescale, crop])

In [8]:
# Split the participants into 8 equal groups

#              train test1 test2
# control   80 - 10 - 0 - 10
# albinism  80 - 10 - 10 -0

groups=['train','dev_train','test1','test2']

if not os.path.exists('../../1_Data/4_K-fold_combined_extraction_learning_rate'):
    os.makedirs('../../1_Data/4_K-fold_combined_extraction_learning_rate')

for i in range(8):
    
    iteration=i
    
    output_folder='../../1_Data/4_K-fold_combined_extraction_learning_rate'+'/'+str(i)

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Load the dictionary
    with open('design_kfold.pkl','rb') as f:
        kfold_design = pickle.load(f)

    design_kfold_combined={}

    # test2 - (i+1)-th group from CHIASM albinism + i-th group from all control groups

    design_kfold_combined['test2']={}

    # CHIASM albinism
    design_kfold_combined['test2']['CHIASM']={}
    design_kfold_combined['test2']['CHIASM']['albinism']=kfold_design['CHIASM']['albinism'][i]
    design_kfold_combined['test2']['CHIASM']['control']=[]
    kfold_design['CHIASM']['albinism'].pop(i)

    # Other publicly available datasets of controls
    for dataset in ['ABIDE', 'Athletes', 'HCP', 'COBRE', 'Leipzig', 'MCIC']:

        design_kfold_combined['test2'][dataset]={}
        design_kfold_combined['test2'][dataset]['control']=kfold_design[dataset]['control'][i]
        kfold_design[dataset]['control'].pop(i)


    # test 1 - (i or i+1)-th group per controls and albinism from CHIASM and UoN datasets

    design_kfold_combined['test1']={}

    for dataset in ['CHIASM','UoN']:
        design_kfold_combined['test1'][dataset]={}
        for label in kfold_design[dataset].keys():
            design_kfold_combined['test1'][dataset][label]=kfold_design[dataset][label][i]
            kfold_design[dataset][label].pop(i)


    # dev_train - (i+1)-th group

    design_kfold_combined['dev_train']={}

    for dataset in kfold_design.keys():
        design_kfold_combined['dev_train'][dataset]={}
        for label in kfold_design[dataset].keys():
            if i==7:
                design_kfold_combined['dev_train'][dataset][label]=kfold_design[dataset][label][0]
                kfold_design[dataset][label].pop(0)
            else:
                design_kfold_combined['dev_train'][dataset][label]=kfold_design[dataset][label][i]
                kfold_design[dataset][label].pop(i)


    # train - rest

    design_kfold_combined['train']={}

    for dataset in kfold_design.keys():
        design_kfold_combined['train'][dataset]={}
        for label in kfold_design[dataset].keys():            
            design_kfold_combined['train'][dataset][label]=[item for sublist in kfold_design[dataset][label] for item in sublist]

    # Save the design
    with open(output_folder+'/kfold_design_'+str(i)+'.pkl','wb') as f:
        pickle.dump(design_kfold_combined, f)

    # Sanity check by counting total number of entries
    #total=0
    #for k in design_kfold_combined.keys():
    #    for l in design_kfold_combined[k].keys():
    #        for m in design_kfold_combined[k][l].keys():
    #            #print(k,l,m,len(design_kfold_combined[k][l][m]))
    #            total+=len(design_kfold_combined[k][l][m])
    #print(total, design_kfold_combined['test1']['UoN']['albinism'])


    # Torchio's subjects' dictionary + upsample the albinism group, so it matches controls in train and dev_train + add labels

    print(i)
    #for group in design_kfold_combined.keys():
    #    total_con=0
    #    total_alb=0
    #    for dataset in design_kfold_combined[group].keys():
    #        for label in design_kfold_combined[group][dataset].keys():
    #            if label == 'control':
    #                total_con += len(design_kfold_combined[group][dataset][label])
    #            else:
    #                total_alb += len(design_kfold_combined[group][dataset][label])
    #            #print(group,dataset,label, len(design_kfold_combined[group][dataset][label]) )
    #    print(group, total_con, total_alb)
    #print('\n')
    
    dict_kfold_combined_training={}

    for group in design_kfold_combined.keys():

        dict_kfold_combined_training[group]=[]

        # Calculate the number of albinism and controls, calculate the scaling coefficient
        num_control=0
        num_albinism=0

        for dataset in design_kfold_combined[group].keys():

            num_control+=len(design_kfold_combined[group][dataset]['control'])

            if dataset in ['CHIASM', 'UoN']:
                num_albinism+=len(design_kfold_combined[group][dataset]['albinism'])

        scaling_factor=int(num_control/num_albinism)

        # Create Torchio's subject for listed IDs, for train & dev_train upsample the albinism
        for dataset in design_kfold_combined[group].keys():

            # If test just aggregate all the data
            if (group=='test2' or group == 'test1'):

                for label in design_kfold_combined[group][dataset].keys():

                    if label=='albinism':
                        label_as=1
                    elif label=='control':
                        label_as=0

                    dict_kfold_combined_training[group]+=[tio.Subject(chiasm=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz', type=tio.INTENSITY),
                                                                        label=label_as) for subject in design_kfold_combined[group][dataset][label]]

            # otherwise upsample albinism by calculated scaling_factor
            else:

                for label in design_kfold_combined[group][dataset].keys():

                    if label=='control':

                        label_as=0

                        dict_kfold_combined_training[group]+=[tio.Subject(chiasm=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz', type=tio.INTENSITY),
                                                                        label=label_as) for subject in design_kfold_combined[group][dataset][label]]

                    if label=='albinism':

                        label_as=1

                        for i in range(scaling_factor):

                            dict_kfold_combined_training[group]+=[tio.Subject(chiasm=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/chiasm.nii.gz', type=tio.INTENSITY),
                                                                              label=label_as) for subject in design_kfold_combined[group][dataset][label]] 

                            
    #for group in dict_kfold_combined_training.keys():
    #    print(len(dict_kfold_combined_training[group]))
    #print('\n')
    
    
    datasets_list={}

    for group in dict_kfold_combined_training.keys():

        if group =='train':

            datasets_list[group] = tio.SubjectsDataset(dict_kfold_combined_training[group], transform=transform_train)

        else:

            datasets_list[group] = tio.SubjectsDataset(dict_kfold_combined_training[group], transform=transform_dev)


    # Create dataloaders
    dataloaders_chiasm={'train': DataLoader(dataset=datasets_list['train'], batch_size=10, shuffle=True, num_workers=8),
                       'dev_train': DataLoader(dataset=datasets_list['dev_train'], batch_size=10, shuffle=True, num_workers=8),
                       'test1': DataLoader(dataset=datasets_list['test1'], batch_size=10, shuffle=True, num_workers=8),
                       'test2': DataLoader(dataset=datasets_list['test2'], batch_size=10, shuffle=True, num_workers=8)}

    # Try setting CUDA if possible
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu") 

    # Criterion
    criterion = nn.BCELoss()

    model_parameters=[[1,2,2,1,256]]
    learning_rates = [0.0005]
    n_epochs=8000

    folder=output_folder

    for parameters in model_parameters:
        for learning_rate in learning_rates:

            # Initialize the proper model
            classifying_network = Classifier(parameters[0],parameters[1], parameters[2], parameters[3], parameters[4])
            classifying_network.load_state_dict(torch.load('../../1_Data/4_K-fold_combined/'+str(iteration)+'/1_2_2_1_256_5e-05/optimal_weights'))
            classifying_network.freeze_classification()

            # Optimizer    
            optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, classifying_network.parameters()), lr=learning_rate)
            #optimizer = torch.optim.Adam(params=classifying_network.parameters(), lr=0.00005)

            # Create output folder
            data_folder = folder+'/'+str(parameters[0])+'_'+str(parameters[1])+'_'+str(parameters[2])+'_'+str(parameters[3])+'_'+str(parameters[4])+'_'+str(learning_rate)+'/'
            os.makedirs(data_folder, exist_ok=True)

            # Train & save weights
            train_loss, dev_train_loss, test_loss, train_f1, dev_train_f1, test_f1 = train_network(n_epochs, dataloaders_chiasm, classifying_network, optimizer, criterion, device, data_folder)

            # Save losses
            with open(data_folder+'train_loss.pkl', 'wb') as f:
                pickle.dump(train_loss, f)

            with open(data_folder+'dev_train_loss.pkl', 'wb') as f:
                pickle.dump(dev_train_loss, f)

            with open(data_folder+'test_loss.pkl', 'wb') as f:
                pickle.dump(test_loss, f)

            with open(data_folder+'train_f1.pkl', 'wb') as f:
                pickle.dump(train_f1, f)

            with open(data_folder+'dev_train_f1.pkl', 'wb') as f:
                pickle.dump(dev_train_f1, f)

            with open(data_folder+'test_f1.pkl', 'wb') as f:
                pickle.dump(test_f1, f)
                
    

0


  6%|████▌                                                                    | 500/8000 [26:37<6:38:59,  3.19s/it]

END OF EPOCH: 500 
 Training loss per image: 0.011216
 Training_dev loss per image: 0.065965
 Test_dev loss per image: 1.751088


 12%|█████████                                                               | 1000/8000 [52:45<5:55:58,  3.05s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.011320
 Training_dev loss per image: 0.044103
 Test_dev loss per image: 1.160551


 14%|█████████▉                                                              | 1098/8000 [57:47<6:03:18,  3.16s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

1


  6%|████▌                                                                    | 500/8000 [25:33<6:21:55,  3.06s/it]

END OF EPOCH: 500 
 Training loss per image: 0.002533
 Training_dev loss per image: 0.028369
 Test_dev loss per image: 2.116751


 12%|█████████                                                               | 1000/8000 [51:05<5:59:56,  3.09s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.002027
 Training_dev loss per image: 0.038738
 Test_dev loss per image: 3.067864


 19%|█████████████▏                                                        | 1500/8000 [1:16:37<5:31:38,  3.06s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.004184
 Training_dev loss per image: 0.036440
 Test_dev loss per image: 3.039875


 24%|█████████████████▏                                                    | 1959/8000 [1:40:06<5:08:43,  3.07s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

2


  6%|████▌                                                                    | 500/8000 [25:32<6:23:52,  3.07s/it]

END OF EPOCH: 500 
 Training loss per image: 0.045569
 Training_dev loss per image: 0.265317
 Test_dev loss per image: 2.471508


 12%|█████████                                                               | 1000/8000 [51:04<5:57:38,  3.07s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.037037
 Training_dev loss per image: 0.192158
 Test_dev loss per image: 2.615460


 19%|█████████████▏                                                        | 1500/8000 [1:16:39<5:31:58,  3.06s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.033535
 Training_dev loss per image: 0.237462
 Test_dev loss per image: 2.783979


 24%|████████████████▌                                                     | 1895/8000 [1:36:53<5:12:09,  3.07s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

3


  6%|████▌                                                                    | 500/8000 [25:35<6:23:59,  3.07s/it]

END OF EPOCH: 500 
 Training loss per image: 0.002272
 Training_dev loss per image: 0.014593
 Test_dev loss per image: 0.922518


 12%|█████████                                                               | 1000/8000 [51:11<5:56:14,  3.05s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.002389
 Training_dev loss per image: 0.021338
 Test_dev loss per image: 0.784763


 19%|█████████████▏                                                        | 1500/8000 [1:16:46<5:32:24,  3.07s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.001205
 Training_dev loss per image: 0.018415
 Test_dev loss per image: 1.074311


 25%|█████████████████▌                                                    | 2000/8000 [1:42:22<5:07:10,  3.07s/it]

END OF EPOCH: 2000 
 Training loss per image: 0.002099
 Training_dev loss per image: 0.024548
 Test_dev loss per image: 1.088626


 31%|█████████████████████▉                                                | 2500/8000 [2:07:58<4:41:20,  3.07s/it]

END OF EPOCH: 2500 
 Training loss per image: 0.001193
 Training_dev loss per image: 0.021254
 Test_dev loss per image: 0.978087


 33%|██████████████████████▉                                               | 2625/8000 [2:14:25<4:35:14,  3.07s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

4


  6%|████▌                                                                    | 500/8000 [25:32<6:21:18,  3.05s/it]

END OF EPOCH: 500 
 Training loss per image: 0.032740
 Training_dev loss per image: 0.125357
 Test_dev loss per image: 2.646341


 12%|█████████                                                               | 1000/8000 [51:04<5:55:28,  3.05s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.027955
 Training_dev loss per image: 0.158279
 Test_dev loss per image: 2.841789


 19%|█████████████▏                                                        | 1500/8000 [1:16:36<5:32:19,  3.07s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.028300
 Training_dev loss per image: 0.130969
 Test_dev loss per image: 2.874740


 25%|█████████████████▌                                                    | 2000/8000 [1:42:13<5:08:21,  3.08s/it]

END OF EPOCH: 2000 
 Training loss per image: 0.018851
 Training_dev loss per image: 0.099332
 Test_dev loss per image: 2.283489


 28%|███████████████████▊                                                  | 2264/8000 [1:55:50<4:53:30,  3.07s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

5


  6%|████▌                                                                    | 500/8000 [25:45<6:25:34,  3.08s/it]

END OF EPOCH: 500 
 Training loss per image: 0.024899
 Training_dev loss per image: 0.095526
 Test_dev loss per image: 2.932807


 12%|█████████                                                               | 1000/8000 [51:32<6:00:12,  3.09s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.032119
 Training_dev loss per image: 0.039189
 Test_dev loss per image: 1.664609


 15%|██████████▌                                                           | 1204/8000 [1:02:06<5:50:33,  3.10s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

6


  6%|████▌                                                                    | 500/8000 [25:48<6:27:14,  3.10s/it]

END OF EPOCH: 500 
 Training loss per image: 0.005129
 Training_dev loss per image: 0.000283
 Test_dev loss per image: 1.387370


 12%|█████████                                                               | 1000/8000 [51:37<6:00:22,  3.09s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.009497
 Training_dev loss per image: 0.000971
 Test_dev loss per image: 1.013940


 19%|█████████████▏                                                        | 1500/8000 [1:17:27<5:35:04,  3.09s/it]

END OF EPOCH: 1500 
 Training loss per image: 0.005320
 Training_dev loss per image: 0.001847
 Test_dev loss per image: 0.885683


 25%|█████████████████▌                                                    | 2000/8000 [1:43:16<5:08:55,  3.09s/it]

END OF EPOCH: 2000 
 Training loss per image: 0.001372
 Training_dev loss per image: 0.000124
 Test_dev loss per image: 0.641684


 31%|█████████████████████▉                                                | 2500/8000 [2:09:10<4:47:28,  3.14s/it]

END OF EPOCH: 2500 
 Training loss per image: 0.003129
 Training_dev loss per image: 0.000105
 Test_dev loss per image: 0.971308


 32%|██████████████████████▋                                               | 2599/8000 [2:14:26<4:39:22,  3.10s/it]
  0%|                                                                                     | 0/8000 [00:00<?, ?it/s]

7


  6%|████▌                                                                    | 500/8000 [25:48<6:28:39,  3.11s/it]

END OF EPOCH: 500 
 Training loss per image: 0.005498
 Training_dev loss per image: 0.064644
 Test_dev loss per image: 0.065955


 12%|█████████                                                               | 1000/8000 [51:38<6:02:04,  3.10s/it]

END OF EPOCH: 1000 
 Training loss per image: 0.005475
 Training_dev loss per image: 0.296456
 Test_dev loss per image: 0.042625


 13%|█████████                                                               | 1001/8000 [51:44<6:01:49,  3.10s/it]
