# Dataloaders

In [1]:
# Importing libraries
import torchio as tio
import glob
import numpy as np
import random
import os

from collections import OrderedDict
from pathlib import Path

from tqdm import tqdm
import time

import torchio as tio
from torchio.transforms import (RescaleIntensity,RandomFlip,Compose, HistogramStandardization, RandomAffine, RandomNoise, ToCanonical)

from torch.utils.data import DataLoader
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

import pickle

In [2]:
# Define dictionary describing assignment of participants to the groups

# Groups
train_groups=['ABIDE','Athletes','HCP','COBRE','Leipzig']
test_dev_groups=['UoN']
test_groups=['CHIASM']

# CHIASM dobry
# ABIDE dobry
# Athletes dobry
# COBRE dobry
# HCP dobry
# Leipzig good
# MCIC złe
# UoN dobry

# Splits
train_split = 0.85
dev_split = 0.15
test_split = 0.0

# Dictionary with study design
design = {}

design['train']={}
design['dev_train']={}
design['dev_test']={}
design['test']={}

design['train']['all']={}
design['dev_train']['all']={}
design['dev_test']['all']={}
design['test']['all']={}

# Training data
for group in train_groups:
    
    #design['dev_train'][group]={}
    #design['test'][group]={}

    # Idices of all subjects
    ids=[path.split('/')[-2] for path in glob.glob('../../1_Data/1_Input/'+group+'/*/tmp_brain_mask.nii.gz')] 
    # Randomize order
    random.shuffle(ids) 
    # Find split ratios
    train_idx = np.int(np.floor(len(ids)*train_split))
    dev_idx = np.int(np.floor(len(ids)*(train_split+dev_split)))
    
    for i in range(len(ids)):
        
        path_to_folder='../../1_Data/1_Input/'+group+'/'+ids[i]+'/'
        
        files={}
        files['brain']=path_to_folder+'t1w_1mm_iso.nii.gz'
        files['probs']=path_to_folder+'tmp_brain_mask.nii.gz'
        files['chiasm']=path_to_folder+'chiasm.nii.gz'
    
        if i+1<=train_idx:
            design['train']['all'][ids[i]]=files
        elif i+1 > dev_idx:
            design['test']['all'][ids[i]]=files
            #design['test'][group][ids[i]]=files
        else:
            design['dev_train']['all'][ids[i]]=files
            #design['dev_train'][group][ids[i]]=files
        
# Dev data
for group in test_dev_groups:

    # Idices of all subjects
    ids=[path.split('/')[-2] for path in glob.glob('../../1_Data/1_Input/'+group+'/*/tmp_brain_mask.nii.gz')] 
    
    for sub_id in ids:
        
        path_to_folder='../../1_Data/1_Input/'+group+'/'+sub_id+'/'
        
        files={}
        files['brain']=path_to_folder+'t1w_1mm_iso.nii.gz'
        files['probs']=path_to_folder+'tmp_brain_mask.nii.gz'
        files['chiasm']=path_to_folder+'chiasm.nii.gz'
        
        design['dev_test']['all'][sub_id]=files
    
# Test data
for group in test_groups:
    
    #design['test'][group]={}

    # Idices of all subjects
    ids=[path.split('/')[-2] for path in glob.glob('../../1_Data/1_Input/'+group+'/*/tmp_brain_mask.nii.gz')] 
    
    for sub_id in ids:
        
        path_to_folder='../../1_Data/1_Input/'+group+'/'+sub_id+'/'
        
        files={}
        files['brain']=path_to_folder+'t1w_1mm_iso.nii.gz'
        files['probs']=path_to_folder+'tmp_brain_mask.nii.gz'
        files['chiasm']=path_to_folder+'chiasm.nii.gz'
        
        design['test']['all'][sub_id]=files 
        #design['test'][group][sub_id]=files  

