# Dataloaders

In [1]:
# Importing libraries
import torchio as tio
import glob
import numpy as np
import random
import os

from collections import OrderedDict
from pathlib import Path

from tqdm import tqdm
import time

import torchio as tio
from torchio.transforms import (RescaleIntensity,RandomFlip,Compose, HistogramStandardization, RandomAffine, RandomNoise, ToCanonical)

from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

import pickle

import copy

from Networks_Training import UNet_1_layer, UNet_2_layer

In [2]:
# List IDs of all participants with albinism
ids_albinism=['Nyst01','Nyst02','Nyst03','Nyst04','Nyst05','Nyst06','Nyst07','Nyst08','Nyst09',
              'Nyst10','Nyst11','Nyst12','Nyst13','Nyst16','Nyst20','Nyst21','Nyst24', 'Nyst25',
              'Nyst31','Nyst35','Nyst37','Nyst43','Nyst45','ALB1','ALB2','ALB3','ALB4',
              'ALB5','ALB6','ALB7','ALB8','ALB9'] 

In [3]:
# One-time creation of dictionary with listed ids of participants. It's fixed througout the experiment to better control for data poisoning

datasets=['ABIDE','Athletes','HCP','COBRE','Leipzig','UoN','CHIASM','MCIC']

subjects_dict={}

for dataset in datasets:
    
    if dataset in ['ABIDE','Athletes','HCP','COBRE','Leipzig','MCIC']:
       
        ids=[path.split('/')[-2] for path in glob.glob('../../1_Data/1_Input/'+dataset+'/*/mask_optic_chiasm.nii.gz')]
        random.shuffle(ids)
    
        subjects_dict[dataset]={}
        subjects_dict[dataset]['control']=ids
    
    if dataset in ['CHIASM','UoN']:
        
        ids_con = ids=[path.split('/')[-2] for path in glob.glob('../../1_Data/1_Input/'+dataset+'/*/mask_optic_chiasm.nii.gz') if path.split('/')[-2] not in ids_albinism]
        ids_alb = ids=[path.split('/')[-2] for path in glob.glob('../../1_Data/1_Input/'+dataset+'/*/mask_optic_chiasm.nii.gz') if path.split('/')[-2] in ids_albinism]
        
        random.shuffle(ids_con)
        random.shuffle(ids_alb)
        
        subjects_dict[dataset]={}
        subjects_dict[dataset]['control']=ids_con
        subjects_dict[dataset]['albinism']=ids_alb
    
#print(subjects_dict)

# Save the dictionary storing all the ids in fixed (beforehand randomized) order
#with open('../subjects_dict.pkl', 'wb') as f:
#    pickle.dump(subjects_dict, f)


In [3]:
# Load the general dictionary 
with open('../subjects_dict.pkl', 'rb') as f:
    subjects_dict = pickle.load(f)

In [6]:
for dataset in subjects_dict.keys():
    for group in subjects_dict[dataset].keys():
        print(dataset,group,len(subjects_dict[dataset][group]))

ABIDE control 355
Athletes control 42
HCP control 1065
COBRE control 60
Leipzig control 133
UoN control 20
UoN albinism 23
CHIASM control 10
CHIASM albinism 9
MCIC control 25


In [8]:
subjects_dict['CHIASM']['albinism']

['ALB2', 'ALB6', 'ALB1', 'ALB4', 'ALB3', 'ALB7', 'ALB9', 'ALB8', 'ALB5']

In [9]:
#total=0
#
#for dataset in subjects_dict.keys():
#    for label in subjects_dict[dataset].keys():
#        print(dataset,label,len(subjects_dict[dataset][label]))
#        
#        total+=len(subjects_dict[dataset][label])
#        
#print('total',total)

In [10]:
# Assign the participants to train/dev_train/dev_test/test groups for the purpose of AE training

split_training=[0.0,0.8,0.9,1.0,1.0]
split_testing=[0.0,0.0,0.0,0.15,1.0]

groups=['train','dev_train', 'dev_test', 'test']

design_ae_training={}

for i in range(len(groups)):
    
    if i==3:
        continue
    else:
        design_ae_training[groups[i]]={}

        for dataset in ['ABIDE','Athletes','HCP','COBRE','Leipzig','MCIC']:

            design_ae_training[groups[i]][dataset]={}

            number_participants = len(subjects_dict[dataset]['control'])
            start = np.int(np.floor(number_participants*split_training[i]))
            end = np.int(np.floor(number_participants*split_training[i+1]))

            #print(groups[i],dataset,len(subjects_dict[dataset]['control'][start:end]))
            design_ae_training[groups[i]][dataset]['control']=subjects_dict[dataset]['control'][start:end]
                       
