In [1]:
import torch
import pickle

with open( "x_train_hist.p", "rb" ) as file:
    x_train_hist = pickle.load(file)
    
with open( "y_train_hist.p", "rb" ) as file:
    y_train_hist = pickle.load(file)
    
with open( "x_valid_hist.p", "rb" ) as file:
    x_valid_hist = pickle.load(file)
    
with open( "y_valid_hist.p", "rb" ) as file:
    y_valid_hist = pickle.load(file)

In [2]:
import torch
import random
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm

from mriqa_dataset import MRIQADataset
from networks import ClassicCNN, PhilsClassicCnn, CatNet

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696



In [3]:
import torch_optimizer as optim
num_epochs = len(x_train_hist)

net = CatNet(num_classes=5)
net = net.cuda()

#optimizer = optim.Adam(net.parameters())
optimizer = optim.Ranger(
    net.parameters(),
    lr=9e-4,
    alpha=0.5,
    k=6,
    N_sma_threshhold=5,
    betas=(.95, 0.999),
    eps=1e-5,
    weight_decay=0
)

flat_lr = round(num_epochs * 0.7)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(num_epochs-flat_lr))
#lmbda = lambda epoch: 0.95
#scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)
ce = CrossEntropyLoss().cuda()

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']



print(f"I will do flat Learning rate for {flat_lr} epochs and then do {num_epochs-flat_lr} epochs cosine annealing")

I will do flat Learning rate for 217 epochs and then do 93 epochs cosine annealing


In [4]:
best_val_loss = 999

loss_csv = open('losses.csv', 'w')
loss_csv.write('epoch,training,validation\n')
num_mini_batches = 13
    