In [3]:
# Save the dictionary
#with open('study_design.pkl', 'wb') as f:
#    pickle.dump(design, f)

# Load the dictionary
with open('study_design.pkl', 'rb') as f:
    design = pickle.load(f)

In [4]:
# Dictionary with all images
subjects_list={}

for group in design.keys():
    subjects_list[group]={}
    
    for dataset in design[group].keys():
        subjects_list[group][dataset]= [tio.Subject(t1=tio.Image(design[group][dataset][sub]['brain'], type=tio.INTENSITY),
                            probs = tio.Image(design[group][dataset][sub]['probs'], type = tio.INTENSITY)) for sub in design[group][dataset].keys()]

In [5]:
# Rescale
rescale = RescaleIntensity((0,1))
# Flip
flip = RandomFlip((0,1,2), flip_probability=0.5, p=0.25)
# Affine transformations
#affine = RandomAffine(degrees=30)

# Composing transforms - rescaling is mandatory, training data is subjected to a range of additional augmentations
transform_train = Compose([rescale, flip]) # leaving out standardization for now
transform_dev = Compose([rescale]) # leaving out standardization for now

In [6]:
# Torchio's (Pytorch's) Dataset
data = {}

for group in subjects_list.keys():
    data[group]={}
    
    for dataset in subjects_list[group]:
        
        if group == 'train':
            data[group][dataset]=tio.SubjectsDataset(subjects_list[group][dataset], transform=transform_train)
        else:
            data[group][dataset]=tio.SubjectsDataset(subjects_list[group][dataset], transform=transform_dev)

In [7]:
# Sampler
patch_size = (24,24,8)
queue_length = 500
samples_per_volume = 5

sampler = tio.data.WeightedSampler(patch_size,'probs')

In [9]:
# Dataloader

dataloader = {}

for group in data.keys():
    dataloader[group]={}
    
    for dataset in data[group]:
        
        dataloader[group][dataset]=DataLoader(tio.Queue(data[group][dataset], queue_length, samples_per_volume, sampler, num_workers=6, shuffle_subjects=True, shuffle_patches=True), batch_size = 25, num_workers=0)


In [14]:
# Testing
'''
num_epochs = 1

model = torch.nn.Identity()

for epoch_index in range(num_epochs):
    for patches_batch in dataloader['dev_train']['all']:
        #print(patches_batch)
        inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject
        targets = patches_batch['t1'][tio.DATA]  # key 'brain' is in subject
        logits = model(inputs)  # model being an instance of torch.nn.Module
'''

"\nnum_epochs = 1\n\nmodel = torch.nn.Identity()\n\nfor epoch_index in range(num_epochs):\n    for patches_batch in dataloader['dev_train']['all']:\n        #print(patches_batch)\n        inputs = patches_batch['t1'][tio.DATA]  # key 't1' is in subject\n        targets = patches_batch['t1'][tio.DATA]  # key 'brain' is in subject\n        logits = model(inputs)  # model being an instance of torch.nn.Module\n"

In [15]:
#inputs.shape
'''
fig = plt.figure(figsize=(20, 10))

for i in range(inputs.shape[0]):
    plt.subplot(5,8,i+1)
    plt.imshow(inputs[i,0,:,:,5],cmap='gray');
    
plt.show()
'''

"\nfig = plt.figure(figsize=(20, 10))\n\nfor i in range(inputs.shape[0]):\n    plt.subplot(5,8,i+1)\n    plt.imshow(inputs[i,0,:,:,5],cmap='gray');\n    \nplt.show()\n"

In [16]:
'''import torchio as tio
t1 = tio.ScalarImage('T1w')
t2 = tio.ScalarImage('T2w')
subject = tio.Subject(T1w=t1, T2w=t2)
cp = tio.CropOrPad((512, 512, 408))
subject = tio.Subject(T1w=cp(t1), T2w=cp(t2))
subject.plot(reorient=False)
'''

