In [1]:
import torch
import pickle

with open( "x_train_hist.p", "rb" ) as file:
    x_train_hist = pickle.load(file)
    
with open( "y_train_hist.p", "rb" ) as file:
    y_train_hist = pickle.load(file)
    
with open( "x_valid_hist.p", "rb" ) as file:
    x_valid_hist = pickle.load(file)
    
with open( "y_valid_hist.p", "rb" ) as file:
    y_valid_hist = pickle.load(file)

In [2]:
import torch
import random
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm

from mriqa_dataset import MRIQADataset
from networks import ClassicCNN, PhilsClassicCnn, CatNet

If you use TorchIO for your research, please cite the following paper:
Pérez-García et al., TorchIO: a Python library for efficient loading,
preprocessing, augmentation and patch-based sampling of medical images
in deep learning. Link: https://arxiv.org/abs/2003.04696



In [3]:
import torch_optimizer as optim
num_epochs = len(x_train_hist)

net = CatNet(num_classes=5)
net = net.cuda()

#optimizer = optim.Adam(net.parameters())
optimizer = optim.Ranger(
    net.parameters(),
    lr=1e-3,
    alpha=0.5,
    k=6,
    N_sma_threshhold=5,
    betas=(.95, 0.999),
    eps=1e-5,
    weight_decay=0
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs*13)
ce = CrossEntropyLoss().cuda()

best_val_loss = 999

loss_csv = open('losses.csv', 'w')
loss_csv.write('epoch,training,validation\n')
num_mini_batches = 13
    
