In [1]:
%load_ext autoreload
%autoreload 2

import os
from tqdm.notebook import tqdm, trange

import torch
from torch import nn

from timm.utils import accuracy, AverageMeter

import cfg
from models import build_model
from data import build_loader
from utils.lr_scheduler import build_scheduler
from utils.optimizer import build_optimizer

In [2]:
def train(model, data_loader, criterion, optimizer):
    model.train()

    loss_meter = AverageMeter()
    acc_meter  = AverageMeter()

    pbar = tqdm(train_loader, leave=False)
    for i, (images, labels) in enumerate( pbar ):
        optimizer.zero_grad()

        images, labels = images.cuda(), labels.cuda()

        output = model(images)
        
        loss = criterion(output, labels)
        acc = accuracy(output, labels, topk=(1,))
        
        loss.backward()
        optimizer.step()

        loss_meter.update(loss.item(), labels.size(0))
        acc_meter.update(acc[0].item(), labels.size(0))

        pbar.set_description(f"[Train] loss: {loss_meter.val:3.3f}, acc: {acc_meter.val:3.3f}")

    pbar.close()
    return loss_meter.avg, acc_meter.avg

def test(model, data_loader):
    model.eval()
    
    loss_meter = AverageMeter()
    acc_meter  = AverageMeter()
    
    pbar = tqdm(test_loader, leave=False)
    with torch.no_grad():
        for i, (images, labels) in enumerate( pbar ):
            images, labels = images.cuda(), labels.cuda()

            output = model(images)
            
            loss = criterion(output, labels)
            acc = accuracy(output, labels, topk=(1,))

            loss_meter.update(loss.item(), labels.size(0))
            acc_meter.update(acc[0].item(), labels.size(0))
            
            pbar.set_description(f"[Test] loss: {loss_meter.val:3.3f}, acc: {acc_meter.val:3.3f}")
    return loss_meter.avg, acc_meter.avg

In [3]:
model = build_model()
train_dataset, train_loader, test_dataset, test_loader = build_loader()
optimizer = build_optimizer(model)
lr_scheduler = build_scheduler(optimizer)
criterion = nn.CrossEntropyLoss().cuda()

model.cuda()
for epoch in range(1, cfg.epoch_num+1):
    
    lr = lr_scheduler.get_last_lr()[0]
    
    loss_train, acc_train = train(model, train_loader, criterion, optimizer)
    loss_test, acc_test   = test(model, test_loader)
    
    print(f"Epoch:{epoch:3}, lr={lr:.1e}, [Train] Loss:{loss_train:.2f}, Acc:{acc_train:.2f} | [Test] Loss:{loss_test:.2f}, Acc:{acc_test:.2f}", flush=True)

    lr_scheduler.step()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  1, lr=1.0e-01, [Train] Loss:2.18, Acc:25.04 | [Test] Loss:1.81, Acc:47.66


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  2, lr=3.0e-02, [Train] Loss:1.02, Acc:71.20 | [Test] Loss:1.43, Acc:61.81


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  3, lr=9.0e-03, [Train] Loss:0.45, Acc:87.28 | [Test] Loss:0.44, Acc:85.07


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  4, lr=2.7e-03, [Train] Loss:0.36, Acc:89.41 | [Test] Loss:0.34, Acc:90.50


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  5, lr=8.1e-04, [Train] Loss:0.35, Acc:90.03 | [Test] Loss:0.33, Acc:90.46


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  6, lr=2.4e-04, [Train] Loss:0.34, Acc:90.14 | [Test] Loss:0.33, Acc:90.54


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  7, lr=7.3e-05, [Train] Loss:0.34, Acc:90.26 | [Test] Loss:0.33, Acc:90.57


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  8, lr=2.2e-05, [Train] Loss:0.34, Acc:90.25 | [Test] Loss:0.33, Acc:90.55


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  9, lr=6.6e-06, [Train] Loss:0.34, Acc:90.21 | [Test] Loss:0.33, Acc:90.53


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch: 10, lr=2.0e-06, [Train] Loss:0.34, Acc:90.21 | [Test] Loss:0.33, Acc:90.52