"import torchio as tio\nt1 = tio.ScalarImage('T1w')\nt2 = tio.ScalarImage('T2w')\nsubject = tio.Subject(T1w=t1, T2w=t2)\ncp = tio.CropOrPad((512, 512, 408))\nsubject = tio.Subject(T1w=cp(t1), T2w=cp(t2))\nsubject.plot(reorient=False)\n"

# Network and parameters

In [17]:
# Try setting CUDA if possible
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu") 
    
print(device)

cuda


In [18]:
# Cropped U-Net copied from Overfitting Model

class UNet(nn.Module):
    
    def __init__(self, in_channels=1, out_channels=1, init_features=10, scaling=2):
        super(UNet, self).__init__()
                
        # Encoding layers
        self.encoder1 = self.unet_block(in_channels, init_features, "enc1")
        self.pool1 = nn.AvgPool3d(kernel_size=2, stride=2, padding=0)
        self.encoder2 = self.unet_block(init_features, init_features*scaling, name='enc2')
        self.pool2 = nn.AvgPool3d(kernel_size=2, stride=2, padding=0)

        # Bottleneck layer
        self.bottleneck = self.unet_block(init_features*scaling, init_features*scaling**2, name='bottleneck')
        
        # Decoding layers (where merge with prevois encoding layers occurs)        
        self.upconv2 = nn.ConvTranspose3d(init_features*scaling**2, init_features*scaling, kernel_size=2, stride=2)
        self.decoder2 = self.unet_block(init_features*scaling, init_features*scaling, name='dec2')
                
        self.upconv1 = nn.ConvTranspose3d(init_features*scaling, init_features, kernel_size=2, stride=2)
        self.decoder1 = self.unet_block(init_features, init_features, name='dec1')
        
        # Final convolution - output equals number of output channels
        self.conv = nn.Conv3d(init_features, out_channels, kernel_size=1) 
        
    def forward(self,x):
        
        # Encoding
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool2(enc2))

        # Upconvolving, concatenating data from respective encoding phase and executing UNet block
        dec2 = self.upconv2(bottleneck)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = self.decoder1(dec1)
        
        out_conv = self.conv(dec1)
        
        return torch.sigmoid(out_conv)
    
    def unet_block(self, in_channels, features, name):
        
        return nn.Sequential(OrderedDict([(name+'conv1',nn.Conv3d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False)),
                             (name+'bnorm1', nn.BatchNorm3d(num_features=features)),
                             (name+'relu1', nn.ReLU(inplace=True)),
                             (name+'conv2', nn.Conv3d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False)),
                             (name+'bnorm2', nn.BatchNorm3d(num_features=features)),
                             (name+'relu2', nn.ReLU(inplace=True))])
                            )

    def output_latent_representations(self,x):
        
        print(x.shape)

        # Encoding
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool2(enc2))
        
        print(bottleneck.shape)
        
        return bottleneck

In [19]:
#unet = UNet(1,1,2,2) # Size of latent representation = (1/64) * init_features * scaling**2
#unet.to(device)

In [20]:
#print(sum(p.numel() for p in unet.parameters() if p.requires_grad))

In [21]:
# try processing
'''
outputs = unet(inputs.to(device))
outputs = outputs.cpu().detach().numpy()

fig = plt.figure(figsize=(20, 10))

for i in range(outputs.shape[0]):
    plt.subplot(10,10,i+1)
    plt.imshow(outputs[i,0,:,:,5],cmap='gray');
    
plt.show()
'''

"\noutputs = unet(inputs.to(device))\noutputs = outputs.cpu().detach().numpy()\n\nfig = plt.figure(figsize=(20, 10))\n\nfor i in range(outputs.shape[0]):\n    plt.subplot(10,10,i+1)\n    plt.imshow(outputs[i,0,:,:,5],cmap='gray');\n    \nplt.show()\n"