print("start training")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    net.train()
    
    train_batches = x_train_hist[epoch]
    train_labels = y_train_hist[epoch]

    # train loop
    for sample, label in tqdm(zip(train_batches, train_labels), total=len(train_batches), leave=False):       
        sample = sample.cuda()
        label = label.cuda()

        prediction = net(sample)
        loss = ce(prediction, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
        optimizer.step()
        epoch_loss += loss.item()
    
    # do annealing after 70% of the training time
    if epoch > flat_lr:
        scheduler.step()

    print('[{}] train-loss: {}'.format(epoch, epoch_loss / num_mini_batches))
    loss_csv.write(str(epoch) + ',' + str(epoch_loss / num_mini_batches))
    loss_csv.flush()

    # validation loop
    net.eval()
    mean_validation_loss = 0
    num_validation_mini_batches = 11
    
    if epoch % 3 == 0:
        valid_batches = x_valid_hist[epoch//5]
        valid_labels = y_valid_hist[epoch//5]
        with torch.no_grad():
            for sample, label in tqdm(zip(valid_batches, valid_labels), total=len(valid_batches), leave=False):              
                sample = sample.cuda()
                label = label.cuda()

                prediction = net(sample)
                validation_loss = ce(prediction, label)

                mean_validation_loss += validation_loss.item()
            print(f'[{epoch}] validation-loss: {mean_validation_loss / num_validation_mini_batches} Lerning Rate {get_lr(optimizer)}')
            loss_csv.write(',' + str(mean_validation_loss / num_validation_mini_batches) + '\n')
            loss_csv.flush()

        # save best model
        if mean_validation_loss <= best_val_loss:
            torch.save({'epoch': epoch,
                        'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': validation_loss.item()}, 'checkpoint_best')
            best_val_loss = mean_validation_loss

print('DONE.')

start training


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value)


[0] train-loss: 0.49916683710538423


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[0] validation-loss: 0.43635866858742456 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[1] train-loss: 0.5042969080118033


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[2] train-loss: 0.502527851324815


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[3] train-loss: 0.49011211211864764


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[3] validation-loss: 0.4354281642220237 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[4] train-loss: 0.4894049626130324


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[5] train-loss: 0.4703590044608483


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[6] train-loss: 0.4594911520297711


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[6] validation-loss: 0.4453528035770763 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[7] train-loss: 0.46967409207270694


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[8] train-loss: 0.4401281705269447


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[9] train-loss: 0.4788716664681068


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[9] validation-loss: 0.4659023935144598 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[10] train-loss: 0.41562962532043457


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[11] train-loss: 0.4464631997621976


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[12] train-loss: 0.43993000800792986


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[12] validation-loss: 0.4266378012570468 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[13] train-loss: 0.470884607388423


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[14] train-loss: 0.4073796639075646


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[15] train-loss: 0.4105930970265315


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[15] validation-loss: 0.3861066861586137 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[16] train-loss: 0.40913432378035325


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[17] train-loss: 0.4116976811335637


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[18] train-loss: 0.4236655326989981


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[18] validation-loss: 0.34754781289534137 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[19] train-loss: 0.42576966835902286


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[20] train-loss: 0.38862869372734654


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[21] train-loss: 0.40481441754561204


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[21] validation-loss: 0.35848441990939056 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[22] train-loss: 0.41857856053572434


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[23] train-loss: 0.39249067123119646


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[24] train-loss: 0.3916969574414767


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[24] validation-loss: 0.33690397305922076 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[25] train-loss: 0.3837587099808913


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[26] train-loss: 0.38879701724419224


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[27] train-loss: 0.3746478282488309


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[27] validation-loss: 0.34350822188637475 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[28] train-loss: 0.3741900370671199


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[29] train-loss: 0.40846112141242397


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[30] train-loss: 0.36036332753988415


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[30] validation-loss: 0.32118126479062165 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[31] train-loss: 0.3762082503392146


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[32] train-loss: 0.36641988387474644


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[33] train-loss: 0.3324785049145038


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[33] validation-loss: 0.31186134164983575 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[34] train-loss: 0.35815569987663853


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[35] train-loss: 0.3632758122224074


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[36] train-loss: 0.32744239843808687


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[36] validation-loss: 0.3447266600348733 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[37] train-loss: 0.32603997450608474


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[38] train-loss: 0.33413854929117054


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[39] train-loss: 0.3163247246008653


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[39] validation-loss: 0.30219819329001685 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[40] train-loss: 0.3254401317009559


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[41] train-loss: 0.31918196494762713


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[42] train-loss: 0.37064001193413365


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[42] validation-loss: 0.3085981932553378 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[43] train-loss: 0.33168972455538237


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[44] train-loss: 0.32873876278217024


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[45] train-loss: 0.32709018083719105


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[45] validation-loss: 0.31720795414664527 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[46] train-loss: 0.3144616622191209


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[47] train-loss: 0.35395089938090396


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[48] train-loss: 0.3299949077459482


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[48] validation-loss: 0.2996195771477439 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[49] train-loss: 0.3230504531126756


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[50] train-loss: 0.31540873417487514


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[51] train-loss: 0.31526952523451585


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[51] validation-loss: 0.29608034003864636 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[52] train-loss: 0.31543976068496704


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[53] train-loss: 0.32322662610274094


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[54] train-loss: 0.3014892477255601


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[54] validation-loss: 0.30930680578405206 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[55] train-loss: 0.3341328455851628


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[56] train-loss: 0.332170491035168


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[57] train-loss: 0.2936213337458097


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[57] validation-loss: 0.34855798157778656 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[58] train-loss: 0.32551870896266055


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[59] train-loss: 0.301493016573099


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[60] train-loss: 0.31627015425608707


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[60] validation-loss: 0.29563313180750067 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[61] train-loss: 0.32457224222329945


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[62] train-loss: 0.33328341520749605


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[63] train-loss: 0.29879181660138643


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[63] validation-loss: 0.3059271357276223 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[64] train-loss: 0.32299867043128383


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[65] train-loss: 0.32037102717619675


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[66] train-loss: 0.29560611798213077


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[66] validation-loss: 0.36945805766365747 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[67] train-loss: 0.3501461102412297


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[68] train-loss: 0.3010344367760878


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[69] train-loss: 0.3148274284142714


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[69] validation-loss: 0.26892654462294147 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[70] train-loss: 0.30844642565800595


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[71] train-loss: 0.31354783131526065


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[72] train-loss: 0.3120528001051683


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[72] validation-loss: 0.35635369474237616 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[73] train-loss: 0.3383937157117404


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[74] train-loss: 0.3160783556791452


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[75] train-loss: 0.29494993503277117


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[75] validation-loss: 0.33061734112826263 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[76] train-loss: 0.30615771275300246


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[77] train-loss: 0.312660093490894


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[78] train-loss: 0.3415491580963135


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[78] validation-loss: 0.2892673774199052 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[79] train-loss: 0.2940917656971858


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[80] train-loss: 0.3221342380230243


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[81] train-loss: 0.3176697355050307


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[81] validation-loss: 0.33147496526891534 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[82] train-loss: 0.3082576348231389


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[83] train-loss: 0.31940802244039684


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[84] train-loss: 0.31426969399819005


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[84] validation-loss: 0.27651014111258765 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[85] train-loss: 0.3092080492239732


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[86] train-loss: 0.31339472990769607


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[87] train-loss: 0.3408208718666664


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[87] validation-loss: 0.3071114258332686 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[88] train-loss: 0.3114491609426645


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[89] train-loss: 0.32675135135650635


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[90] train-loss: 0.3222773762849661


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[90] validation-loss: 0.2688968506726352 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[91] train-loss: 0.3153402897027823


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[92] train-loss: 0.2874324000798739


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[93] train-loss: 0.32406485080718994


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[93] validation-loss: 0.3493157083337957 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[94] train-loss: 0.31860289665368885


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[95] train-loss: 0.30930669949604916


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[96] train-loss: 0.2955138454070458


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[96] validation-loss: 0.2694063295017589 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[97] train-loss: 0.30423314296282256


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[98] train-loss: 0.30714727823550886


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[99] train-loss: 0.30335070995184094


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[99] validation-loss: 0.2864048751917752 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[100] train-loss: 0.30884966941980213


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[101] train-loss: 0.29455482959747314


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[102] train-loss: 0.31493149353907657


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[102] validation-loss: 0.2763637900352478 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[103] train-loss: 0.3126489290824303


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[104] train-loss: 0.30881510789577776


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[105] train-loss: 0.2991935473221999


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[105] validation-loss: 0.25867423144253815 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[106] train-loss: 0.32594149846297044


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[107] train-loss: 0.31903865245672375


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[108] train-loss: 0.3139895750926091


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[108] validation-loss: 0.3955482244491577 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[109] train-loss: 0.3162131951405452


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[110] train-loss: 0.3129394375360929


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[111] train-loss: 0.3037064579816965


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[111] validation-loss: 0.34258883649652655 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[112] train-loss: 0.3087924076960637


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[113] train-loss: 0.30215129026999843


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[114] train-loss: 0.3128982186317444


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[114] validation-loss: 0.35693802616812964 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[115] train-loss: 0.3555680834330045


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[116] train-loss: 0.32548458301104033


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[117] train-loss: 0.30502288616620576


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[117] validation-loss: 0.26113755052739923 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[118] train-loss: 0.3039194024526156


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[119] train-loss: 0.3152375267102168


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[120] train-loss: 0.346311289530534


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[120] validation-loss: 0.2923421968113292 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[121] train-loss: 0.32046616535920364


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[122] train-loss: 0.3111461951182439


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[123] train-loss: 0.30502350055254424


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[123] validation-loss: 0.3503661805933172 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[124] train-loss: 0.3016143716298617


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[125] train-loss: 0.2978878984084496


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[126] train-loss: 0.3146821673099811


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[126] validation-loss: 0.28378289396112616 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[127] train-loss: 0.31943899393081665


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[128] train-loss: 0.3167330026626587


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[129] train-loss: 0.32672399282455444


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[129] validation-loss: 0.3405072797428478 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[130] train-loss: 0.3221147839839642


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[131] train-loss: 0.32345022605015683


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[132] train-loss: 0.32837312955122727


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[132] validation-loss: 0.3776499032974243 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[133] train-loss: 0.28995400208693284


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[134] train-loss: 0.31305742263793945


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[135] train-loss: 0.2994735057537372


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[135] validation-loss: 0.37174390662800183 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[136] train-loss: 0.2898865800637465


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[137] train-loss: 0.3140704585955693


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[138] train-loss: 0.3414018016595107


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[138] validation-loss: 0.3023890798742121 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[139] train-loss: 0.2984007459420424


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[140] train-loss: 0.33249151706695557


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[141] train-loss: 0.33526955201075626


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[141] validation-loss: 0.2579460089856928 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[142] train-loss: 0.3408766251343947


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[143] train-loss: 0.30871509588681734


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[144] train-loss: 0.30197779031900257


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[144] validation-loss: 0.3207788142290982 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[145] train-loss: 0.2885102996459374


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[146] train-loss: 0.3241365597798274


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[147] train-loss: 0.3038995357660147


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[147] validation-loss: 0.3171453909440474 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[148] train-loss: 0.301070369206942


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[149] train-loss: 0.2948710001431979


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[150] train-loss: 0.32707683398173404


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[150] validation-loss: 0.35632275451313367 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[151] train-loss: 0.29379133077768177


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[152] train-loss: 0.3191033418361957


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[153] train-loss: 0.29189038276672363


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[153] validation-loss: 0.2921054796739058 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[154] train-loss: 0.2921450642439035


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[155] train-loss: 0.3500946851877066


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[156] train-loss: 0.3054706545976492


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[156] validation-loss: 0.34213819287040015 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[157] train-loss: 0.31606406431931716


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[158] train-loss: 0.31609919437995326


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[159] train-loss: 0.3046766519546509


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[159] validation-loss: 0.3606835061853582 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[160] train-loss: 0.33852439201795137


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[161] train-loss: 0.3469184637069702


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[162] train-loss: 0.29824226636153


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[162] validation-loss: 0.29703729261051526 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[163] train-loss: 0.2933225402465233


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[164] train-loss: 0.31843378910651576


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[165] train-loss: 0.30042025217643153


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[165] validation-loss: 0.36604767495935614 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[166] train-loss: 0.3388558534475473


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[167] train-loss: 0.3358268646093515


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[168] train-loss: 0.3089428956692035


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[168] validation-loss: 0.3395573767748746 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[169] train-loss: 0.30189181291140044


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[170] train-loss: 0.30003820474331194


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[171] train-loss: 0.336382833810953


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[171] validation-loss: 0.2786218090490861 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[172] train-loss: 0.296482306260329


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[173] train-loss: 0.30093106398215663


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[174] train-loss: 0.32129904398551357


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[174] validation-loss: 0.4243858944286 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[175] train-loss: 0.29919975996017456


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[176] train-loss: 0.34680174864255464


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[177] train-loss: 0.33698779344558716


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[177] validation-loss: 0.27981660582802514 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[178] train-loss: 0.3034809644405658


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[179] train-loss: 0.2912757213299091


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[180] train-loss: 0.29517079775150007


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[180] validation-loss: 0.2978478886864402 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[181] train-loss: 0.2879293836080111


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[182] train-loss: 0.31428150488780093


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[183] train-loss: 0.31780024675222546


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[183] validation-loss: 0.42082535136829724 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[184] train-loss: 0.32415918661997867


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[185] train-loss: 0.3070365465604342


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[186] train-loss: 0.32680498178188616


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[186] validation-loss: 0.396645415912975 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[187] train-loss: 0.2935000933133639


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[188] train-loss: 0.3011901103533231


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[189] train-loss: 0.32601800790199864


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[189] validation-loss: 0.36783991076729516 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[190] train-loss: 0.28886584593699527


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[191] train-loss: 0.34337463745704067


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[192] train-loss: 0.3154652898128216


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[192] validation-loss: 0.36620145494287665 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[193] train-loss: 0.3179017901420593


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[194] train-loss: 0.3027208172357999


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[195] train-loss: 0.3213662000802847


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[195] validation-loss: 0.2775180014696988 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[196] train-loss: 0.32424697509178746


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[197] train-loss: 0.2885848650565514


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[198] train-loss: 0.31181347828644973


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[198] validation-loss: 0.3043572794307362 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[199] train-loss: 0.281401473742265


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[200] train-loss: 0.2875552544227013


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[201] train-loss: 0.29764520204984224


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[201] validation-loss: 0.30994129180908203 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[202] train-loss: 0.2980296199138348


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[203] train-loss: 0.3079675527719351


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[204] train-loss: 0.3217271382992084


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[204] validation-loss: 0.2822950103066184 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[205] train-loss: 0.3181366966320918


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[206] train-loss: 0.3125396829385024


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[207] train-loss: 0.3145298361778259


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[207] validation-loss: 0.37261034141887317 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[208] train-loss: 0.3441704511642456


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[209] train-loss: 0.3225138233258174


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[210] train-loss: 0.29488208202215344


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[210] validation-loss: 0.29186035286296497 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[211] train-loss: 0.31593127434070295


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[212] train-loss: 0.3220343222984901


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[213] train-loss: 0.30219903817543614


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[213] validation-loss: 0.39624163237485016 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[214] train-loss: 0.29720289890582746


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[215] train-loss: 0.29598080194913423


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[216] train-loss: 0.28536383922283465


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[216] validation-loss: 0.3867712129246105 Lerning Rate 0.0009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[217] train-loss: 0.29438483715057373


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[218] train-loss: 0.3231753431833707


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[219] train-loss: 0.3125157677210294


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[219] validation-loss: 0.25549998066642066 Lerning Rate 0.0008989733766060318


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[220] train-loss: 0.28666096467238206


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[221] train-loss: 0.31021294685510487


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[222] train-loss: 0.3422632400806134


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[222] validation-loss: 0.31615048105066473 Lerning Rate 0.0008935964078916617


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[223] train-loss: 0.3158107537489671


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[224] train-loss: 0.288642576107612


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[225] train-loss: 0.30715877276200515


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[225] validation-loss: 0.33223006942055444 Lerning Rate 0.0008836675397504736


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[226] train-loss: 0.2983368864426246


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[227] train-loss: 0.2912676059282743


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[228] train-loss: 0.29744547605514526


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[228] validation-loss: 0.34846278754147614 Lerning Rate 0.0008692886558055009


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[229] train-loss: 0.3057612409958473


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[230] train-loss: 0.3285032006410452


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[231] train-loss: 0.3323679337134728


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[231] validation-loss: 0.2870998165824197 Lerning Rate 0.0008506073028637585


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[232] train-loss: 0.29346493574289173


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[233] train-loss: 0.29313435004307675


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[234] train-loss: 0.300510443173922


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[234] validation-loss: 0.33442715081301605 Lerning Rate 0.0008278151768863377


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[235] train-loss: 0.2967744515492366


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[236] train-loss: 0.30513554352980393


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[237] train-loss: 0.3299296865096459


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[237] validation-loss: 0.27481306141073053 Lerning Rate 0.0008011461559284415


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[238] train-loss: 0.2969381809234619


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[239] train-loss: 0.28917725269611066


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[240] train-loss: 0.28740598605229306


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[240] validation-loss: 0.3033814755353061 Lerning Rate 0.0007708739002340492


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[241] train-loss: 0.30670333825624907


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[242] train-loss: 0.3078686182315533


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[243] train-loss: 0.3426579741331247


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[243] validation-loss: 0.3016004508191889 Lerning Rate 0.0007373090441114927


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[244] train-loss: 0.2939205215527461


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[245] train-loss: 0.30604697649295515


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[246] train-loss: 0.3299674620995155


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[246] validation-loss: 0.2898300723596053 Lerning Rate 0.0007007960084050964


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[247] train-loss: 0.28475894377781796


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[248] train-loss: 0.2905223002800575


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[249] train-loss: 0.29245434816067034


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[249] validation-loss: 0.2607030922716314 Lerning Rate 0.0006617094662712402


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[250] train-loss: 0.3109155847476079


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[251] train-loss: 0.30292986447994524


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[252] train-loss: 0.2937870759230394


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[252] validation-loss: 0.37185352498834784 Lerning Rate 0.0006204504985247593


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[253] train-loss: 0.3023057671693655


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[254] train-loss: 0.30868467000814587


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[255] train-loss: 0.33084623171732974


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[255] validation-loss: 0.30325224182822486 Lerning Rate 0.000577442478007037


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[256] train-loss: 0.34381941648629993


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[257] train-loss: 0.2903720415555514


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[258] train-loss: 0.30217755757845366


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[258] validation-loss: 0.26418739557266235 Lerning Rate 0.0005331267252077353


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[259] train-loss: 0.2900801209303049


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[260] train-loss: 0.3035689959159264


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[261] train-loss: 0.3000643986922044


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[261] validation-loss: 0.2745296196504073 Lerning Rate 0.00048795797971937037


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[262] train-loss: 0.29189242307956403


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[263] train-loss: 0.2876467475524315


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[264] train-loss: 0.3069488772979149


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[264] validation-loss: 0.26643919402902777 Lerning Rate 0.000442399733993731


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[265] train-loss: 0.29511202298677885


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[266] train-loss: 0.28203398906267607


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[267] train-loss: 0.30303258620775664


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[267] validation-loss: 0.29189797423102637 Lerning Rate 0.00039691947728211923


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[268] train-loss: 0.28627116405046904


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[269] train-loss: 0.30444091099959153


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[270] train-loss: 0.2927669286727905


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[270] validation-loss: 0.2673622098836032 Lerning Rate 0.00035198389856301674


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[271] train-loss: 0.3244528632897597


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[272] train-loss: 0.33154059373415434


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[273] train-loss: 0.314312613927401


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[273] validation-loss: 0.2946392731233077 Lerning Rate 0.00030805409768163496


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[274] train-loss: 0.29164913525948155


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[275] train-loss: 0.2972509402495164


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[276] train-loss: 0.2829719919424791


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[276] validation-loss: 0.2612178163094954 Lerning Rate 0.0002655808538415357


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[277] train-loss: 0.2882932103597201


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[278] train-loss: 0.2814563191854037


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[279] train-loss: 0.2919858464827904


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[279] validation-loss: 0.26401940800926904 Lerning Rate 0.00022500000000000008


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[280] train-loss: 0.3082723938501798


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[281] train-loss: 0.2868053683867821


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[282] train-loss: 0.2867582211127648


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[282] validation-loss: 0.2602863203395497 Lerning Rate 0.0001867279506321116


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[283] train-loss: 0.30642381539711583


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[284] train-loss: 0.28461209627298206


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[285] train-loss: 0.2949374547371498


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[285] validation-loss: 0.24833854220130228 Lerning Rate 0.00015115742875474234


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[286] train-loss: 0.31627442286564755


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[287] train-loss: 0.2848939253733708


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[288] train-loss: 0.2875528335571289


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[288] validation-loss: 0.24687653238123114 Lerning Rate 0.00011865343605696116


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[289] train-loss: 0.28112385823176456


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[290] train-loss: 0.2806984002773578


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[291] train-loss: 0.28177035313386184


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[291] validation-loss: 0.2524741237813776 Lerning Rate 8.954950748877685e-05


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[292] train-loss: 0.2785424773509686


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[293] train-loss: 0.2786542452298678


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[294] train-loss: 0.3222447175246019


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[294] validation-loss: 0.24965945157137784 Lerning Rate 6.414428874120698e-05


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[295] train-loss: 0.3251510170789865


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[296] train-loss: 0.2896566161742577


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[297] train-loss: 0.28329981290377104


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[297] validation-loss: 0.24987887252460828 Lerning Rate 4.269847173735541e-05


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[298] train-loss: 0.28746349536455595


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[299] train-loss: 0.28011125784653884


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[300] train-loss: 0.28352511387604934


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[300] validation-loss: 0.25240504199808295 Lerning Rate 2.5432119580508876e-05


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[301] train-loss: 0.306500269816472


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[302] train-loss: 0.2839350700378418


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[303] train-loss: 0.2860154142746559


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[303] validation-loss: 0.2524645816196095 Lerning Rate 1.252240840890424e-05


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[304] train-loss: 0.3012666152073787


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[305] train-loss: 0.28084248762864333


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[306] train-loss: 0.28483779155291045


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[306] validation-loss: 0.2510438453067433 Lerning Rate 4.101809328792519e-06


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[307] train-loss: 0.2901087999343872


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[308] train-loss: 0.29894136924010056


HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

[309] train-loss: 0.28191441756028396


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

[309] validation-loss: 0.2511053518815474 Lerning Rate 2.567290816268995e-07
DONE.


In [5]:
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [6]:
prediction

tensor([[8.1751e-09, 7.7866e-07, 9.1112e-08, 1.0000e+00, 1.1654e-09],
        [7.8917e-09, 7.9131e-07, 1.0044e-07, 1.0000e+00, 1.7672e-09],
        [2.6293e-07, 1.0000e+00, 5.9445e-11, 5.9596e-10, 8.3633e-08],
        [5.4529e-07, 5.5853e-08, 1.0000e+00, 2.9399e-08, 1.1039e-06],
        [2.1192e-08, 2.3380e-12, 1.0000e+00, 4.1937e-11, 5.7587e-11],
        [9.9927e-01, 1.3875e-04, 5.0521e-04, 7.6006e-05, 5.2926e-06],
        [9.2417e-09, 2.5784e-07, 7.7077e-08, 1.0000e+00, 3.5888e-10],
        [4.8748e-12, 1.0000e+00, 9.2828e-12, 9.4536e-10, 1.0514e-09],
        [1.1424e-08, 1.0000e+00, 2.1505e-10, 4.7224e-08, 8.8493e-09],
        [4.8963e-06, 3.7303e-07, 1.4507e-09, 7.5827e-11, 9.9999e-01],
        [8.9398e-06, 3.5119e-09, 9.9999e-01, 3.8491e-08, 2.3634e-08],
        [1.7540e-07, 5.6175e-10, 1.0000e+00, 4.0932e-10, 1.5227e-08],
        [1.0735e-06, 6.0700e-08, 1.9198e-06, 2.2887e-09, 1.0000e+00],
        [8.3475e-07, 1.5465e-08, 1.0000e+00, 2.0683e-08, 8.1588e-07],
        [9.9997e-01,

In [7]:
label

tensor([3, 3, 1, 2, 2, 0, 3, 1, 1, 4, 2, 2, 4, 2, 0, 1, 2, 3, 3, 1, 3],
       device='cuda:0')

In [8]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_train_hist, y_train_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [9]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.98      0.99      0.99      6339
           1       1.00      0.96      0.98      6175
           2       0.99      0.99      0.99      6165
           3       1.00      1.00      1.00      6208
           4       0.96      0.99      0.98      6113

    accuracy                           0.99     31000
   macro avg       0.99      0.99      0.99     31000
weighted avg       0.99      0.99      0.99     31000



In [11]:
confusion_matrix(all_labels, all_predictions)

array([[6295,    1,   40,    0,    3],
       [ 185, 5806,   23,    0,  161],
       [   7,    0, 6158,    0,    0],
       [   0,    0,    0, 6208,    0],
       [  41,    7,   32,    0, 6033]], dtype=int64)

In [10]:
confusion_matrix(all_labels, all_predictions)

array([[6283,    3,   20,    0,   33],
       [  52, 5898,    9,    0,  216],
       [  40,    0, 6125,    0,    0],
       [   0,    0,    0, 6208,    0],
       [  28,    2,   13,    0, 6070]], dtype=int64)

In [11]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [12]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.98      1.00      0.99      2727
           1       0.99      0.96      0.97      2714
           2       0.99      1.00      0.99      2710
           3       1.00      1.00      1.00      2714
           4       0.98      0.98      0.98      2735

    accuracy                           0.99     13600
   macro avg       0.99      0.99      0.99     13600
weighted avg       0.99      0.99      0.99     13600



In [13]:
print(confusion_matrix(all_labels, all_predictions))

[[2717    0    6    0    4]
 [  45 2605    8    0   56]
 [   2    0 2707    0    1]
 [   0    0    0 2714    0]
 [  10   25   11    0 2689]]


In [10]:
torch.save({'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': validation_loss.item()}, 'final')

In [11]:
model = torch.load("checkpoint_best")

In [12]:
model = CatNet(num_classes=5)
optimizer = optim.Ranger(
    net.parameters(),
    lr=1e-3,
    alpha=0.5,
    k=6,
    N_sma_threshhold=5,
    betas=(.95, 0.999),
    eps=1e-5,
    weight_decay=0
)

checkpoint = torch.load("checkpoint_best")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
model.cuda()

CatNet(
  (block1): Sequential(
    (0): Conv2d(1, 82, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (1): BatchNorm2d(82, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): mish_layer()
  )
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): skip_connection_block(
    (conv_il): Conv2d(82, 82, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (batn_il): BatchNorm2d(82, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_ol): Conv2d(82, 82, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batn_ol): BatchNorm2d(82, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_sl): Conv2d(82, 82, kernel_size=(1, 1), stride=(2, 2))
    (batn_sl): BatchNorm2d(82, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block3): skip_connection_block(
    (conv_il): Conv2d(82, 82, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (batn_il): BatchNor

In [13]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = model(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [14]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.99      0.99      0.99      2727
           1       1.00      0.97      0.98      2714
           2       0.99      1.00      0.99      2710
           3       1.00      1.00      1.00      2714
           4       0.98      1.00      0.99      2735

    accuracy                           0.99     13600
   macro avg       0.99      0.99      0.99     13600
weighted avg       0.99      0.99      0.99     13600



In [15]:
print(confusion_matrix(all_labels, all_predictions))

[[2698    3   16    0   10]
 [   8 2642   10    1   53]
 [   0    0 2708    0    2]
 [   0    0    0 2714    0]
 [   6    6    0    0 2723]]