# Do the same for test
for i in [2,3]:
    
    if i==3:
        design_ae_training[groups[i]]={}
    
    for dataset in ['CHIASM','UoN']:
            
        design_ae_training[groups[i]][dataset]={}
        
        for label in ['control','albinism']:
            
            design_ae_training[groups[i]][dataset][label]={}

            number_participants = len(subjects_dict[dataset][label])
            start = np.int(np.floor(number_participants*split_testing[i]))
            end = np.int(np.floor(number_participants*split_testing[i+1]))

            #print(groups[i],dataset,label,len(subjects_dict[dataset][label][start:end]))
            design_ae_training[groups[i]][dataset][label]=subjects_dict[dataset][label][start:end]
            
# Check the number of participants in each group

#for a in design_ae_training.keys():
#    print('\n')
#    for b in design_ae_training[a].keys():
#        for c in design_ae_training[a][b].keys():
#            print(a,b,c, len(design_ae_training[a][b][c]))

# Save
#with open('design_ae_training.pkl', 'wb') as f:
#    pickle.dump(design_ae_training, f)


In [2]:
# Load the dictionary for AE training 
with open('design_ae_training.pkl', 'rb') as f:
    design_ae_training = pickle.load(f)

In [3]:
for a in design_ae_training.keys():
    print('\n')
    for b in design_ae_training[a].keys():
        for c in design_ae_training[a][b].keys():
            print(a,b,c, len(design_ae_training[a][b][c]))



train ABIDE control 284
train Athletes control 33
train HCP control 852
train COBRE control 48
train Leipzig control 106
train MCIC control 20


dev_train ABIDE control 35
dev_train Athletes control 4
dev_train HCP control 106
dev_train COBRE control 6
dev_train Leipzig control 13
dev_train MCIC control 2


dev_test ABIDE control 36
dev_test Athletes control 5
dev_test HCP control 107
dev_test COBRE control 6
dev_test Leipzig control 14
dev_test MCIC control 3
dev_test CHIASM control 1
dev_test CHIASM albinism 1
dev_test UoN control 3
dev_test UoN albinism 3


test CHIASM control 9
test CHIASM albinism 8
test UoN control 17
test UoN albinism 20


In [4]:
design_ae_training