In [22]:
# Criterion
#criterion = DiceLoss()
criterion = nn.MSELoss()

# Number of epochs
n_epochs = 25

In [23]:
#outcome = criterion(inputs, torch.Tensor(outputs))
#print(outcome)

# Training

In [24]:
# Function returning trained model
def train_network(n_epochs, dataloaders, model, optimizer, criterion, device, save_path):
    
    track_train_loss = []
    track_dev_train_loss = []
    track_dev_test_loss = []
    
    valid_loss_min = np.Inf
    
    model.to(device)
        
    for epoch in tqdm(range(1, n_epochs+1)):
        
        # Initialize loss monitoring variables
        train_loss = 0.0
        dev_train_loss = 0.0
        dev_test_loss = 0.0
        
        i=0
        k=0
        j=0
        
        start = time.time()
        
        # Training
        model.train()
        
        for batch in dataloaders['train']['all']:
            
            data = batch['t1']['data'].to(device)
            
            optimizer.zero_grad()
            
            output = model(data)
            loss = criterion(output, data)
            loss.backward()
            
            optimizer.step()
            
            train_loss += loss.item()
            i+=1
            
        track_train_loss.append(train_loss/i)
        
        # Validation on two datasets
        model.eval()
        
        for batch in dataloaders['dev_train']['all']:
            
            data = batch['t1']['data'].to(device)
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output,data)
                
                dev_train_loss += loss.item()
                j+=1
                
        track_dev_train_loss.append(dev_train_loss/j)
        
        
        for batch in dataloaders['dev_test']['all']:
            
            data = batch['t1']['data'].to(device)
            
            with torch.no_grad():
                
                output = model(data)
                loss = criterion(output,data)
                
                dev_test_loss += loss.item()
                k+=1
                
        track_dev_test_loss.append(dev_test_loss/k)
        
        # Print summary of epoch
        duration = time.time() - start

        print('END OF EPOCH: {} \tTraining loss per batch: {:.6f}\tTraining_dev loss per batch: {:.6f}\tTest_dev loss per batch: {:.6f}'.format(epoch, train_loss/i, dev_train_loss/j, dev_test_loss/k))
       
        
        ## Save the model if reached min validation loss
        if dev_train_loss + dev_test_loss < valid_loss_min:
            valid_loss_min = dev_train_loss + dev_test_loss
            torch.save(model.state_dict(),save_path+'optimal_weights')
                        
    # return trained model
    return track_train_loss, track_dev_train_loss, track_dev_test_loss         

In [25]:
#model_parameters=[[64,1],[32,1],[16,2],[16,1],[8,2],[8,1],[4,4],[4,2],[4,1],[2,4],[2,2],[2,1],[1,8],[1,4],[1,2],[1,1]]
model_parameters=[[2,4],[16,1],[4,2],[2,2],[8,2],[4,4],[64,1],[1,8]]

folder='../../1_Data/2_Trained_AE/'

for parameters in model_parameters:
    
    print(parameters)
        
    # Initialize the proper model
    unet = UNet(1,1,parameters[0],parameters[1])
    
    # Optimizer
    optimizer = torch.optim.Adam(params=unet.parameters(), lr=0.005)
    
    # Create output folder
    data_folder = folder+'/'+str(parameters[0])+'_'+str(parameters[1])+'/'
    os.makedirs(data_folder, exist_ok=True)
    
    # Train & save weights
    train_loss, dev_train_loss, dev_test_loss = train_network(n_epochs, dataloader, unet, optimizer, criterion, device, data_folder)
    
    # Save losses
    with open(data_folder+'train_loss.pkl', 'wb') as f:
        pickle.dump(train_loss, f)
        
    with open(data_folder+'dev_train_loss.pkl', 'wb') as f:
        pickle.dump(dev_train_loss, f)
        
    with open(data_folder+'dev_test_loss.pkl', 'wb') as f:
        pickle.dump(dev_test_loss, f)



