In [1]:
import torch
import pickle

with open( "x_train_hist.p", "rb" ) as file:
    x_train_hist = pickle.load(file)
    
with open( "y_train_hist.p", "rb" ) as file:
    y_train_hist = pickle.load(file)
    
with open( "x_valid_hist.p", "rb" ) as file:
    x_valid_hist = pickle.load(file)
    
with open( "y_valid_hist.p", "rb" ) as file:
    y_valid_hist = pickle.load(file)

In [2]:
import torch
import random
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm

from mriqa_dataset import MRIQADataset
from networks import ClassicCNN, PhilsClassicCnn, CatNet, CustomResNet

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696



In [3]:
# set random seeds for reproducibility
random.seed(21062020)
np.random.seed(21062020)
torch.manual_seed(21062020)
torch.cuda.manual_seed(21062020)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(torch.cuda.current_device())
torch.cuda.set_device(0)

net = CustomResNet(num_classes=5)
net = net.cuda()

optimizer = optim.SGD(net.parameters(), lr=1e-3)
ce = CrossEntropyLoss().cuda()

num_epochs = len(x_train_hist)
num_mini_batches = 13

loss_csv = open('losses.csv', 'w')
loss_csv.write('epoch,training,validation\n')

best_val_loss = 999
print("start training")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    net.train()
    
    train_batches = x_train_hist[epoch]
    train_labels = y_train_hist[epoch]

    # train loop
    for sample, label in tqdm(zip(train_batches, train_labels), total=len(train_batches), leave=False):       
        sample = sample.cuda()
        label = label.cuda()

        prediction = net(sample)
        loss = ce(prediction, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print('[{}] train-loss: {}'.format(epoch, epoch_loss / num_mini_batches))
    loss_csv.write(str(epoch) + ',' + str(epoch_loss / num_mini_batches))
    loss_csv.flush()

    # validation loop
    net.eval()
    mean_validation_loss = 0
    num_validation_mini_batches = 11

    if epoch % 3 == 0:
        valid_batches = x_valid_hist[epoch//5]
        valid_labels = y_valid_hist[epoch//5]
        
        with torch.no_grad():
            for sample, label in tqdm(zip(valid_batches, valid_labels), total=len(valid_batches), leave=False):              
                sample = sample.cuda()
                label = label.cuda()

                prediction = net(sample)
                validation_loss = ce(prediction, label)

                mean_validation_loss += validation_loss.item()
                
        print('[{}] validation-loss: {}'.format(epoch, mean_validation_loss / num_validation_mini_batches))
        loss_csv.write(',' + str(mean_validation_loss / num_validation_mini_batches) + '\n')
        loss_csv.flush()

    # save best model
    if mean_validation_loss <= best_val_loss:
        torch.save({'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': validation_loss.item()}, 'checkpoint_best_resnet')
        best_val_loss = mean_validation_loss

print('DONE.')

0
start training


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[0] train-loss: 0.5162584965045636


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[0] validation-loss: 0.43978590315038507


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[1] train-loss: 0.4385003034885113


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[2] train-loss: 0.432023837016179


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[3] train-loss: 0.2980826244904445


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[3] validation-loss: 0.42897891998291016


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[4] train-loss: 0.20040242259319013


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[5] train-loss: 0.23034498324761024


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[6] train-loss: 0.23673685238911554


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[6] validation-loss: 0.45749278502030805


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[7] train-loss: 0.211608153123122


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[8] train-loss: 0.09005270554469182


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[9] train-loss: 0.10968477450884305


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[9] validation-loss: 0.4225595539266413


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[10] train-loss: 0.09766369198377316


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[11] train-loss: 0.061711093554129966


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[12] train-loss: 0.060427469129745774


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[12] validation-loss: 0.0796541730788621


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[13] train-loss: 0.1334477410866664


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[14] train-loss: 0.13870279376323408


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[15] train-loss: 0.04571559108220614


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[15] validation-loss: 0.3694729381664233


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[16] train-loss: 0.04349151425636732


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[17] train-loss: 0.1567737746697206


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[18] train-loss: 0.20525278724156892


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[18] validation-loss: 2.238165508617054


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[19] train-loss: 0.09309402452065395


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[20] train-loss: 0.15790944489148948


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[21] train-loss: 0.4659235288317387


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[21] validation-loss: 3.4697695645419033


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[22] train-loss: 0.07353349259266487


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[23] train-loss: 0.16744095774797293


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[24] train-loss: 0.18205050780222967


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[24] validation-loss: 6.9783968491987745


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[25] train-loss: 0.08189062201059781


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[26] train-loss: 0.19509415901624239


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[27] train-loss: 0.22711083751458389


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[27] validation-loss: 23.016347018155184


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[28] train-loss: 0.22835334218465364


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[29] train-loss: 0.4033811825972337


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[30] train-loss: 0.2156569751409384


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[30] validation-loss: 5.035468795082786


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[31] train-loss: 0.6528343535386599


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[32] train-loss: 0.25868804638202375


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[33] train-loss: 0.14208664343907282


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[33] validation-loss: 0.8529327782717618


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[34] train-loss: 0.5794443304722126


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[35] train-loss: 0.4838333404981173


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[36] train-loss: 0.14329782586831313


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[36] validation-loss: 3.452334317294034


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[37] train-loss: 0.639651894569397


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[38] train-loss: 0.21854849962087777


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[39] train-loss: 0.43313992940462553


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[39] validation-loss: 3.1736054853959517


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[40] train-loss: 0.15705020152605498


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[41] train-loss: 0.41973568155215335


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[42] train-loss: 0.4004477560520172


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[42] validation-loss: 1.709108677777377


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[43] train-loss: 0.3458973719523503


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[44] train-loss: 0.21239726360027605


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[45] train-loss: 0.11003547677626976


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[45] validation-loss: 2.375606060028076


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[46] train-loss: 0.22185360468350923


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[47] train-loss: 1.0447702682935274


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[48] train-loss: 0.3462610153051523


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[48] validation-loss: 2.1917893669821997


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[49] train-loss: 0.24907652919109052


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[50] train-loss: 0.30840553687169003


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[51] train-loss: 0.14526575345259446


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[51] validation-loss: 4.216013474897905


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[52] train-loss: 0.1885862533862774


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[53] train-loss: 0.3337179101430453


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[54] train-loss: 0.2142328665806697


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[54] validation-loss: 4.3604042746803975


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[55] train-loss: 0.21680684273059553


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[56] train-loss: 0.3399096612746899


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[57] train-loss: 0.22216472258934608


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[57] validation-loss: 2.658657507462935


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[58] train-loss: 0.15181005688814017


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[59] train-loss: 0.20033534215046808


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[60] train-loss: 0.29915528572522676


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[60] validation-loss: 0.8297571485692804


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[61] train-loss: 0.40607417546785796


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[62] train-loss: 0.1602623726312931


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[63] train-loss: 0.2560026783209581


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[63] validation-loss: 0.6929769082502886


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[64] train-loss: 0.4646459542787992


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[65] train-loss: 0.13322760279362017


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[66] train-loss: 0.18938795419839713


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[66] validation-loss: 0.5035536722703413


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[67] train-loss: 0.15162855157485375


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[68] train-loss: 0.1456572528068836


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[69] train-loss: 0.48122867254110485


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[69] validation-loss: 0.610137170011347


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[70] train-loss: 0.20630833735832801


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[71] train-loss: 0.3120485223256625


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[72] train-loss: 0.2991821765899658


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[72] validation-loss: 0.6086036508733575


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[73] train-loss: 0.5839065955235407


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[74] train-loss: 0.43113361642910886


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[75] train-loss: 0.3455226513055655


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[75] validation-loss: 1.2663353789936413


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[76] train-loss: 0.4229755401611328


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[77] train-loss: 0.31132249190257144


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[78] train-loss: 1.1825026640525231


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[78] validation-loss: 1.5163829976862127


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[79] train-loss: 0.4427418250304002


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[80] train-loss: 0.27219481422350955


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[81] train-loss: 0.3334883818259606


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[81] validation-loss: 0.1514267081564123


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[82] train-loss: 0.22033635011086097


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[83] train-loss: 0.35428924285448515


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[84] train-loss: 0.523651333955618


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[84] validation-loss: 0.4146591750058261


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[85] train-loss: 0.25534003056012666


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[86] train-loss: 0.3045806838915898


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[87] train-loss: 1.010917333456186


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[87] validation-loss: 1.4725320122458718


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[88] train-loss: 0.3579455201442425


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[89] train-loss: 1.1163872595016773


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[90] train-loss: 0.701353118969844


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[90] validation-loss: 0.9932767694646661


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[91] train-loss: 0.6249597164300772


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[92] train-loss: 0.576836503469027


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[93] train-loss: 0.5441123292996333


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[93] validation-loss: 1.4377552379261365


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[94] train-loss: 0.3287309683286227


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[95] train-loss: 0.583855913235591


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[96] train-loss: 0.5115087857613196


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[96] validation-loss: 2.0220447887073862


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[97] train-loss: 0.43021984283740705


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[98] train-loss: 0.34427529114943284


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[99] train-loss: 0.33985060911912185


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[99] validation-loss: 2.3192830085754395


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[100] train-loss: 0.6616461185308603


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[101] train-loss: 0.43842754455713123


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[102] train-loss: 0.24887604438341582


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[102] validation-loss: 2.7091686075383965


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[103] train-loss: 0.38646392868115353


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[104] train-loss: 0.24207094082465538


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[105] train-loss: 0.09460184092705066


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[105] validation-loss: 1.741043134169145


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[106] train-loss: 0.23025870323181152


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[107] train-loss: 0.4119010338416466


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[108] train-loss: 0.30974538738910967


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[108] validation-loss: 1.5574248053810813


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[109] train-loss: 0.2491456178518442


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[110] train-loss: 0.41289918697797334


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[111] train-loss: 0.38785072473379284


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[111] validation-loss: 0.8033023747530851


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[112] train-loss: 0.15882619069172785


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[113] train-loss: 0.1987407528437101


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[114] train-loss: 0.44817898823664737


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[114] validation-loss: 0.6762187101624229


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[115] train-loss: 0.8476101251748892


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[116] train-loss: 0.9406792567326472


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[117] train-loss: 0.2157846368276156


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[117] validation-loss: 1.1575958945534446


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[118] train-loss: 0.4188659099432138


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[119] train-loss: 2.411003213662368


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[120] train-loss: 2.3710128527421217


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[120] validation-loss: 3.282827290621671


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[121] train-loss: 0.47028053723848784


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[122] train-loss: 0.9574272540899423


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[123] train-loss: 1.5189459094634423


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[123] validation-loss: 4.788549076427113


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[124] train-loss: 0.6236601150952853


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[125] train-loss: 0.7746518208430364


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[126] train-loss: 1.9374905732961802


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[126] validation-loss: 7.198960564353249


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[127] train-loss: 1.2301386503072886


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[128] train-loss: 0.35352116364699143


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[129] train-loss: 0.5245729409731351


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[129] validation-loss: 1.556033806367354


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[130] train-loss: 0.6158557855165921


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[131] train-loss: 0.26788904116703915


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[132] train-loss: 0.5548319266392634


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[132] validation-loss: 1.210873387076638


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[133] train-loss: 0.4683874799655034


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[134] train-loss: 1.490951982828287


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[135] train-loss: 0.5193558014356173


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[135] validation-loss: 1.2488508224487305


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[136] train-loss: 3.6341281670790453


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[137] train-loss: 1.0522357317117543


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[138] train-loss: 2.6529003519278307


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[138] validation-loss: 1.6202103224667637


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[139] train-loss: 0.28411851020959705


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[140] train-loss: 0.517912704211015


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[141] train-loss: 0.7857431906920213


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[141] validation-loss: 1.5591056563637473


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[142] train-loss: 1.1644750833511353


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[143] train-loss: 0.6745834717383752


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[144] train-loss: 0.6673956559254572


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[144] validation-loss: 1.0699769800359553


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[145] train-loss: 0.45308318504920375


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[146] train-loss: 0.6055065164199243


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[147] train-loss: 1.0042479221637433


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[147] validation-loss: 0.9643911881880327


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[148] train-loss: 3.080463097645686


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[149] train-loss: 0.6844574625675495


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[150] train-loss: 0.9661425535495465


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[150] validation-loss: 0.8775351697748358


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[151] train-loss: 0.6944759488105774


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[152] train-loss: 0.44315186830667347


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[153] train-loss: 0.4485545525184044


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[153] validation-loss: 2.1251635551452637


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[154] train-loss: 0.48571857466147494


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[155] train-loss: 0.9042889705071082


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[156] train-loss: 0.67147440635241


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[156] validation-loss: 3.252166357907382


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[157] train-loss: 0.8855238877809964


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[158] train-loss: 0.7463860534704648


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[159] train-loss: 1.0983143953176646


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[159] validation-loss: 3.929164409637451


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[160] train-loss: 0.9805286480830266


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[161] train-loss: 1.0044844700739934


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[162] train-loss: 0.770670799108652


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[162] validation-loss: 3.688605481928045


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[163] train-loss: 0.48267873433920055


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[164] train-loss: 1.0852878093719482


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[165] train-loss: 1.3823694449204664


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[165] validation-loss: 4.4565323916348545


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[166] train-loss: 2.0014943984838633


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[167] train-loss: 1.7241859344335704


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[168] train-loss: 1.2791144572771513


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[168] validation-loss: 3.825811039317738


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[169] train-loss: 0.47690995839925915


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[170] train-loss: 2.393272014764639


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[171] train-loss: 0.7682941051629874


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[171] validation-loss: 1.0512027090246028


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[172] train-loss: 0.619324280665471


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[173] train-loss: 0.32346580120233387


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[174] train-loss: 0.49285901051301223


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[174] validation-loss: 3.7886055166071113


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[175] train-loss: 0.46043739181298476


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[176] train-loss: 0.9428167343139648


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[177] train-loss: 0.4787313754741962


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[177] validation-loss: 4.570853320035067


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[178] train-loss: 0.20384882963620699


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[179] train-loss: 0.5137428320371188


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[180] train-loss: 0.5098917415508857


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[180] validation-loss: 3.547732266512784


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[181] train-loss: 0.5062405329484206


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[182] train-loss: 0.7422468387163602


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[183] train-loss: 1.5520326174222505


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[183] validation-loss: 3.9419144717129795


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[184] train-loss: 0.29402121672263515


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[185] train-loss: 0.681527506846648


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[186] train-loss: 0.627348386324369


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[186] validation-loss: 3.4557254964655097


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[187] train-loss: 0.795488777068945


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[188] train-loss: 0.5132680397767287


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[189] train-loss: 0.9479586069400494


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[189] validation-loss: 2.708764596418901


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[190] train-loss: 0.4798119251544659


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[191] train-loss: 0.41841456523308385


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[192] train-loss: 0.26482749443787795


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[192] validation-loss: 2.6437007730657403


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[193] train-loss: 0.34636470904717076


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[194] train-loss: 0.3959942001562852


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[195] train-loss: 0.325640660065871


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[195] validation-loss: 1.8624036528847434


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[196] train-loss: 0.7667930859785813


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[197] train-loss: 0.6449989355527438


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[198] train-loss: 1.1521089260394757


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[198] validation-loss: 1.7311008193276145


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[199] train-loss: 0.8257865080466638


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[200] train-loss: 0.23301921670253462


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[201] train-loss: 0.28190128619854266


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[201] validation-loss: 1.7505127733403987


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[202] train-loss: 0.3446185542986943


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[203] train-loss: 0.48292805598332333


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[204] train-loss: 1.0832784130023076


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[204] validation-loss: 1.8860454992814497


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[205] train-loss: 0.8240513251377986


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[206] train-loss: 1.0490835309028625


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[207] train-loss: 0.7470352282890906


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[207] validation-loss: 1.6743242090398616


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[208] train-loss: 0.9722543817300063


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[209] train-loss: 0.6380740037331214


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[210] train-loss: 0.4130225640076857


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[210] validation-loss: 1.0490215691653164


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[211] train-loss: 0.41427721885534435


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[212] train-loss: 0.4795520993379446


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[213] train-loss: 0.5250042310127845


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[213] validation-loss: 0.6193168380043723


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[214] train-loss: 0.20618299337533805


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[215] train-loss: 0.1863419092618502


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[216] train-loss: 0.1575589134142949


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[216] validation-loss: 0.473603909665888


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[217] train-loss: 0.8315379069401667


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[218] train-loss: 0.7600076748774602


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[219] train-loss: 0.719850283402663


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[219] validation-loss: 0.8946479884060946


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[220] train-loss: 0.37333860305639416


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[221] train-loss: 0.4096728884256803


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[222] train-loss: 0.8258827466231126


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[222] validation-loss: 1.3885713707317004


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[223] train-loss: 0.9117445785265702


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[224] train-loss: 0.2844029481594379


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[225] train-loss: 0.461589882007012


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[225] validation-loss: 1.4807695042003284


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[226] train-loss: 0.5472511419883141


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[227] train-loss: 0.3084603456350473


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[228] train-loss: 0.46225832058833194


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[228] validation-loss: 2.1408632451837715


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[229] train-loss: 0.3283095153478476


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[230] train-loss: 0.4454022737649771


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[231] train-loss: 1.0343853326944203


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[231] validation-loss: 1.4545958692377263


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[232] train-loss: 0.8045375347137451


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[233] train-loss: 1.0346691058232234


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[234] train-loss: 0.7935400926149808


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[234] validation-loss: 2.163989543914795


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[235] train-loss: 0.8875889961536114


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[236] train-loss: 0.6732832009975727


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[237] train-loss: 0.5376463990945083


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[237] validation-loss: 2.929401917891069


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[238] train-loss: 0.5060095970447247


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[239] train-loss: 0.32549489461458647


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[240] train-loss: 0.1946441875054286


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[240] validation-loss: 1.7627518393776633


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[241] train-loss: 0.4144339148814862


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[242] train-loss: 0.18877461094122666


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[243] train-loss: 1.7014841529039235


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[243] validation-loss: 1.67660344730724


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[244] train-loss: 0.29706220901929414


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[245] train-loss: 0.8539372865970318


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[246] train-loss: 0.2876010216199435


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[246] validation-loss: 2.1445266116749155


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[247] train-loss: 0.8086815063770001


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[248] train-loss: 0.5323668351540198


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[249] train-loss: 0.20153138041496277


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[249] validation-loss: 0.7655422362414274


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[250] train-loss: 0.5049044260611901


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[251] train-loss: 0.5001212541873639


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[252] train-loss: 0.2522405982017517


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[252] validation-loss: 0.8514665907079523


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[253] train-loss: 0.31891534420160145


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[254] train-loss: 0.3100348435915433


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[255] train-loss: 0.753196445795206


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[255] validation-loss: 0.9733191836964


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[256] train-loss: 0.9295583229798537


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[257] train-loss: 0.569498781974499


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[258] train-loss: 0.9380511389328883


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[258] validation-loss: 1.851506146517667


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[259] train-loss: 0.3103666305541992


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[260] train-loss: 0.7304783876125629


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[261] train-loss: 0.6831599749051608


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[261] validation-loss: 1.1487130251797764


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[262] train-loss: 0.4127205701974722


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[263] train-loss: 0.4642445147037506


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[264] train-loss: 0.8162430799924411


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[264] validation-loss: 1.9225089333274148


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[265] train-loss: 0.4621471625107985


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[266] train-loss: 0.5052532610984949


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[267] train-loss: 0.5756851159609281


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[267] validation-loss: 2.6301734230735083


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[268] train-loss: 0.37648360087321353


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[269] train-loss: 0.3820997063930218


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[270] train-loss: 0.1422868279310373


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[270] validation-loss: 1.203499598936601


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[271] train-loss: 0.189783153625635


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[272] train-loss: 0.409117185152494


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[273] train-loss: 0.19329298918063825


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[273] validation-loss: 0.7986316897652366


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[274] train-loss: 0.23641795836962187


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[275] train-loss: 0.253397503724465


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[276] train-loss: 0.35768510286624616


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[276] validation-loss: 0.7843843806873668


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[277] train-loss: 0.30378912962399995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[278] train-loss: 1.1317318219404955


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[279] train-loss: 0.34801002190663266


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[279] validation-loss: 1.0000936768271707


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[280] train-loss: 0.4499189257621765


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[281] train-loss: 0.5100174775490394


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[282] train-loss: 0.31982452135819656


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[282] validation-loss: 1.0185739994049072


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[283] train-loss: 1.3196001603053167


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[284] train-loss: 0.5791657796272864


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[285] train-loss: 0.3394708266625038


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[285] validation-loss: 0.8385577418587424


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[286] train-loss: 0.6118014867489154


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[287] train-loss: 0.425413415982173


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[288] train-loss: 0.5259061318177444


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[288] validation-loss: 0.5479197285392068


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[289] train-loss: 0.3617287599123441


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[290] train-loss: 0.4108109199083768


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[291] train-loss: 0.36334670048493606


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[291] validation-loss: 0.49589749899777497


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[292] train-loss: 0.29678109746712905


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[293] train-loss: 0.16986804283582246


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[294] train-loss: 0.3172937516982739


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[294] validation-loss: 0.22207770564339377


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[295] train-loss: 1.937154648395685


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[296] train-loss: 0.6250267945803128


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[297] train-loss: 0.12443934724881099


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[297] validation-loss: 0.5994829914786599


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[298] train-loss: 0.23720374818031603


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[299] train-loss: 0.1586704391699571


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[300] train-loss: 0.15417841306099525


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[300] validation-loss: 0.5377556437795813


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[301] train-loss: 0.3144130248289842


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[302] train-loss: 2.2424109394733724


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[303] train-loss: 0.32712826362022984


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[303] validation-loss: 0.6220245903188532


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[304] train-loss: 0.45596903562545776


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[305] train-loss: 0.24879914980668288


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[306] train-loss: 0.27803564071655273


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[306] validation-loss: 0.6875350150195035


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[307] train-loss: 0.28599586624365586


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[308] train-loss: 0.6391289417560284


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[309] train-loss: 0.44680827397566575


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[309] validation-loss: 0.7705133828249845
DONE.


In [4]:
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [5]:
prediction

tensor([[-1.2149e+02,  5.6693e+01, -1.5790e+02,  1.2513e+02,  9.7539e+01],
        [-4.9917e+01, -1.1710e+01, -1.3494e+00,  5.9592e+01,  3.7572e+00],
        [-7.8551e+00,  1.5121e+01, -2.3986e+01,  5.4095e+00,  1.1265e+01],
        [ 2.0342e+01, -1.6284e+01,  1.6081e+01, -5.8594e+00, -1.4447e+01],
        [ 3.4142e+00, -9.1235e-01,  2.1889e+00, -4.1047e+00, -6.6818e-01],
        [ 2.0658e+01, -1.1394e+01,  1.0917e+01, -4.7126e+00, -1.5542e+01],
        [-9.9185e+01,  6.3831e+01, -1.5177e+02,  9.7483e+01,  8.9504e+01],
        [-2.1625e+02,  1.9040e+02, -4.5689e+02,  2.2397e+02,  2.5720e+02],
        [ 1.9817e+01,  1.1169e+00, -1.7972e+00, -5.1613e+00, -1.4060e+01],
        [ 2.9316e+00, -2.7252e-01,  4.8315e-01, -2.4243e+00, -7.8901e-01],
        [ 1.2003e+01, -8.4898e+00,  9.5590e+00, -6.5386e+00, -6.7070e+00],
        [ 1.1108e+01, -8.6098e+00,  8.9377e+00, -4.6214e+00, -6.9799e+00],
        [-3.8377e+01,  2.4375e+01, -5.0766e+01,  2.1482e+01,  4.2994e+01],
        [ 1.9014e+01, -1.

In [6]:
label

tensor([3, 3, 1, 2, 2, 0, 3, 1, 1, 4, 2, 2, 4, 2, 0, 1, 2, 3, 3, 1, 3],
       device='cuda:0')

In [7]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_train_hist, y_train_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [8]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.38      0.99      0.55      6339
           1       0.90      0.45      0.60      6175
           2       0.36      0.00      0.01      6165
           3       0.99      0.89      0.94      6208
           4       0.75      0.72      0.74      6113

    accuracy                           0.61     31000
   macro avg       0.68      0.61      0.57     31000
weighted avg       0.67      0.61      0.57     31000



In [9]:
confusion_matrix(all_labels, all_predictions)

array([[6279,   25,    0,    0,   35],
       [2637, 2806,    4,   20,  708],
       [6123,    2,   23,    0,   17],
       [   0,    4,    0, 5527,  677],
       [1317,  298,   37,   59, 4402]], dtype=int64)

In [10]:
confusion_matrix(all_labels, all_predictions)

array([[6279,   25,    0,    0,   35],
       [2637, 2806,    4,   20,  708],
       [6123,    2,   23,    0,   17],
       [   0,    4,    0, 5527,  677],
       [1317,  298,   37,   59, 4402]], dtype=int64)

In [11]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [12]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.38      0.99      0.55      2727
           1       0.90      0.47      0.62      2714
           2       0.44      0.00      0.01      2710
           3       0.98      0.88      0.93      2714
           4       0.75      0.75      0.75      2735

    accuracy                           0.62     13600
   macro avg       0.69      0.62      0.57     13600
weighted avg       0.69      0.62      0.57     13600



In [13]:
print(confusion_matrix(all_labels, all_predictions))

[[2697   15    0    0   15]
 [1112 1270    5   14  313]
 [2690    0    7    0   13]
 [   0    0    0 2377  337]
 [ 511  125    4   31 2064]]


In [14]:
torch.save({'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': validation_loss.item()}, 'final_resnet')

In [22]:
model = torch.load("checkpoint_best_resnet")

In [16]:
model = CustomResNet(num_classes=5)
checkpoint = torch.load("checkpoint_best_default")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
model.cuda()

CustomResNet(
  (pre): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [17]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = model(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [18]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.72      0.27      0.40      2727
           1       0.00      0.00      0.00      2714
           2       0.00      0.00      0.00      2710
           3       0.57      1.00      0.73      2714
           4       0.14      0.40      0.21      2735

    accuracy                           0.34     13600
   macro avg       0.29      0.34      0.27     13600
weighted avg       0.29      0.34      0.27     13600



  _warn_prf(average, modifier, msg_start, len(result))


In [19]:
print(confusion_matrix(all_labels, all_predictions))

[[ 743    0    0   37 1947]
 [ 166    0    0  325 2223]
 [ 123    0    0   46 2541]
 [   0    0    0 2713    1]
 [   0    0    0 1628 1107]]