{'train': {'ABIDE': {'control': ['A00032339',
    'A00032673',
    'A00032696',
    'A00032633',
    'A00032348',
    'A00032743',
    'A00032611',
    'A00032358',
    'A00033283',
    'A00032540',
    'A00032781',
    'A00033275',
    'A00032638',
    'A00032554',
    'A00032634',
    'A00032808',
    'A00032701',
    'A00032372',
    'A00032594',
    'A00032541',
    'A00032530',
    'A00032658',
    'A00032601',
    'A00032786',
    'A00032799',
    'A00033259',
    'A00032703',
    'A00032618',
    'A00032723',
    'A00033263',
    'A00032735',
    'A00032539',
    'A00032403',
    'A00032748',
    'A00032368',
    'A00032382',
    'A00032367',
    'A00032556',
    'A00032797',
    'A00032636',
    'A00032809',
    'A00032815',
    'A00032599',
    'A00032387',
    'A00032295',
    'A00032779',
    'A00033277',
    'A00032798',
    'A00033281',
    'A00032775',
    'A00032774',
    'A00032379',
    'A00032374',
    'A00032814',
    'A00032583',
    'A00032794',
    'A00032376',
  

In [12]:
# Create dictionary with data required for creation of dataset and dataloader
dict_ae_training=copy.deepcopy(design_ae_training)

'''
# train and dev_train (combined, randomized data)
for group in ['train','dev_train']:
    
    all_subjects=[]
    
    for dataset in design_ae_training[group].keys():
        for label in design_ae_training[group][dataset].keys():
                        
            all_subjects+=[tio.Subject(t1=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/t1w_1mm_iso_brain.nii.gz', type = tio.INTENSITY),
                                        probs=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/sampling_distribution.nii.gz', type = tio.INTENSITY)) 
                                        for subject in design_ae_training[group][dataset][label]]

    dict_ae_training[group]=all_subjects
'''           
for group in design_ae_training.keys():
    for dataset in design_ae_training[group].keys():
        for label in design_ae_training[group][dataset].keys():
                        
            dict_ae_training[group][dataset][label]=[tio.Subject(t1=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/t1w_1mm_iso_brain.nii.gz', type = tio.INTENSITY),
                                                              probs=tio.Image('../../1_Data/1_Input/'+dataset+'/'+subject+'/sampling_distribution.nii.gz', type = tio.INTENSITY)) 
                                                  for subject in design_ae_training[group][dataset][label]]

In [13]:
# Histogram standardization (to mitigate cross-site differences)
# Standardization is performed on all datasets

# Save paths of all images
images_paths=[]
probs_paths=[]

for group in design_ae_training.keys():
    for dataset in design_ae_training[group].keys():
        for label in design_ae_training[group][dataset].keys():
            for subject in design_ae_training[group][dataset][label]:
                images_paths.append('../../1_Data/1_Input/'+dataset+'/'+subject+'/t1w_1mm_iso_brain.nii.gz')
                probs_paths.append('../../1_Data/1_Input/'+dataset+'/'+subject+'/sampling_distribution.nii.gz')

images_landmarks_paths = Path('images_landmarks.npy') 
probs_landmarks_paths = Path('probs_landmarks.npy') 

images_landmarks = HistogramStandardization.train(images_paths)
probs_landmarks = HistogramStandardization.train(probs_paths)

torch.save(images_landmarks, images_landmarks_paths)
torch.save(probs_landmarks, probs_landmarks_paths)

landmarks={'t1': images_landmarks,
          'probs': probs_landmarks}

standardize = HistogramStandardization(landmarks)

100%|██████████| 1742/1742 [09:44<00:00,  2.98it/s]
100%|██████████| 1742/1742 [06:52<00:00,  4.22it/s]


In [14]:
# Transforms

# Rescale
rescale = RescaleIntensity((0,1))

# Flip
flip = RandomFlip((0,1,2), flip_probability=0.5, p=0.25)

# Composing transforms 
transform_train = Compose([standardize, rescale, flip]) 
transform_dev = Compose([standardize, rescale]) 

In [15]:
# Create Torchio dataset

dataset_ae_training = {}


# train
#dataset_ae_training['train']=tio.SubjectsDataset(dict_ae_training['train'], transform=transform_train)
          
# dev_train
#dataset_ae_training['dev_train']=tio.SubjectsDataset(dict_ae_training['dev_train'], transform=transform_dev)
    
# dev_test and test
for group in dict_ae_training.keys():
    
    dataset_ae_training[group]={}
    
    for dataset in dict_ae_training[group].keys():
        
        dataset_ae_training[group][dataset]={}
            
        for label in dict_ae_training[group][dataset].keys():
            
            dataset_ae_training[group][dataset][label]=tio.SubjectsDataset(dict_ae_training[group][dataset][label], transform=transform_dev)
            

In [16]:
# Sampler
patch_size = (24,24,8)
queue_length = 200
samples_per_volume = 5

sampler = tio.data.WeightedSampler(patch_size,'probs')

In [17]:
# Concatenate train and dev_train datasets (dev_test and test remain as they are)
concatenated_datasets={}

for group in ['train','dev_train']:

    entry=[]
    
    for dataset in dataset_ae_training[group].keys():
    
        for labels in dataset_ae_training[group][dataset].keys():
            
            entry.append(dataset_ae_training[group][dataset][labels])
   
    #print(entry)
    concatenated_datasets[group]=torch.utils.data.ConcatDataset(entry)

In [18]:
# Define dataloader

dataloader = {}

# train & dev_train
dataloader['train']= DataLoader(tio.Queue(concatenated_datasets['train'], queue_length, samples_per_volume, sampler, num_workers=6, shuffle_subjects=True, shuffle_patches=True), batch_size=20, num_workers=0)
dataloader['dev_train']= DataLoader(tio.Queue(concatenated_datasets['dev_train'], queue_length, samples_per_volume, sampler, num_workers=6, shuffle_subjects=True, shuffle_patches=True), batch_size=20, num_workers=0)

# dev_test and test
'''
for group in ['dev_test','test']:
    dataloader[group]={}
    for dataset in dataset_ae_training[group].keys():
        dataloader[group][dataset]={}
        for label in dataset_ae_training[group][dataset].keys():
            dataloader[group][dataset][label]=DataLoader(tio.Queue(dataset_ae_training[group][dataset][label], queue_length, samples_per_volume, sampler, num_workers=6, shuffle_subjects=True, shuffle_patches=True), batch_size = 25, num_workers=0)
'''

"\nfor group in ['dev_test','test']:\n    dataloader[group]={}\n    for dataset in dataset_ae_training[group].keys():\n        dataloader[group][dataset]={}\n        for label in dataset_ae_training[group][dataset].keys():\n            dataloader[group][dataset][label]=DataLoader(tio.Queue(dataset_ae_training[group][dataset][label], queue_length, samples_per_volume, sampler, num_workers=6, shuffle_subjects=True, shuffle_patches=True), batch_size = 25, num_workers=0)\n"

In [19]:
# Testing
#model = torch.nn.Identity()

#for patches_batch in dataloader['dev_test']['MCIC']['control']:
    #print(patches_batch)
#    inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
#    targets = patches_batch['t1'][tio.DATA]  # key 'brain' is in subject
#    logits = model(inputs)  # model being an instance of torch.nn.Module

In [20]:
#inputs.shape

#fig = plt.figure(figsize=(20, 10))

#for i in range(inputs.shape[0]):
#    plt.subplot(5,8,i+1)
#    plt.imshow(inputs[i,0,:,:,5],cmap='gray');
    
#plt.show()

# Network and parameters

In [21]:
# Try setting CUDA if possible
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu") 
    
print(device)

cuda


In [22]:
#print(sum(p.numel() for p in model.parameters() if p.requires_grad))

In [26]:
# Criterion
#criterion = DiceLoss()
criterion = nn.MSELoss()

# Number of epochs
n_epochs = 50

# Training

In [27]:
# Function returning trained model
def train_network(n_epochs, dataloaders, model, optimizer, criterion, device, save_path):
    
    track_train_loss = []
    track_dev_train_loss = []
    
    valid_loss_min = np.Inf
    
    model.to(device)
        
    for epoch in tqdm(range(1, n_epochs+1)):
        
        # Initialize loss monitoring variables
        train_loss = 0.0
        dev_train_loss = 0.0
        
        i=0
        j=0
        
        # train
        model.train()
        
        for batch in dataloaders['train']:
            
            data = batch['t1']['data'].to(device)
            
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output, data)
            loss.backward()
            
            optimizer.step()
            
            train_loss += loss.item()
            i+=1
            
        track_train_loss.append(train_loss/i)
        
        # dev_train
        model.eval()
        
        for batch in dataloaders['dev_train']:
            
            data = batch['t1']['data'].to(device)
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output,data)
                
                dev_train_loss += loss.item()
                j+=1
                
        track_dev_train_loss.append(dev_train_loss/j)

        print('END OF EPOCH: {} \tTraining loss per batch: {:.6f}\tTraining_dev loss per image: {:.6f}'.format(epoch, train_loss/i, dev_train_loss/j))
           
        ## Save the model if reached min validation loss
        if dev_train_loss  < valid_loss_min:
            valid_loss_min = dev_train_loss
            torch.save(model.state_dict(),save_path+'optimal_weights')
            last_updated_epoch = epoch
            
            with open(save_path+'number_epochs.txt','w') as f:
                print('Epoch:', str(epoch), file=f)
                
        # Early stopping
        if (epoch - last_updated_epoch) == 5:
            break
                        
    # return trained model
    return track_train_loss, track_dev_train_loss         

In [28]:
# training for 1-layer network

model_parameters=[[4,1],[2,2]]

folder='../../1_Data/2_Trained_AE/'

for parameters in model_parameters:
    
    print(parameters)
        
    # Initialize the proper model
    unet = UNet_1_layer(1,1,parameters[0],parameters[1])
    
    # Optimizer
    optimizer = torch.optim.Adam(params=unet.parameters(), lr=0.001)
    
    # Create output folder
    data_folder = folder+'/1_layer_'+str(parameters[0])+'_'+str(parameters[1])+'/'
    os.makedirs(data_folder, exist_ok=True)
    
    # Train & save weights
    train_loss, dev_train_loss = train_network(n_epochs, dataloader, unet, optimizer, criterion, device, data_folder)
    
    # Save losses
    with open(data_folder+'train_loss.pkl', 'wb') as f:
        pickle.dump(train_loss, f)
        
    with open(data_folder+'dev_train_loss.pkl', 'wb') as f:
        pickle.dump(dev_train_loss, f)


[4, 1]


  2%|▏         | 1/50 [07:58<6:30:28, 478.14s/it]

END OF EPOCH: 1 	Training loss per batch: 0.038422	Training_dev loss per image: 0.006503


  4%|▍         | 2/50 [15:39<6:18:22, 472.97s/it]

END OF EPOCH: 2 	Training loss per batch: 0.006484	Training_dev loss per image: 0.004503


  6%|▌         | 3/50 [23:22<6:08:20, 470.21s/it]

END OF EPOCH: 3 	Training loss per batch: 0.005105	Training_dev loss per image: 0.003910


  8%|▊         | 4/50 [31:08<5:59:24, 468.79s/it]

END OF EPOCH: 4 	Training loss per batch: 0.004654	Training_dev loss per image: 0.003374


 10%|█         | 5/50 [38:50<5:50:03, 466.74s/it]

END OF EPOCH: 5 	Training loss per batch: 0.004197	Training_dev loss per image: 0.002973


 12%|█▏        | 6/50 [46:52<5:45:44, 471.47s/it]

END OF EPOCH: 6 	Training loss per batch: 0.004109	Training_dev loss per image: 0.003070


 14%|█▍        | 7/50 [55:19<5:45:25, 482.00s/it]

END OF EPOCH: 7 	Training loss per batch: 0.003854	Training_dev loss per image: 0.003616


 16%|█▌        | 8/50 [1:02:58<5:32:35, 475.14s/it]

END OF EPOCH: 8 	Training loss per batch: 0.003889	Training_dev loss per image: 0.003260


 18%|█▊        | 9/50 [1:10:49<5:23:49, 473.89s/it]

END OF EPOCH: 9 	Training loss per batch: 0.003795	Training_dev loss per image: 0.002787


 20%|██        | 10/50 [1:18:41<5:15:28, 473.21s/it]

END OF EPOCH: 10 	Training loss per batch: 0.003894	Training_dev loss per image: 0.003921


 22%|██▏       | 11/50 [1:26:38<5:08:29, 474.60s/it]

END OF EPOCH: 11 	Training loss per batch: 0.003791	Training_dev loss per image: 0.002502


 24%|██▍       | 12/50 [1:34:23<4:58:39, 471.57s/it]

END OF EPOCH: 12 	Training loss per batch: 0.003565	Training_dev loss per image: 0.003515


 26%|██▌       | 13/50 [1:42:15<4:50:53, 471.72s/it]

END OF EPOCH: 13 	Training loss per batch: 0.003503	Training_dev loss per image: 0.002342


 28%|██▊       | 14/50 [1:50:27<4:46:39, 477.76s/it]

END OF EPOCH: 14 	Training loss per batch: 0.003566	Training_dev loss per image: 0.003747


 30%|███       | 15/50 [1:58:10<4:36:09, 473.41s/it]

END OF EPOCH: 15 	Training loss per batch: 0.003471	Training_dev loss per image: 0.002422


 32%|███▏      | 16/50 [2:06:10<4:29:17, 475.22s/it]

END OF EPOCH: 16 	Training loss per batch: 0.003439	Training_dev loss per image: 0.002778


 34%|███▍      | 17/50 [2:14:02<4:20:58, 474.51s/it]

END OF EPOCH: 17 	Training loss per batch: 0.003464	Training_dev loss per image: 0.002866


 36%|███▌      | 18/50 [2:22:09<4:14:58, 478.06s/it]

END OF EPOCH: 18 	Training loss per batch: 0.003450	Training_dev loss per image: 0.002227


 38%|███▊      | 19/50 [2:29:56<4:05:15, 474.69s/it]

END OF EPOCH: 19 	Training loss per batch: 0.003428	Training_dev loss per image: 0.002248


 40%|████      | 20/50 [2:37:50<3:57:19, 474.64s/it]

END OF EPOCH: 20 	Training loss per batch: 0.003188	Training_dev loss per image: 0.002287


 42%|████▏     | 21/50 [2:46:02<3:51:54, 479.82s/it]

END OF EPOCH: 21 	Training loss per batch: 0.003418	Training_dev loss per image: 0.002314


 44%|████▍     | 22/50 [2:53:44<3:41:27, 474.55s/it]

END OF EPOCH: 22 	Training loss per batch: 0.003148	Training_dev loss per image: 0.002883


 46%|████▌     | 23/50 [3:01:40<3:33:38, 474.77s/it]

END OF EPOCH: 23 	Training loss per batch: 0.003043	Training_dev loss per image: 0.002141


 48%|████▊     | 24/50 [3:09:22<3:24:06, 471.02s/it]

END OF EPOCH: 24 	Training loss per batch: 0.003181	Training_dev loss per image: 0.002536


 50%|█████     | 25/50 [3:17:20<3:17:07, 473.08s/it]

END OF EPOCH: 25 	Training loss per batch: 0.003165	Training_dev loss per image: 0.001952


 52%|█████▏    | 26/50 [3:25:05<3:08:21, 470.89s/it]

END OF EPOCH: 26 	Training loss per batch: 0.002984	Training_dev loss per image: 0.002425


 54%|█████▍    | 27/50 [3:33:22<3:03:30, 478.73s/it]

END OF EPOCH: 27 	Training loss per batch: 0.003039	Training_dev loss per image: 0.002425


 56%|█████▌    | 28/50 [3:41:19<2:55:15, 477.99s/it]

END OF EPOCH: 28 	Training loss per batch: 0.002976	Training_dev loss per image: 0.002053


 58%|█████▊    | 29/50 [3:48:59<2:45:28, 472.81s/it]

END OF EPOCH: 29 	Training loss per batch: 0.003106	Training_dev loss per image: 0.002112


 58%|█████▊    | 29/50 [3:56:51<2:51:30, 490.04s/it]
  0%|          | 0/50 [00:00<?, ?it/s]

END OF EPOCH: 30 	Training loss per batch: 0.003234	Training_dev loss per image: 0.002024
[2, 2]


  2%|▏         | 1/50 [07:30<6:08:16, 450.95s/it]

END OF EPOCH: 1 	Training loss per batch: 0.034390	Training_dev loss per image: 0.013824


  4%|▍         | 2/50 [15:13<6:03:32, 454.42s/it]

END OF EPOCH: 2 	Training loss per batch: 0.009287	Training_dev loss per image: 0.004303


  6%|▌         | 3/50 [22:42<5:54:44, 452.86s/it]

END OF EPOCH: 3 	Training loss per batch: 0.005186	Training_dev loss per image: 0.003077


  8%|▊         | 4/50 [30:34<5:51:33, 458.56s/it]

END OF EPOCH: 4 	Training loss per batch: 0.004202	Training_dev loss per image: 0.003128


 10%|█         | 5/50 [38:15<5:44:30, 459.34s/it]

END OF EPOCH: 5 	Training loss per batch: 0.004055	Training_dev loss per image: 0.002799


 12%|█▏        | 6/50 [45:44<5:34:28, 456.11s/it]

END OF EPOCH: 6 	Training loss per batch: 0.003820	Training_dev loss per image: 0.004038


 14%|█▍        | 7/50 [53:23<5:27:33, 457.07s/it]

END OF EPOCH: 7 	Training loss per batch: 0.003829	Training_dev loss per image: 0.002892


 16%|█▌        | 8/50 [1:00:51<5:18:06, 454.43s/it]

END OF EPOCH: 8 	Training loss per batch: 0.003557	Training_dev loss per image: 0.002401


 18%|█▊        | 9/50 [1:08:30<5:11:26, 455.76s/it]

END OF EPOCH: 9 	Training loss per batch: 0.003595	Training_dev loss per image: 0.002753


 20%|██        | 10/50 [1:15:53<5:01:14, 451.86s/it]

END OF EPOCH: 10 	Training loss per batch: 0.003568	Training_dev loss per image: 0.002254


 22%|██▏       | 11/50 [1:23:50<4:58:37, 459.41s/it]

END OF EPOCH: 11 	Training loss per batch: 0.003499	Training_dev loss per image: 0.002693


 24%|██▍       | 12/50 [1:31:31<4:51:14, 459.84s/it]

END OF EPOCH: 12 	Training loss per batch: 0.003374	Training_dev loss per image: 0.002421


 26%|██▌       | 13/50 [1:38:57<4:41:00, 455.70s/it]

END OF EPOCH: 13 	Training loss per batch: 0.003378	Training_dev loss per image: 0.002136


 28%|██▊       | 14/50 [1:46:37<4:34:13, 457.03s/it]

END OF EPOCH: 14 	Training loss per batch: 0.003243	Training_dev loss per image: 0.002252


 30%|███       | 15/50 [1:54:05<4:25:04, 454.42s/it]

END OF EPOCH: 15 	Training loss per batch: 0.003381	Training_dev loss per image: 0.002239


 32%|███▏      | 16/50 [2:01:44<4:18:13, 455.71s/it]

END OF EPOCH: 16 	Training loss per batch: 0.003089	Training_dev loss per image: 0.002205


 34%|███▍      | 17/50 [2:09:36<4:13:18, 460.58s/it]

END OF EPOCH: 17 	Training loss per batch: 0.003112	Training_dev loss per image: 0.001919


 36%|███▌      | 18/50 [2:17:02<4:03:20, 456.27s/it]

END OF EPOCH: 18 	Training loss per batch: 0.003173	Training_dev loss per image: 0.001922


 38%|███▊      | 19/50 [2:24:42<3:56:17, 457.34s/it]

END OF EPOCH: 19 	Training loss per batch: 0.003013	Training_dev loss per image: 0.002114


 40%|████      | 20/50 [2:32:10<3:47:11, 454.38s/it]

END OF EPOCH: 20 	Training loss per batch: 0.003063	Training_dev loss per image: 0.002101


 42%|████▏     | 21/50 [2:39:50<3:40:25, 456.05s/it]

END OF EPOCH: 21 	Training loss per batch: 0.002949	Training_dev loss per image: 0.001955


 42%|████▏     | 21/50 [2:47:16<3:51:00, 477.94s/it]

END OF EPOCH: 22 	Training loss per batch: 0.003013	Training_dev loss per image: 0.001946





In [29]:
# training for 2-layer network
model_parameters=[[8,2],[32,1]]

folder='../../1_Data/2_Trained_AE/'

for parameters in model_parameters:
    
    print(parameters)
        
    # Initialize the proper model
    unet = UNet_2_layer(1,1,parameters[0],parameters[1])
    
    # Optimizer
    optimizer = torch.optim.Adam(params=unet.parameters(), lr=0.0025)
    
    # Create output folder
    data_folder = folder+'/2_layer_'+str(parameters[0])+'_'+str(parameters[1])+'/'
    os.makedirs(data_folder, exist_ok=True)
    
    # Train & save weights
    train_loss, dev_train_loss = train_network(n_epochs, dataloader, unet, optimizer, criterion, device, data_folder)
    
    # Save losses
    with open(data_folder+'train_loss.pkl', 'wb') as f:
        pickle.dump(train_loss, f)
        
    with open(data_folder+'dev_train_loss.pkl', 'wb') as f:
        pickle.dump(dev_train_loss, f)

  0%|          | 0/50 [00:00<?, ?it/s]

[8, 2]


  2%|▏         | 1/50 [07:41<6:16:31, 461.05s/it]

END OF EPOCH: 1 	Training loss per batch: 0.012957	Training_dev loss per image: 0.003832


  4%|▍         | 2/50 [15:31<6:11:10, 463.97s/it]

END OF EPOCH: 2 	Training loss per batch: 0.005026	Training_dev loss per image: 0.003958


  6%|▌         | 3/50 [22:58<5:59:28, 458.91s/it]

END OF EPOCH: 3 	Training loss per batch: 0.004664	Training_dev loss per image: 0.003663


  8%|▊         | 4/50 [30:37<5:51:46, 458.83s/it]

END OF EPOCH: 4 	Training loss per batch: 0.004118	Training_dev loss per image: 0.004219


 10%|█         | 5/50 [38:04<5:41:31, 455.36s/it]

END OF EPOCH: 5 	Training loss per batch: 0.004112	Training_dev loss per image: 0.002975


 12%|█▏        | 6/50 [45:44<5:34:51, 456.63s/it]

END OF EPOCH: 6 	Training loss per batch: 0.003918	Training_dev loss per image: 0.002639


 14%|█▍        | 7/50 [53:12<5:25:20, 453.96s/it]

END OF EPOCH: 7 	Training loss per batch: 0.003662	Training_dev loss per image: 0.002798


 16%|█▌        | 8/50 [1:00:52<5:19:06, 455.86s/it]

END OF EPOCH: 8 	Training loss per batch: 0.003820	Training_dev loss per image: 0.003276


 18%|█▊        | 9/50 [1:08:47<5:15:24, 461.56s/it]

END OF EPOCH: 9 	Training loss per batch: 0.003617	Training_dev loss per image: 0.002903


 20%|██        | 10/50 [1:16:15<5:05:00, 457.52s/it]

END OF EPOCH: 10 	Training loss per batch: 0.003512	Training_dev loss per image: 0.003181


 22%|██▏       | 11/50 [1:23:53<4:57:26, 457.61s/it]

END OF EPOCH: 11 	Training loss per batch: 0.003486	Training_dev loss per image: 0.002473


 24%|██▍       | 12/50 [1:31:24<4:48:32, 455.59s/it]

END OF EPOCH: 12 	Training loss per batch: 0.003439	Training_dev loss per image: 0.003348


 26%|██▌       | 13/50 [1:39:02<4:41:32, 456.55s/it]

END OF EPOCH: 13 	Training loss per batch: 0.003327	Training_dev loss per image: 0.004369


 28%|██▊       | 14/50 [1:46:32<4:32:38, 454.40s/it]

END OF EPOCH: 14 	Training loss per batch: 0.003474	Training_dev loss per image: 0.002614


 30%|███       | 15/50 [1:54:25<4:28:17, 459.93s/it]

END OF EPOCH: 15 	Training loss per batch: 0.003122	Training_dev loss per image: 0.002044


 32%|███▏      | 16/50 [2:02:06<4:20:57, 460.50s/it]

END OF EPOCH: 16 	Training loss per batch: 0.003392	Training_dev loss per image: 0.002617


 34%|███▍      | 17/50 [2:09:34<4:11:06, 456.56s/it]

END OF EPOCH: 17 	Training loss per batch: 0.003339	Training_dev loss per image: 0.002213


 36%|███▌      | 18/50 [2:17:12<4:03:47, 457.11s/it]

END OF EPOCH: 18 	Training loss per batch: 0.003184	Training_dev loss per image: 0.002260


 38%|███▊      | 19/50 [2:24:41<3:54:49, 454.50s/it]

END OF EPOCH: 19 	Training loss per batch: 0.003224	Training_dev loss per image: 0.002369


 38%|███▊      | 19/50 [2:32:20<4:08:32, 481.06s/it]
  0%|          | 0/50 [00:00<?, ?it/s]

END OF EPOCH: 20 	Training loss per batch: 0.003229	Training_dev loss per image: 0.002161
[32, 1]


  2%|▏         | 1/50 [07:27<6:05:35, 447.67s/it]

END OF EPOCH: 1 	Training loss per batch: 0.010321	Training_dev loss per image: 0.005772


  4%|▍         | 2/50 [15:22<6:04:35, 455.73s/it]

END OF EPOCH: 2 	Training loss per batch: 0.005212	Training_dev loss per image: 0.005696


  6%|▌         | 3/50 [23:00<5:57:33, 456.45s/it]

END OF EPOCH: 3 	Training loss per batch: 0.005092	Training_dev loss per image: 0.005019


  8%|▊         | 4/50 [30:26<5:47:38, 453.45s/it]

END OF EPOCH: 4 	Training loss per batch: 0.004372	Training_dev loss per image: 0.008498


 10%|█         | 5/50 [38:05<5:41:22, 455.16s/it]

END OF EPOCH: 5 	Training loss per batch: 0.004153	Training_dev loss per image: 0.003603


 12%|█▏        | 6/50 [45:34<5:32:16, 453.10s/it]

END OF EPOCH: 6 	Training loss per batch: 0.003990	Training_dev loss per image: 0.003001


 14%|█▍        | 7/50 [53:12<5:25:55, 454.77s/it]

END OF EPOCH: 7 	Training loss per batch: 0.003755	Training_dev loss per image: 0.005343


 16%|█▌        | 8/50 [1:00:36<5:15:55, 451.33s/it]

END OF EPOCH: 8 	Training loss per batch: 0.003795	Training_dev loss per image: 0.003454


 18%|█▊        | 9/50 [1:08:33<5:13:44, 459.14s/it]

END OF EPOCH: 9 	Training loss per batch: 0.003639	Training_dev loss per image: 0.002609


 20%|██        | 10/50 [1:16:13<5:06:16, 459.41s/it]

END OF EPOCH: 10 	Training loss per batch: 0.003466	Training_dev loss per image: 0.002531


 22%|██▏       | 11/50 [1:23:39<4:56:03, 455.48s/it]

END OF EPOCH: 11 	Training loss per batch: 0.003382	Training_dev loss per image: 0.002417


 24%|██▍       | 12/50 [1:31:20<4:49:26, 457.02s/it]

END OF EPOCH: 12 	Training loss per batch: 0.003408	Training_dev loss per image: 0.003950


 26%|██▌       | 13/50 [1:38:47<4:39:54, 453.90s/it]

END OF EPOCH: 13 	Training loss per batch: 0.003307	Training_dev loss per image: 0.002261


 28%|██▊       | 14/50 [1:46:27<4:33:26, 455.72s/it]

END OF EPOCH: 14 	Training loss per batch: 0.003274	Training_dev loss per image: 0.002300


 30%|███       | 15/50 [1:54:18<4:28:35, 460.44s/it]

END OF EPOCH: 15 	Training loss per batch: 0.003200	Training_dev loss per image: 0.002333


 32%|███▏      | 16/50 [2:01:46<4:18:48, 456.71s/it]

END OF EPOCH: 16 	Training loss per batch: 0.003069	Training_dev loss per image: 0.002144


 34%|███▍      | 17/50 [2:09:27<4:11:56, 458.06s/it]

END OF EPOCH: 17 	Training loss per batch: 0.003069	Training_dev loss per image: 0.002197


 36%|███▌      | 18/50 [2:16:55<4:02:40, 455.00s/it]

END OF EPOCH: 18 	Training loss per batch: 0.003002	Training_dev loss per image: 0.002003


 38%|███▊      | 19/50 [2:24:34<3:55:43, 456.25s/it]

END OF EPOCH: 19 	Training loss per batch: 0.003063	Training_dev loss per image: 0.002813


 40%|████      | 20/50 [2:32:01<3:46:44, 453.48s/it]

END OF EPOCH: 20 	Training loss per batch: 0.003264	Training_dev loss per image: 0.002257


 42%|████▏     | 21/50 [2:39:40<3:39:59, 455.16s/it]

END OF EPOCH: 21 	Training loss per batch: 0.003046	Training_dev loss per image: 0.002284


 44%|████▍     | 22/50 [2:47:32<3:34:41, 460.06s/it]

END OF EPOCH: 22 	Training loss per batch: 0.003171	Training_dev loss per image: 0.002207


 44%|████▍     | 22/50 [2:55:01<3:42:45, 477.33s/it]

END OF EPOCH: 23 	Training loss per batch: 0.002955	Training_dev loss per image: 0.002400



