In [1]:
import torch
import pickle

with open( "x_train_hist.p", "rb" ) as file:
    x_train_hist = pickle.load(file)
    
with open( "y_train_hist.p", "rb" ) as file:
    y_train_hist = pickle.load(file)
    
with open( "x_valid_hist.p", "rb" ) as file:
    x_valid_hist = pickle.load(file)
    
with open( "y_valid_hist.p", "rb" ) as file:
    y_valid_hist = pickle.load(file)

In [2]:
import torch
import random
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm

from mriqa_dataset import MRIQADataset
from networks import ClassicCNN, PhilsClassicCnn, CatNet

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696



In [9]:
# set random seeds for reproducibility
random.seed(21062020)
np.random.seed(21062020)
torch.manual_seed(21062020)
torch.cuda.manual_seed(21062020)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(torch.cuda.current_device())
torch.cuda.set_device(0)

net = ClassicCNN(num_classes=5)
net = net.cuda()

optimizer = optim.Adam(net.parameters())
ce = CrossEntropyLoss().cuda()
num_epochs = len(x_train_hist)
num_mini_batches = 13

loss_csv = open('losses.csv', 'w')
loss_csv.write('epoch,training,validation\n')

best_val_loss = 999
print("start training")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    net.train()
    
    train_batches = x_train_hist[epoch]
    train_labels = y_train_hist[epoch]

    # train loop
    for sample, label in tqdm(zip(train_batches, train_labels), total=len(train_batches), leave=False):       
        sample = sample.cuda()
        label = label.cuda()

        prediction = net(sample)
        loss = ce(prediction, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print('[{}] train-loss: {}'.format(epoch, epoch_loss / num_mini_batches))
    loss_csv.write(str(epoch) + ',' + str(epoch_loss / num_mini_batches))
    loss_csv.flush()

    # validation loop
    net.eval()
    mean_validation_loss = 0
    num_validation_mini_batches = 11

    if epoch % 3 == 0:
        valid_batches = x_valid_hist[epoch//5]
        valid_labels = y_valid_hist[epoch//5]
        
        with torch.no_grad():
            for sample, label in tqdm(zip(valid_batches, valid_labels), total=len(valid_batches), leave=False):              
                sample = sample.cuda()
                label = label.cuda()

                prediction = net(sample)
                validation_loss = ce(prediction, label)

                mean_validation_loss += validation_loss.item()
                
        print('[{}] validation-loss: {}'.format(epoch, mean_validation_loss / num_validation_mini_batches))
        loss_csv.write(',' + str(mean_validation_loss / num_validation_mini_batches) + '\n')
        loss_csv.flush()

    # save best model
    if mean_validation_loss <= best_val_loss:
        torch.save({'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': validation_loss.item()}, 'checkpoint_best_default')
        best_val_loss = mean_validation_loss

print('DONE.')

0
start training


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[0] train-loss: 0.42341728393848127


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[0] validation-loss: 0.4380675662647594


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[1] train-loss: 0.3507945720966046


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[2] train-loss: 0.38625192642211914


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[3] train-loss: 0.35492175358992356


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[3] validation-loss: 0.7539175640452992


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[4] train-loss: 0.35259028581472546


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[5] train-loss: 0.3265841557429387


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[6] train-loss: 0.3194340467453003


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[6] validation-loss: 0.8955813104456122


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[7] train-loss: 0.3502597717138437


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[8] train-loss: 0.3467416442357577


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[9] train-loss: 0.3254014620414147


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[9] validation-loss: 0.9784387891942804


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[10] train-loss: 0.2619530741985028


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[11] train-loss: 0.30314395977900577


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[12] train-loss: 0.3160180082687965


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[12] validation-loss: 0.8985635800795122


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[13] train-loss: 0.3661050017063434


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[14] train-loss: 0.25600262788625866


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[15] train-loss: 0.20970718677227312


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[15] validation-loss: 1.351287841796875


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[16] train-loss: 0.20951837301254272


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[17] train-loss: 0.24815439260922945


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[18] train-loss: 0.2439953638957097


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[18] validation-loss: 1.496446045962247


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[19] train-loss: 0.34168365368476283


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[20] train-loss: 0.3262795714231638


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[21] train-loss: 0.27080288300147426


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[21] validation-loss: 1.535530523820357


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[22] train-loss: 0.30969395545812756


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[23] train-loss: 0.21452608016821054


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[24] train-loss: 0.3365901754452632


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[24] validation-loss: 2.118982575156472


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[25] train-loss: 0.25214550586847156


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[26] train-loss: 0.29902841494633603


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[27] train-loss: 0.24929282298454872


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[27] validation-loss: 1.7768485329367898


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[28] train-loss: 0.3670247976596539


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[29] train-loss: 0.3916344321691073


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[30] train-loss: 0.3534027406802544


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[30] validation-loss: 2.7952571348710493


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[31] train-loss: 0.41274643861330473


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[32] train-loss: 0.3726689632122333


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[33] train-loss: 0.1973001269193796


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[33] validation-loss: 3.254867900501598


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[34] train-loss: 0.19262322783470154


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[35] train-loss: 0.3420147391465994


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[36] train-loss: 0.3088432137782757


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[36] validation-loss: 4.513531511480158


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[37] train-loss: 0.28187994773571307


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[38] train-loss: 0.24100836194478548


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[39] train-loss: 0.2713494667640099


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[39] validation-loss: 3.432167400013317


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[40] train-loss: 0.31826329689759475


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[41] train-loss: 0.4420559589679425


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[42] train-loss: 0.3142971350596501


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[42] validation-loss: 2.837898774580522


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[43] train-loss: 0.3181886902222267


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[44] train-loss: 0.38805202795909


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[45] train-loss: 0.2697858581176171


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[45] validation-loss: 2.7287784923206675


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[46] train-loss: 0.5349390323345478


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[47] train-loss: 0.3494921555885902


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[48] train-loss: 0.4311420229765085


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[48] validation-loss: 1.7628337686712092


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[49] train-loss: 0.42794187710835385


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[50] train-loss: 0.20396982935758737


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[51] train-loss: 0.28523128766279954


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[51] validation-loss: 0.7967987060546875


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[52] train-loss: 0.1870497052486126


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[53] train-loss: 0.38926950784829945


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[54] train-loss: 0.2222824555176955


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[54] validation-loss: 0.6971090706911954


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[55] train-loss: 0.21793772624089167


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[56] train-loss: 0.3934914286320026


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[57] train-loss: 0.2922425132531386


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[57] validation-loss: 1.019420558756048


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[58] train-loss: 0.4393705954918495


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[59] train-loss: 0.3432920300043546


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[60] train-loss: 0.4095709782380324


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[60] validation-loss: 0.8764701973308217


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[61] train-loss: 0.29751409934117246


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[62] train-loss: 0.49853490866147554


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[63] train-loss: 0.334253911788647


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[63] validation-loss: 0.7370816360820424


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[64] train-loss: 0.32731178632149327


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[65] train-loss: 0.1918871975862063


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[66] train-loss: 0.30471922342593855


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[66] validation-loss: 0.9118036573583429


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[67] train-loss: 0.26399338245391846


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[68] train-loss: 0.26685526967048645


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[69] train-loss: 0.33771945192263675


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[69] validation-loss: 0.6935779506509955


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[70] train-loss: 0.20428738227257362


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[71] train-loss: 0.1652642167531527


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[72] train-loss: 0.18008524179458618


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[72] validation-loss: 0.5655303326520053


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[73] train-loss: 0.3197782956636869


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[74] train-loss: 0.1728502741226783


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[75] train-loss: 0.27641671895980835


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[75] validation-loss: 0.8844837492162531


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[76] train-loss: 0.26454304731809175


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[77] train-loss: 0.2956060813023494


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[78] train-loss: 0.3646992353292612


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[78] validation-loss: 0.8848142407157205


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[79] train-loss: 0.22430206262148344


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[80] train-loss: 0.3023696725185101


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[81] train-loss: 0.2349852048433744


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[81] validation-loss: 0.6220294562253085


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[82] train-loss: 0.3676414994093088


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[83] train-loss: 0.44942540847338164


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[84] train-loss: 0.43424217517559344


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[84] validation-loss: 0.6233757409182462


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[85] train-loss: 0.2571628735615657


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[86] train-loss: 0.2398904378597553


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[87] train-loss: 0.4979912271866432


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[87] validation-loss: 0.4458650025454434


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[88] train-loss: 0.45757810886089617


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[89] train-loss: 0.36885859416081357


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[90] train-loss: 0.275381672840852


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[90] validation-loss: 0.40796609358354047


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[91] train-loss: 0.32488630138910735


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[92] train-loss: 0.1877019588763897


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[93] train-loss: 0.3417330773977133


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[93] validation-loss: 0.3121967315673828


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[94] train-loss: 0.21393106533930853


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[95] train-loss: 0.3006858779833867


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[96] train-loss: 0.23512874658291155


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[96] validation-loss: 0.38178780945864593


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[97] train-loss: 0.21107676854500404


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[98] train-loss: 0.28266456035467297


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[99] train-loss: 0.21728535340382502


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[99] validation-loss: 0.5220266038721258


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[100] train-loss: 0.18921577013455904


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[101] train-loss: 0.14460445596621588


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[102] train-loss: 0.22514711664273188


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[102] validation-loss: 0.6021915999325839


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[103] train-loss: 0.18322674127725455


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[104] train-loss: 0.19800645800737235


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[105] train-loss: 0.19319369701238778


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[105] validation-loss: 0.6845428076657382


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[106] train-loss: 0.2498872050872216


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[107] train-loss: 0.18577548403006333


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[108] train-loss: 0.18078124064665574


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[108] validation-loss: 0.5557658130472357


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[109] train-loss: 0.21797972000562227


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[110] train-loss: 0.21102570111934954


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[111] train-loss: 0.19614078219120318


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[111] validation-loss: 0.4094858711416071


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[112] train-loss: 0.2464060095640329


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[113] train-loss: 0.24066236385932335


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[114] train-loss: 0.33390371616070086


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[114] validation-loss: 0.30911002375862817


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[115] train-loss: 0.4003307590117821


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[116] train-loss: 0.4638450833467337


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[117] train-loss: 0.2425688046675462


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[117] validation-loss: 0.25621467286890204


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[118] train-loss: 0.28571218710679275


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[119] train-loss: 0.3466767301926246


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[120] train-loss: 0.31854531398186314


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[120] validation-loss: 0.18586986715143378


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[121] train-loss: 0.3425790575834421


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[122] train-loss: 0.2917117201364957


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[123] train-loss: 0.31102484923142654


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[123] validation-loss: 0.16689279675483704


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[124] train-loss: 0.21823823910493118


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[125] train-loss: 0.21599270518009478


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[126] train-loss: 0.3502857272441571


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[126] validation-loss: 0.20109735293821854


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[127] train-loss: 0.3888457509187552


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[128] train-loss: 0.18462701485707209


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[129] train-loss: 0.2794533371925354


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[129] validation-loss: 0.20378374511545355


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[130] train-loss: 0.2654757591394278


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[131] train-loss: 0.27094635596642125


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[132] train-loss: 0.20992767352324265


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[132] validation-loss: 0.1680654829198664


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[133] train-loss: 0.20808581205514762


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[134] train-loss: 0.3365643849739662


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[135] train-loss: 0.2489586197412931


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[135] validation-loss: 0.24300229007547552


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[136] train-loss: 0.2196094852227431


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[137] train-loss: 0.22048388077662542


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[138] train-loss: 0.27990391621222865


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[138] validation-loss: 0.23267974094911056


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[139] train-loss: 0.18457243992732122


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[140] train-loss: 0.22200559652768648


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[141] train-loss: 0.30425719114450306


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[141] validation-loss: 0.26386457681655884


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[142] train-loss: 0.31439295640358556


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[143] train-loss: 0.2676405356480525


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[144] train-loss: 0.1959784489411574


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[144] validation-loss: 0.24870925058018079


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[145] train-loss: 0.1920710985477154


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[146] train-loss: 0.20016279587378868


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[147] train-loss: 0.19753258044903094


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[147] validation-loss: 0.2524265646934509


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[148] train-loss: 0.15589526066413292


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[149] train-loss: 0.1831818360548753


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[150] train-loss: 0.239134655548976


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[150] validation-loss: 0.19151342998851428


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[151] train-loss: 0.15461141329545242


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[152] train-loss: 0.2919951035426213


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[153] train-loss: 0.15272740217355582


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[153] validation-loss: 0.20206820422952826


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[154] train-loss: 0.17593777179718018


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[155] train-loss: 0.2910252075928908


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[156] train-loss: 0.2471809799854572


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[156] validation-loss: 0.23173268274827438


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[157] train-loss: 0.2078191890166356


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[158] train-loss: 0.14884090652832618


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[159] train-loss: 0.16884370950552133


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[159] validation-loss: 0.2408955151384527


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[160] train-loss: 0.18717697033515343


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[161] train-loss: 0.17510931079204267


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[162] train-loss: 0.17270711522835952


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[162] validation-loss: 0.18860119581222534


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[163] train-loss: 0.15381132868620065


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[164] train-loss: 0.20310905346503624


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[165] train-loss: 0.21413850784301758


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[165] validation-loss: 0.2822602011940696


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[166] train-loss: 0.2655196121105781


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[167] train-loss: 0.30775060332738435


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[168] train-loss: 0.20753845343222985


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[168] validation-loss: 0.2714631774208762


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[169] train-loss: 0.18611975358082697


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[170] train-loss: 0.27955958247184753


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[171] train-loss: 0.25326473667071414


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[171] validation-loss: 0.2453188571062955


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[172] train-loss: 0.16770145067801842


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[173] train-loss: 0.15709512279583857


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[174] train-loss: 0.2243136282150562


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[174] validation-loss: 0.24059524861249057


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[175] train-loss: 0.15183493724236122


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[176] train-loss: 0.15821207945163435


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[177] train-loss: 0.22625751678760236


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[177] validation-loss: 0.21193553100932727


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[178] train-loss: 0.16156264222585237


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[179] train-loss: 0.16290463851048395


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[180] train-loss: 0.16167932748794556


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[180] validation-loss: 0.24381887370889838


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[181] train-loss: 0.22620294415033781


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[182] train-loss: 0.20772253091518694


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[183] train-loss: 0.28377363085746765


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[183] validation-loss: 0.22388346086848865


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[184] train-loss: 0.14745347774945772


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[185] train-loss: 0.2169713836449843


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[186] train-loss: 0.13830609504993147


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[186] validation-loss: 0.19007555463097311


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[187] train-loss: 0.10376079036639287


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[188] train-loss: 0.20176064509611863


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[189] train-loss: 0.19942782933895403


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[189] validation-loss: 0.20599958571520718


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[190] train-loss: 0.1893136455462529


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[191] train-loss: 0.20728067251352164


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[192] train-loss: 0.16365732137973493


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[192] validation-loss: 0.23590533841740002


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[193] train-loss: 0.19807085624107948


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[194] train-loss: 0.20607528778222892


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[195] train-loss: 0.1851056951742906


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[195] validation-loss: 0.20725520090623337


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[196] train-loss: 0.3777736792197594


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[197] train-loss: 0.16442282841755793


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[198] train-loss: 0.23489312942211443


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[198] validation-loss: 0.17716248468919235


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[199] train-loss: 0.15524354577064514


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[200] train-loss: 0.1596132585635552


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[201] train-loss: 0.17378923296928406


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[201] validation-loss: 0.2196437499739907


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[202] train-loss: 0.27523988026839036


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[203] train-loss: 0.17630068155435416


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[204] train-loss: 0.1679877627354402


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[204] validation-loss: 0.25539523363113403


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[205] train-loss: 0.27506731565182024


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[206] train-loss: 0.2254018783569336


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[207] train-loss: 0.22793070857341474


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[207] validation-loss: 0.18999731540679932


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[208] train-loss: 0.21311988968115586


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[209] train-loss: 0.2991527708677145


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[210] train-loss: 0.16829432890965387


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[210] validation-loss: 0.15418747067451477


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[211] train-loss: 0.23552852869033813


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[212] train-loss: 0.12393495211234459


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[213] train-loss: 0.20754629373550415


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[213] validation-loss: 0.16572658311236987


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[214] train-loss: 0.14480966329574585


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[215] train-loss: 0.3220497713639186


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[216] train-loss: 0.1593986818423638


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[216] validation-loss: 0.166817616332661


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[217] train-loss: 0.3470430695093595


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[218] train-loss: 0.25882338560544527


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[219] train-loss: 0.18449369302162758


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[219] validation-loss: 0.15524968775835904


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[220] train-loss: 0.12961703538894653


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[221] train-loss: 0.355316352385741


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[222] train-loss: 0.30544958206323475


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[222] validation-loss: 0.09425569664348256


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[223] train-loss: 0.1576694387655992


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[224] train-loss: 0.10357460264976208


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[225] train-loss: 0.35181012749671936


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[225] validation-loss: 0.08461710810661316


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[226] train-loss: 0.11648302582594064


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[227] train-loss: 0.11903101893571708


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[228] train-loss: 0.21225548249024612


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[228] validation-loss: 0.08621614358641884


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[229] train-loss: 0.1726722029539255


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[230] train-loss: 0.28469257859083325


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[231] train-loss: 0.4436582854160896


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[231] validation-loss: 0.06359713930975307


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[232] train-loss: 0.1835676087782933


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[233] train-loss: 0.10669443469781142


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[234] train-loss: 0.12003496060004601


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[234] validation-loss: 0.09791587564078244


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[235] train-loss: 0.07104930167014782


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[236] train-loss: 0.17863426758692816


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[237] train-loss: 0.42866922800357526


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[237] validation-loss: 0.20801458846439014


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[238] train-loss: 0.10138524266389701


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[239] train-loss: 0.11493583825918344


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[240] train-loss: 0.12566890166356012


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[240] validation-loss: 0.2803910049525174


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[241] train-loss: 0.11350958393170284


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[242] train-loss: 0.13458298032100385


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[243] train-loss: 0.46206493790333086


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[243] validation-loss: 0.37547601894898847


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[244] train-loss: 0.19472890633803147


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[245] train-loss: 0.3217836137001331


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[246] train-loss: 0.2585893686001117


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[246] validation-loss: 0.5248021862723611


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[247] train-loss: 0.2588069484784053


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[248] train-loss: 0.24282116844103888


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[249] train-loss: 0.13198582713420576


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[249] validation-loss: 0.5173618468371305


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[250] train-loss: 0.2517496874699226


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[251] train-loss: 0.23950735193032485


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[252] train-loss: 0.22110399833092323


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[252] validation-loss: 0.5398709557273171


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[253] train-loss: 0.3229786639030163


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[254] train-loss: 0.17893857910082892


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[255] train-loss: 0.22296856687619135


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[255] validation-loss: 0.2806470069018277


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[256] train-loss: 0.40876664565159726


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[257] train-loss: 0.19548027790509737


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[258] train-loss: 0.28181283061320966


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[258] validation-loss: 0.21916675025766547


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[259] train-loss: 0.15436282524695763


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[260] train-loss: 0.2477232596048942


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[261] train-loss: 0.09044450521469116


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[261] validation-loss: 0.3113218871029941


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[262] train-loss: 0.13890516757965088


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[263] train-loss: 0.22224258459531343


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[264] train-loss: 0.15459217933508065


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[264] validation-loss: 0.3102596673098477


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[265] train-loss: 0.13612683232013995


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[266] train-loss: 0.13798677692046532


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[267] train-loss: 0.1532556964800908


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[267] validation-loss: 0.2939804738218134


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[268] train-loss: 0.08216489507601811


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[269] train-loss: 0.21270855573507455


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[270] train-loss: 0.14308025057499224


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[270] validation-loss: 0.15816489132967862


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[271] train-loss: 0.3014990503971393


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[272] train-loss: 0.35457545748123753


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[273] train-loss: 0.12761941322913536


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[273] validation-loss: 0.15858991037715564


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[274] train-loss: 0.13192653197508591


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[275] train-loss: 0.09029131440015939


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[276] train-loss: 0.23993620047202477


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[276] validation-loss: 0.15729716691103848


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[277] train-loss: 0.13938506749960092


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[278] train-loss: 0.2816247733739706


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[279] train-loss: 0.21012987540318415


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[279] validation-loss: 0.1539155434478413


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[280] train-loss: 0.3878474602332482


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[281] train-loss: 0.16102816049869245


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[282] train-loss: 0.10898060982043926


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[282] validation-loss: 0.2257348190654408


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[283] train-loss: 0.2923880104835217


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[284] train-loss: 0.13460693909571722


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[285] train-loss: 0.1437571128973594


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[285] validation-loss: 0.24196385795419867


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[286] train-loss: 0.23814266461592454


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[287] train-loss: 0.1413058523948376


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[288] train-loss: 0.1631349279330327


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[288] validation-loss: 0.20974843339486557


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[289] train-loss: 0.1381316643494826


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[290] train-loss: 0.10876666238674751


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[291] train-loss: 0.14847710728645325


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[291] validation-loss: 0.16818052530288696


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[292] train-loss: 0.14295242153681242


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[293] train-loss: 0.14317604899406433


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[294] train-loss: 0.23985415697097778


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[294] validation-loss: 0.14985445954582907


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[295] train-loss: 0.2131994916842534


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[296] train-loss: 0.12357217073440552


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[297] train-loss: 0.10828957190880409


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[297] validation-loss: 0.11291038990020752


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[298] train-loss: 0.169912746319404


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[299] train-loss: 0.09695321779984695


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[300] train-loss: 0.10055775367296658


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[300] validation-loss: 0.11102031035856767


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[301] train-loss: 0.2804058836056636


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[302] train-loss: 0.23382666477790245


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[303] train-loss: 0.10160934466582078


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[303] validation-loss: 0.12032041495496576


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[304] train-loss: 0.2544305290167148


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[305] train-loss: 0.08974884335811321


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[306] train-loss: 0.11974337697029114


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[306] validation-loss: 0.1426882418719205


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[307] train-loss: 0.19191545706528884


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[308] train-loss: 0.37299243761942935


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[309] train-loss: 0.20720622860468352


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[309] validation-loss: 0.14478055997328323
DONE.


In [10]:
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [11]:
prediction

tensor([[ -8.9970,  -4.4122, -16.9409,  15.0178,  -2.2451],
        [ -0.8023, -13.6114,  -1.6850,  17.6273, -11.4828],
        [ -0.9889,   0.5246,  -2.2444,  -2.6369,   3.2851],
        [  2.7983,  -5.5457,   3.3898,  -2.1274,  -1.0825],
        [  1.2322,  -1.0881,   1.8989,  -3.3888,   1.4466],
        [  5.0590,  -4.1397,   1.3992,  -4.7918,  -2.9194],
        [ -3.7646,  -3.1720,  -5.9572,   6.3695,   0.4226],
        [ -2.4624,   5.6199, -16.9658,  -2.3868,   0.4178],
        [  1.1986,  -0.1129,  -2.8071,  -3.4948,  -0.1528],
        [  1.0674,  -0.9836,   2.2941,  -4.1562,   2.2164],
        [  1.6949,  -2.7851,   3.0106,  -3.4684,   1.0700],
        [  1.4913,  -2.0166,   3.0791,  -5.3501,   2.5055],
        [ -1.8201,   0.4987,  -1.5850,  -3.0029,   4.8967],
        [  2.6256,  -5.0120,   3.5869,  -2.6930,  -0.3711],
        [  4.7152,  -2.3928,   0.4477,  -5.8855,  -2.4076],
        [  3.7542,   0.1290,  -2.8193,  -5.3375,  -2.7252],
        [  1.1781,  -0.8517,   1.7547,  

In [12]:
label

tensor([3, 3, 1, 2, 2, 0, 3, 1, 1, 4, 2, 2, 4, 2, 0, 1, 2, 3, 3, 1, 3],
       device='cuda:0')

In [13]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_train_hist, y_train_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [14]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.69      0.93      0.79      6339
           1       0.99      0.38      0.55      6175
           2       0.86      0.92      0.89      6165
           3       1.00      0.99      1.00      6208
           4       0.73      0.86      0.79      6113

    accuracy                           0.82     31000
   macro avg       0.85      0.82      0.80     31000
weighted avg       0.85      0.82      0.80     31000



In [15]:
confusion_matrix(all_labels, all_predictions)

array([[5888,    4,  322,    0,  125],
       [1961, 2331,  104,    1, 1778],
       [ 425,    0, 5696,    0,   44],
       [   6,    0,    0, 6163,   39],
       [ 284,   15,  539,    0, 5275]], dtype=int64)

In [16]:
confusion_matrix(all_labels, all_predictions)

array([[5888,    4,  322,    0,  125],
       [1961, 2331,  104,    1, 1778],
       [ 425,    0, 5696,    0,   44],
       [   6,    0,    0, 6163,   39],
       [ 284,   15,  539,    0, 5275]], dtype=int64)

In [17]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [18]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.69      0.94      0.80      2727
           1       0.99      0.38      0.55      2714
           2       0.87      0.92      0.90      2710
           3       1.00      0.99      0.99      2714
           4       0.72      0.87      0.79      2735

    accuracy                           0.82     13600
   macro avg       0.85      0.82      0.80     13600
weighted avg       0.85      0.82      0.80     13600



In [19]:
print(confusion_matrix(all_labels, all_predictions))

[[2551    3  116    0   57]
 [ 811 1028   48    0  827]
 [ 186    0 2502    0   22]
 [   2    0    0 2682   30]
 [ 132    5  208    0 2390]]


In [20]:
torch.save({'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': validation_loss.item()}, 'final_default')

In [22]:
model = torch.load("checkpoint_best_default")

In [26]:
model = ClassicCNN(num_classes=5)
optim.Adam(net.parameters())

checkpoint = torch.load("checkpoint_best_default")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
model.cuda()

ClassicCNN(
  (block1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (block2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block3): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (block4): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dil

In [27]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = model(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [28]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.66      0.94      0.78      2727
           1       1.00      0.33      0.50      2714
           2       0.85      0.91      0.88      2710
           3       1.00      0.99      1.00      2714
           4       0.73      0.86      0.79      2735

    accuracy                           0.81     13600
   macro avg       0.85      0.81      0.79     13600
weighted avg       0.85      0.81      0.79     13600



In [29]:
print(confusion_matrix(all_labels, all_predictions))

[[2563    1  117    0   46]
 [ 959  897   57    0  801]
 [ 212    0 2479    0   19]
 [   2    0    0 2691   21]
 [ 140    3  250    0 2342]]