print("start training")
for epoch in range(num_epochs):
    epoch_loss = 0.0
    net.train()
    
    train_batches = x_train_hist[epoch]
    train_labels = y_train_hist[epoch]

    # train loop
    for sample, label in tqdm(zip(train_batches, train_labels), total=len(train_batches), leave=False):       
        sample = sample.cuda()
        label = label.cuda()

        prediction = net(sample)
        loss = ce(prediction, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 0.75)
        optimizer.step()

        epoch_loss += loss.item()

    print('[{}] train-loss: {}'.format(epoch, epoch_loss / num_mini_batches))
    loss_csv.write(str(epoch) + ',' + str(epoch_loss / num_mini_batches))
    loss_csv.flush()

    # validation loop
    net.eval()
    mean_validation_loss = 0
    num_validation_mini_batches = 11
    
    if epoch % 5 == 0:
        valid_batches = x_valid_hist[epoch//5]
        valid_labels = y_valid_hist[epoch//5]
        with torch.no_grad():
            for sample, label in tqdm(zip(valid_batches, valid_labels), total=len(valid_batches), leave=False):              
                sample = sample.cuda()
                label = label.cuda()

                prediction = net(sample)
                validation_loss = ce(prediction, label)

                mean_validation_loss += validation_loss.item()
            print('[{}] validation-loss: {}'.format(epoch, mean_validation_loss / num_validation_mini_batches))
            loss_csv.write(',' + str(mean_validation_loss / num_validation_mini_batches) + '\n')
            loss_csv.flush()

        # save best model
        if mean_validation_loss <= best_val_loss:
            torch.save({'epoch': epoch,
                        'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': validation_loss.item()}, 'checkpoint_best')
            best_val_loss = mean_validation_loss

print('DONE.')

start training


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

	addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
	addcmul_(Tensor tensor1, Tensor tensor2, *, Number value)


[0] train-loss: 1.5663099105541523


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[0] validation-loss: 1.6248558868061413


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[1] train-loss: 1.4350514870423536


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[2] train-loss: 1.343301617182218


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[3] train-loss: 1.296697708276602


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[4] train-loss: 1.319832104903001


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[5] train-loss: 1.3185409215780406


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[5] validation-loss: 1.3161169724030928


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[6] train-loss: 1.3369301099043627


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[7] train-loss: 1.296456318635207


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[8] train-loss: 1.1709290100977972


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[9] train-loss: 1.1543013132535493


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[10] train-loss: 1.1512054709287791


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[10] validation-loss: 1.1704936677759343


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[11] train-loss: 1.1887794091151311


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[12] train-loss: 1.1692465406197767


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[13] train-loss: 1.1074888889606183


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[14] train-loss: 1.175271327678974


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[15] train-loss: 1.19504609474769


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[15] validation-loss: 0.908751823685386


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[16] train-loss: 1.1725320770190313


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[17] train-loss: 1.1324917582365184


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[18] train-loss: 1.010369571355673


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[19] train-loss: 1.007484426865211


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[20] train-loss: 1.0286092849878163


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[20] validation-loss: 0.8291559273546393


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[21] train-loss: 1.0032782096129198


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[22] train-loss: 1.0709409759594843


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[23] train-loss: 0.96380726649211


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[24] train-loss: 0.9493971604567307


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[25] train-loss: 0.9657615790000329


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[25] validation-loss: 0.9359039501710371


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[26] train-loss: 0.9411839521848239


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[27] train-loss: 0.9290196162003738


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[28] train-loss: 1.0019193887710571


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[29] train-loss: 1.0233785280814538


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[30] train-loss: 0.8743363389602075


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[30] validation-loss: 1.0269771868532354


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[31] train-loss: 0.961085741336529


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[32] train-loss: 0.937915366429549


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[33] train-loss: 0.9274464066211994


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[34] train-loss: 0.9302069544792175


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[35] train-loss: 0.8937221490419828


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[35] validation-loss: 0.6343347376043146


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[36] train-loss: 0.8376940259566674


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[37] train-loss: 0.8696103646205022


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[38] train-loss: 0.8091880862529461


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[39] train-loss: 0.945244312286377


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[40] train-loss: 0.8257232124988849


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[40] validation-loss: 0.8560248071497137


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[41] train-loss: 0.8685288704358615


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[42] train-loss: 0.8153815315319941


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[43] train-loss: 0.8122140444242038


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[44] train-loss: 0.8229582401422354


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[45] train-loss: 0.8480713092363797


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[45] validation-loss: 0.6638910391113975


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[46] train-loss: 0.8407336404690375


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[47] train-loss: 0.7034790240801297


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[48] train-loss: 0.7212734589209924


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[49] train-loss: 0.7709763371027433


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[50] train-loss: 0.7834874391555786


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[50] validation-loss: 0.6243175241080198


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[51] train-loss: 0.6970295172471267


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[52] train-loss: 0.7949426449262179


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[53] train-loss: 0.7144562900066376


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[54] train-loss: 0.7893509406309861


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[55] train-loss: 0.8123552249028132


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[55] validation-loss: 0.8027955456213518


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[56] train-loss: 0.7514541240838858


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[57] train-loss: 0.7525284038140223


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[58] train-loss: 0.792020871089055


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[59] train-loss: 0.7482600372571212


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[60] train-loss: 0.7084405055412879


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[60] validation-loss: 0.590921862558885


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[61] train-loss: 0.7108936951710627


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[62] train-loss: 0.5958938873731173


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[63] train-loss: 0.6669151828839228


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[64] train-loss: 0.5320503092729129


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[65] train-loss: 0.6453406787835635


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[65] validation-loss: 0.47464531660079956


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[66] train-loss: 0.5823764801025391


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[67] train-loss: 0.6127231877583724


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[68] train-loss: 0.5779384443393121


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[69] train-loss: 0.660896014708739


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[70] train-loss: 0.5697948909722842


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[70] validation-loss: 0.6274202384731986


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[71] train-loss: 0.5993930720365964


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[72] train-loss: 0.5935636323231918


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[73] train-loss: 0.5824650571896479


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[74] train-loss: 0.5802050599685082


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[75] train-loss: 0.6674860578316909


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[75] validation-loss: 0.4724848527799953


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[76] train-loss: 0.659316924902109


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[77] train-loss: 0.5011428204866556


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[78] train-loss: 0.5760716944932938


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[79] train-loss: 0.5117720846946423


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[80] train-loss: 0.5548631388407487


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[80] validation-loss: 0.47996575994925067


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[81] train-loss: 0.5153241386780372


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[82] train-loss: 0.5926739940276513


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[83] train-loss: 0.542760255245062


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[84] train-loss: 0.6254781805551969


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[85] train-loss: 0.5536352373086489


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[85] validation-loss: 0.47074379975145514


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[86] train-loss: 0.5508189384753888


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[87] train-loss: 0.4673963257899651


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[88] train-loss: 0.5491276589723734


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[89] train-loss: 0.6116063801141886


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[90] train-loss: 0.5418936564372137


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[90] validation-loss: 0.36782884326848114


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[91] train-loss: 0.41918506759863633


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[92] train-loss: 0.5276372341009287


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[93] train-loss: 0.4969700391475971


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[94] train-loss: 0.4453338315853706


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[95] train-loss: 0.530036788720351


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[95] validation-loss: 0.3434470417824658


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[96] train-loss: 0.44548457860946655


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[97] train-loss: 0.5128662471587841


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[98] train-loss: 0.4376853314729837


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[99] train-loss: 0.4480158526163835


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[100] train-loss: 0.5419931182494531


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[100] validation-loss: 0.42586165395649994


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[101] train-loss: 0.3447047036427718


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[102] train-loss: 0.4400196694410764


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[103] train-loss: 0.41423561710577744


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[104] train-loss: 0.42909066952191866


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[105] train-loss: 0.4600322682123918


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[105] validation-loss: 0.3729186383160678


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[106] train-loss: 0.45558157104712266


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[107] train-loss: 0.4269019021437718


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[108] train-loss: 0.42725971111884486


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[109] train-loss: 0.39204625670726484


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[110] train-loss: 0.5101531102107122


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[110] validation-loss: 0.519530553709377


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[111] train-loss: 0.4200785435163058


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[112] train-loss: 0.4810237231162878


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[113] train-loss: 0.39123003299419695


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[114] train-loss: 0.3663769456056448


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[115] train-loss: 0.44212077672664934


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[115] validation-loss: 0.3494498512961648


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[116] train-loss: 0.46329203935769886


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[117] train-loss: 0.5627250098265134


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[118] train-loss: 0.3750616002541322


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[119] train-loss: 0.43592159335429853


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[120] train-loss: 0.4095888023193066


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[120] validation-loss: 0.28401614319194446


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[121] train-loss: 0.40917163972671217


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[122] train-loss: 0.4790360423234793


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[123] train-loss: 0.30039523656551653


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[124] train-loss: 0.39283884259370655


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[125] train-loss: 0.36178649159578175


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[125] validation-loss: 0.3601430221037431


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[126] train-loss: 0.39969143386070544


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[127] train-loss: 0.49484338897925156


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[128] train-loss: 0.3670039176940918


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[129] train-loss: 0.44891321888336766


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[130] train-loss: 0.4224446690999545


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[130] validation-loss: 0.3202317018400539


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[131] train-loss: 0.3898599629218762


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[132] train-loss: 0.453537975366299


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[133] train-loss: 0.2610333149249737


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[134] train-loss: 0.3631214407774118


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[135] train-loss: 0.4231100449195275


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[135] validation-loss: 0.39376775513995776


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[136] train-loss: 0.38381341787484974


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[137] train-loss: 0.3252615286753728


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[138] train-loss: 0.4471455571743158


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[139] train-loss: 0.3449232876300812


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[140] train-loss: 0.34609674031917864


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[140] validation-loss: 0.25406064601107076


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[141] train-loss: 0.2700243385938498


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[142] train-loss: 0.3676438216979687


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[143] train-loss: 0.28409592921917254


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[144] train-loss: 0.3209722271332374


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[145] train-loss: 0.3070245820742387


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[145] validation-loss: 0.2429404773495414


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[146] train-loss: 0.37617475940630984


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[147] train-loss: 0.3379121995889224


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[148] train-loss: 0.3071840829574145


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[149] train-loss: 0.29101608808224017


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[150] train-loss: 0.4010657255466168


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[150] validation-loss: 0.17803989215330643


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[151] train-loss: 0.3403455294095553


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[152] train-loss: 0.22469207415213951


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[153] train-loss: 0.31432802172807545


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[154] train-loss: 0.293023393704341


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[155] train-loss: 0.32450634011855495


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[155] validation-loss: 0.1751932745630091


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[156] train-loss: 0.24262640338677627


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[157] train-loss: 0.27466550584022814


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[158] train-loss: 0.3969537547001472


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[159] train-loss: 0.295732339987388


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[160] train-loss: 0.23089269720591032


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[160] validation-loss: 1.5743366046385332


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[161] train-loss: 0.2529308818853818


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[162] train-loss: 0.32370662689208984


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[163] train-loss: 0.2880419401022104


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[164] train-loss: 0.31170074756328875


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[165] train-loss: 0.3195749062758226


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[165] validation-loss: 0.23432310061021286


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[166] train-loss: 0.22323602323348707


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[167] train-loss: 0.23533281683921814


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[168] train-loss: 0.3139888667143308


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[169] train-loss: 0.362834625519239


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[170] train-loss: 0.29613868319071257


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[170] validation-loss: 0.9067453173073855


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[171] train-loss: 0.3573579467259921


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[172] train-loss: 0.3572752899848498


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[173] train-loss: 0.24163535237312317


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[174] train-loss: 0.40969256254342884


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[175] train-loss: 0.43124955204816967


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[175] validation-loss: 0.15276278961788525


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[176] train-loss: 0.2719179712809049


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[177] train-loss: 0.3290046017903548


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[178] train-loss: 0.2897542417049408


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[179] train-loss: 0.42372660453502947


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[180] train-loss: 0.20812406906714806


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[180] validation-loss: 0.21705276315862482


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[181] train-loss: 0.19693977328447196


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[182] train-loss: 0.3463057027413295


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[183] train-loss: 0.30352464891397035


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[184] train-loss: 0.26838953678424543


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[185] train-loss: 0.28411136682216936


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[185] validation-loss: 0.3914773301644759


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[186] train-loss: 0.2820671177827395


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[187] train-loss: 0.2157354515332442


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[188] train-loss: 0.2674389985891489


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[189] train-loss: 0.3331296764887296


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[190] train-loss: 0.1989699097780081


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[190] validation-loss: 0.33228324489160016


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[191] train-loss: 0.27381793237649477


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[192] train-loss: 0.24523580762056205


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[193] train-loss: 0.26145283992473894


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[194] train-loss: 0.35072561181508577


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[195] train-loss: 0.33475208282470703


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[195] validation-loss: 0.2931963828476993


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[196] train-loss: 0.4093075188306662


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[197] train-loss: 0.3319578766822815


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[198] train-loss: 0.2282063364982605


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[199] train-loss: 0.30541696686011094


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[200] train-loss: 0.20873530552937433


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[200] validation-loss: 0.19141494465822523


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[201] train-loss: 0.20184644139730012


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[202] train-loss: 0.27218876893703753


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[203] train-loss: 0.23576723383023188


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[204] train-loss: 0.32423400305784666


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[205] train-loss: 0.23786732325187096


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[205] validation-loss: 0.3990857668898322


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[206] train-loss: 0.23426352326686567


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[207] train-loss: 0.2754016105945294


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[208] train-loss: 0.20761186113724342


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[209] train-loss: 0.20198637934831473


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[210] train-loss: 0.22788914006489974


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[210] validation-loss: 0.1389593556523323


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[211] train-loss: 0.29000500990794253


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[212] train-loss: 0.22749988849346453


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[213] train-loss: 0.1789399775174948


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[214] train-loss: 0.28413620820412266


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[215] train-loss: 0.19841093054184547


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[215] validation-loss: 0.16926401582631198


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[216] train-loss: 0.18967708486777085


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[217] train-loss: 0.2670312959414262


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[218] train-loss: 0.17140851341761076


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[219] train-loss: 0.258374768954057


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[220] train-loss: 0.21293906752879804


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

[220] validation-loss: 0.24854535626416857


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[221] train-loss: 0.3396165370941162


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[222] train-loss: 0.24585402126495653


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[223] train-loss: 0.19631077005312994


HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

[224] train-loss: 0.24056900349947122
DONE.


In [4]:
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [5]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_train_hist, y_train_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [6]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.97      0.92      0.94      4592
           1       0.99      0.94      0.97      4372
           2       0.94      0.98      0.96      4542
           3       1.00      1.00      1.00      4555
           4       0.93      0.99      0.96      4439

    accuracy                           0.97     22500
   macro avg       0.97      0.97      0.97     22500
weighted avg       0.97      0.97      0.97     22500



In [None]:
confusion_matrix(all_labels, all_predictions)

In [7]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = net(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [8]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.97      0.90      0.93       775
           1       0.98      0.94      0.96       862
           2       0.93      0.97      0.95       881
           3       1.00      1.00      1.00       855
           4       0.92      0.99      0.96       877

    accuracy                           0.96      4250
   macro avg       0.96      0.96      0.96      4250
weighted avg       0.96      0.96      0.96      4250



In [18]:
print(confusion_matrix(all_labels, all_predictions))

[[694   8  52   0  21]
 [ 19 808  11   0  24]
 [  1   0 854   0  26]
 [  0   0   0 855   0]
 [  0   6   4   0 867]]


In [10]:
torch.save({'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': validation_loss.item()}, 'final')

In [26]:
model = torch.load("checkpoint_best")

In [33]:
model = CatNet(num_classes=5)
optimizer = optim.Ranger(
    net.parameters(),
    lr=1e-3,
    alpha=0.5,
    k=6,
    N_sma_threshhold=5,
    betas=(.95, 0.999),
    eps=1e-5,
    weight_decay=0
)

checkpoint = torch.load("checkpoint_best")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
model.cuda()

CatNet(
  (block1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): mish_layer()
  )
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (block2): skip_connection_block(
    (conv_il): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (batn_il): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_ol): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (batn_ol): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_sl): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2))
    (batn_sl): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (block3): skip_connection_block(
    (conv_il): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
    (batn_il): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [37]:
all_labels = []
all_predictions = []

with torch.no_grad():
    for batch, labels in tqdm(zip(x_valid_hist, y_valid_hist)):
        for minibatch, label in zip(batch, labels):
            sample = minibatch.cuda()

            prediction = model(sample)
            all_predictions.append(prediction.cpu())
            
            all_labels.append(label)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [38]:
all_predictions = [item.argmax().tolist() for sublist in all_predictions for item in sublist]
all_labels = [item.tolist() for sublist in all_labels for item in sublist]
print(classification_report(all_labels, all_predictions))

              precision    recall  f1-score   support

           0       0.97      0.90      0.94       775
           1       0.92      0.94      0.93       862
           2       0.99      0.90      0.95       881
           3       1.00      0.99      1.00       855
           4       0.87      0.98      0.92       877

    accuracy                           0.95      4250
   macro avg       0.95      0.94      0.95      4250
weighted avg       0.95      0.95      0.95      4250

