In [1]:
%load_ext autoreload
%autoreload 2

import os
from tqdm.notebook import tqdm, trange

import torch
from torch import nn

from timm.utils import accuracy, AverageMeter

from models import build_model
from data import build_loader
from utils.lr_scheduler import build_scheduler
from utils.optimizer import build_optimizer

In [2]:
def train(model, data_loader, criterion, optimizer):
    model.train()

    loss_meter = AverageMeter()
    acc_meter  = AverageMeter()

    pbar = tqdm(train_loader, leave=False)
    for i, (images, labels) in enumerate( pbar ):
        optimizer.zero_grad()

        images, labels = images.cuda(), labels.cuda()

        output = model(images)
        
        loss = criterion(output, labels)
        acc = accuracy(output, labels, topk=(1,))
        
        loss.backward()
        optimizer.step()

        loss_meter.update(loss.item(), labels.size(0))
        acc_meter.update(acc[0].item(), labels.size(0))

        pbar.set_description(f"[Train] loss: {loss_meter.val:3.3f}, acc: {acc_meter.val:3.3f}")

    pbar.close()
    return loss_meter.avg, acc_meter.avg

def test(model, data_loader):
    model.eval()
    
    loss_meter = AverageMeter()
    acc_meter  = AverageMeter()
    
    pbar = tqdm(test_loader, leave=False)
    with torch.no_grad():
        for i, (images, labels) in enumerate( pbar ):
            images, labels = images.cuda(), labels.cuda()

            output = model(images)
            
            loss = criterion(output, labels)
            acc = accuracy(output, labels, topk=(1,))

            loss_meter.update(loss.item(), labels.size(0))
            acc_meter.update(acc[0].item(), labels.size(0))
            
            pbar.set_description(f"[Test] loss: {loss_meter.val:3.3f}, acc: {acc_meter.val:3.3f}")
    return loss_meter.avg, acc_meter.avg

In [15]:
epoch_num = 6

model = build_model(None)
train_dataset, train_loader, test_dataset, test_loader = build_loader()
lr_scheduler = build_scheduler()
optimizer = build_optimizer(model)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3, last_epoch=- 1, verbose=False)
criterion = nn.CrossEntropyLoss().cuda()

model.cuda()
for epoch in range(1, epoch_num+1):
    
    lr = lr_scheduler.get_last_lr()[0]
    
    loss_train, acc_train = train(model, train_loader, criterion, optimizer)
    loss_test, acc_test   = test(model, test_loader)
    
    print(f"Epoch:{epoch:3}, lr={lr:.1e}, [Train] Loss:{loss_train:.2f}, Acc:{acc_train:.2f} | [Test] Loss:{loss_test:.2f}, Acc:{acc_test:.2f}", flush=True)

    lr_scheduler.step()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  1, lr=1.0e-01, [Train] Loss:2.23, Acc:23.13 | [Test] Loss:1.70, Acc:37.94


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  2, lr=3.0e-02, [Train] Loss:0.99, Acc:73.08 | [Test] Loss:0.77, Acc:74.50


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  3, lr=9.0e-03, [Train] Loss:0.53, Acc:84.52 | [Test] Loss:0.51, Acc:85.02


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  4, lr=2.7e-03, [Train] Loss:0.47, Acc:86.12 | [Test] Loss:0.45, Acc:86.84


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  5, lr=8.1e-04, [Train] Loss:0.45, Acc:86.61 | [Test] Loss:0.44, Acc:86.98


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=40.0), HTML(value='')))

Epoch:  6, lr=2.4e-04, [Train] Loss:0.45, Acc:86.71 | [Test] Loss:0.43, Acc:87.04