[1, 1]


  4%|▍         | 1/25 [33:27<13:22:54, 2007.27s/it]

END OF EPOCH: 1 	Training loss per batch: 0.021654	Training_dev loss per batch: 0.009881	Test_dev loss per batch: 0.008477


  8%|▊         | 2/25 [1:03:47<12:27:55, 1951.09s/it]

END OF EPOCH: 2 	Training loss per batch: 0.006293	Training_dev loss per batch: 0.004711	Test_dev loss per batch: 0.003947


 12%|█▏        | 3/25 [1:33:15<11:35:18, 1896.29s/it]

END OF EPOCH: 3 	Training loss per batch: 0.006213	Training_dev loss per batch: 0.004244	Test_dev loss per batch: 0.002689


 16%|█▌        | 4/25 [2:04:17<11:00:01, 1885.80s/it]

END OF EPOCH: 4 	Training loss per batch: 0.005990	Training_dev loss per batch: 0.004316	Test_dev loss per batch: 0.002692


 20%|██        | 5/25 [2:33:31<10:15:30, 1846.54s/it]

END OF EPOCH: 5 	Training loss per batch: 0.006079	Training_dev loss per batch: 0.004398	Test_dev loss per batch: 0.002063


 24%|██▍       | 6/25 [3:03:06<9:37:55, 1825.02s/it] 

END OF EPOCH: 6 	Training loss per batch: 0.005818	Training_dev loss per batch: 0.004326	Test_dev loss per batch: 0.002290


 28%|██▊       | 7/25 [3:34:49<9:14:32, 1848.45s/it]

END OF EPOCH: 7 	Training loss per batch: 0.005797	Training_dev loss per batch: 0.004415	Test_dev loss per batch: 0.001974
END OF EPOCH: 8 	Training loss per batch: 0.005939	Training_dev loss per batch: 0.004125	Test_dev loss per batch: 0.002412


 36%|███▌      | 9/25 [4:39:08<8:23:01, 1886.37s/it]

END OF EPOCH: 9 	Training loss per batch: 0.005888	Training_dev loss per batch: 0.005010	Test_dev loss per batch: 0.002049


 40%|████      | 10/25 [5:11:32<7:55:52, 1903.51s/it]

END OF EPOCH: 10 	Training loss per batch: 0.005790	Training_dev loss per batch: 0.004245	Test_dev loss per batch: 0.002489
END OF EPOCH: 11 	Training loss per batch: 0.005830	Training_dev loss per batch: 0.004184	Test_dev loss per batch: 0.001751


 48%|████▊     | 12/25 [6:13:04<6:46:02, 1874.07s/it]

END OF EPOCH: 12 	Training loss per batch: 0.005925	Training_dev loss per batch: 0.004099	Test_dev loss per batch: 0.001886


 52%|█████▏    | 13/25 [6:43:07<6:10:33, 1852.79s/it]

END OF EPOCH: 13 	Training loss per batch: 0.005597	Training_dev loss per batch: 0.004163	Test_dev loss per batch: 0.002226


 56%|█████▌    | 14/25 [7:13:10<5:36:55, 1837.82s/it]

END OF EPOCH: 14 	Training loss per batch: 0.005876	Training_dev loss per batch: 0.004321	Test_dev loss per batch: 0.001676


 60%|██████    | 15/25 [7:43:19<5:04:50, 1829.03s/it]

END OF EPOCH: 15 	Training loss per batch: 0.005710	Training_dev loss per batch: 0.004663	Test_dev loss per batch: 0.001992
END OF EPOCH: 16 	Training loss per batch: 0.006005	Training_dev loss per batch: 0.003908	Test_dev loss per batch: 0.001818


 68%|██████▊   | 17/25 [8:43:37<4:02:28, 1818.60s/it]

END OF EPOCH: 17 	Training loss per batch: 0.006151	Training_dev loss per batch: 0.004062	Test_dev loss per batch: 0.001745


 72%|███████▏  | 18/25 [9:13:30<3:31:15, 1810.74s/it]

END OF EPOCH: 18 	Training loss per batch: 0.005835	Training_dev loss per batch: 0.004259	Test_dev loss per batch: 0.002180


 76%|███████▌  | 19/25 [9:46:01<3:05:16, 1852.78s/it]

END OF EPOCH: 19 	Training loss per batch: 0.005988	Training_dev loss per batch: 0.003973	Test_dev loss per batch: 0.002087


 80%|████████  | 20/25 [10:14:12<2:30:21, 1804.33s/it]

END OF EPOCH: 20 	Training loss per batch: 0.005988	Training_dev loss per batch: 0.004068	Test_dev loss per batch: 0.001753


 84%|████████▍ | 21/25 [10:42:55<1:58:40, 1780.01s/it]

END OF EPOCH: 21 	Training loss per batch: 0.005896	Training_dev loss per batch: 0.003974	Test_dev loss per batch: 0.002085


 88%|████████▊ | 22/25 [11:11:45<1:28:15, 1765.07s/it]

END OF EPOCH: 22 	Training loss per batch: 0.005845	Training_dev loss per batch: 0.004052	Test_dev loss per batch: 0.001973


 92%|█████████▏| 23/25 [11:40:36<58:29, 1754.68s/it]  

END OF EPOCH: 23 	Training loss per batch: 0.005489	Training_dev loss per batch: 0.004083	Test_dev loss per batch: 0.001988


 96%|█████████▌| 24/25 [12:09:30<29:08, 1748.62s/it]

END OF EPOCH: 24 	Training loss per batch: 0.005861	Training_dev loss per batch: 0.004169	Test_dev loss per batch: 0.001916


100%|██████████| 25/25 [12:38:31<00:00, 1820.46s/it]

END OF EPOCH: 25 	Training loss per batch: 0.005777	Training_dev loss per batch: 0.004529	Test_dev loss per batch: 0.001997
[32, 1]



  0%|          | 0/25 [00:00<?, ?it/s]

END OF EPOCH: 1 	Training loss per batch: 0.006281	Training_dev loss per batch: 0.003072	Test_dev loss per batch: 0.000938


  8%|▊         | 2/25 [57:36<11:02:10, 1727.42s/it]

END OF EPOCH: 2 	Training loss per batch: 0.004664	Training_dev loss per batch: 0.003160	Test_dev loss per batch: 0.000709


 12%|█▏        | 3/25 [1:26:45<10:35:46, 1733.94s/it]

END OF EPOCH: 3 	Training loss per batch: 0.004152	Training_dev loss per batch: 0.005915	Test_dev loss per batch: 0.000950
END OF EPOCH: 4 	Training loss per batch: 0.004472	Training_dev loss per batch: 0.002191	Test_dev loss per batch: 0.000570


 20%|██        | 5/25 [2:24:21<9:37:24, 1732.24s/it] 

END OF EPOCH: 5 	Training loss per batch: 0.003979	Training_dev loss per batch: 0.002097	Test_dev loss per batch: 0.000587


 24%|██▍       | 6/25 [2:53:30<9:10:06, 1737.18s/it]

END OF EPOCH: 6 	Training loss per batch: 0.003242	Training_dev loss per batch: 0.002919	Test_dev loss per batch: 0.001896


 28%|██▊       | 7/25 [3:22:24<8:40:53, 1736.29s/it]

END OF EPOCH: 7 	Training loss per batch: 0.002900	Training_dev loss per batch: 0.002060	Test_dev loss per batch: 0.001044


 32%|███▏      | 8/25 [3:51:17<8:11:43, 1735.47s/it]

END OF EPOCH: 8 	Training loss per batch: 0.003096	Training_dev loss per batch: 0.003948	Test_dev loss per batch: 0.000991
END OF EPOCH: 9 	Training loss per batch: 0.002763	Training_dev loss per batch: 0.001789	Test_dev loss per batch: 0.000383


 40%|████      | 10/25 [4:49:10<7:14:00, 1736.04s/it]

END OF EPOCH: 10 	Training loss per batch: 0.002790	Training_dev loss per batch: 0.005005	Test_dev loss per batch: 0.000578


 44%|████▍     | 11/25 [5:18:05<6:44:59, 1735.67s/it]

END OF EPOCH: 11 	Training loss per batch: 0.002682	Training_dev loss per batch: 0.002534	Test_dev loss per batch: 0.000402


 48%|████▊     | 12/25 [5:46:47<6:15:07, 1731.36s/it]

END OF EPOCH: 12 	Training loss per batch: 0.002626	Training_dev loss per batch: 0.001829	Test_dev loss per batch: 0.000428


 52%|█████▏    | 13/25 [6:15:32<5:45:55, 1729.63s/it]

END OF EPOCH: 13 	Training loss per batch: 0.002508	Training_dev loss per batch: 0.001839	Test_dev loss per batch: 0.000534
END OF EPOCH: 14 	Training loss per batch: 0.002342	Training_dev loss per batch: 0.001714	Test_dev loss per batch: 0.000447


 60%|██████    | 15/25 [7:13:02<4:47:47, 1726.75s/it]

END OF EPOCH: 15 	Training loss per batch: 0.002244	Training_dev loss per batch: 0.001839	Test_dev loss per batch: 0.000654


 64%|██████▍   | 16/25 [7:42:01<4:19:35, 1730.60s/it]

END OF EPOCH: 16 	Training loss per batch: 0.002115	Training_dev loss per batch: 0.001997	Test_dev loss per batch: 0.000576


 68%|██████▊   | 17/25 [8:11:22<3:51:57, 1739.64s/it]

END OF EPOCH: 17 	Training loss per batch: 0.002131	Training_dev loss per batch: 0.001769	Test_dev loss per batch: 0.000409
END OF EPOCH: 18 	Training loss per batch: 0.002072	Training_dev loss per batch: 0.001722	Test_dev loss per batch: 0.000345


 76%|███████▌  | 19/25 [9:11:27<2:57:39, 1776.57s/it]

END OF EPOCH: 19 	Training loss per batch: 0.001947	Training_dev loss per batch: 0.002005	Test_dev loss per batch: 0.000592
END OF EPOCH: 20 	Training loss per batch: 0.002088	Training_dev loss per batch: 0.001694	Test_dev loss per batch: 0.000421


 84%|████████▍ | 21/25 [10:13:58<2:01:48, 1827.11s/it]

END OF EPOCH: 21 	Training loss per batch: 0.002013	Training_dev loss per batch: 0.001991	Test_dev loss per batch: 0.000450


 88%|████████▊ | 22/25 [10:44:50<1:31:43, 1834.50s/it]

END OF EPOCH: 22 	Training loss per batch: 0.001892	Training_dev loss per batch: 0.002136	Test_dev loss per batch: 0.000537
END OF EPOCH: 23 	Training loss per batch: 0.001843	Training_dev loss per batch: 0.001608	Test_dev loss per batch: 0.000363


 96%|█████████▌| 24/25 [11:49:32<31:32, 1892.94s/it]  

END OF EPOCH: 24 	Training loss per batch: 0.001864	Training_dev loss per batch: 0.002856	Test_dev loss per batch: 0.000345


100%|██████████| 25/25 [12:19:48<00:00, 1775.53s/it]

END OF EPOCH: 25 	Training loss per batch: 0.001782	Training_dev loss per batch: 0.003122	Test_dev loss per batch: 0.000843
[2, 4]



  0%|          | 0/25 [13:28<?, ?it/s]


KeyboardInterrupt: 