In [1]:
from sam import SAM
from utility.step_lr import StepLR
from wide_res_net import WideResNet
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from smooth_cross_entropy import smooth_crossentropy
from utility.log import Log

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import numpy as np
import random
import os

import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
train_data = dsets.CIFAR10(root='data/',
                         train=True,
                         transform=transforms.ToTensor(),
                         download=True)

test_data = dsets.CIFAR10(root='data/',
                        train=False,
                        transform=transforms.ToTensor(),
                        download=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [14]:
model = WideResNet(16, 8, 0.0, in_channels = 3, labels = 10).to(device)

In [15]:
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=2.0, adaptive=True, lr=0.1, momentum=0.9, weight_decay=0.0005)

In [17]:
scheduler = StepLR(optimizer,0.1, 200)

In [18]:
log = Log(log_each=10)

In [20]:
batch_size = 128

train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True)

test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=False)

In [21]:
for epoch in range(20):
    model.train()
    log.train(len_dataset=len(train_data))
    for i, (batch_images, batch_labels) in enumerate(train_loader):
        inputs = batch_images.to(device)
        targets = batch_labels.to(device) 
        
        # first forward-backward step
        enable_running_stats(model)#momentum을 backup momentum으로
        predictions = model(inputs)#w_t
        loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
        loss.mean().backward()#w_t의 grad를 구했다

        optimizer.first_step(zero_grad=True) #w -> w_adv이동, w_t에서의 param은 old_p에 저장

        # second forward-backward step
        disable_running_stats(model) #momentum backup하고 module.momentum = 0으로 만들어줌.
        smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()#w_adv에서의 grad 구했다
        optimizer.second_step(zero_grad=True)#저장해둔 old_p 가져와서 SGD 한step 진행시킴.

        with torch.no_grad():
            correct = torch.argmax(predictions.data, 1) == targets
            log(model, loss.cpu(), correct.cpu(), scheduler.lr())
            scheduler(epoch)

    model.eval()
    log.eval(len_dataset=len(train_data))

    with torch.no_grad():
        for i, (batch_images, batch_labels) in enumerate(test_loader):
        #for batch in dataset.test:
            inputs = batch_images.to(device)
            targets = batch_labels.to(device) 

            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            log(model, loss.cpu(), correct.cpu())


┏━━━━━━━━━━━━━━┳━━━━━━━╸T╺╸R╺╸A╺╸I╺╸N╺━━━━━━━┳━━━━━━━╸S╺╸T╺╸A╺╸T╺╸S╺━━━━━━━┳━━━━━━━╸V╺╸A╺╸L╺╸I╺╸D╺━━━━━━━┓
┃              ┃              ╷              ┃              ╷              ┃              ╷              ┃
┃       epoch  ┃        loss  │    accuracy  ┃        l.r.  │     elapsed  ┃        loss  │    accuracy  ┃
┠──────────────╂──────────────┼──────────────╂──────────────┼──────────────╂──────────────┼──────────────┨
┃           0  ┃      1.1412  │     45.21 %  ┃   1.000e-01  │   01:19 min  ┃┈░┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┨      0.9705  │     53.03 %  ┃
┃           1  ┃      0.7107  │     67.71 %  ┃   1.000e-01  │   01:19 min  ┃┈░┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┨      0.6642  │     70.14 %  ┃
┃           2  ┃      0.5217  │     77.47 %  ┃   1.000e-01  │   01:20 min  ┃┈░┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┨      0.5884  │     73.68 %  ┃
┃           3  ┃      0.4131  │     82.50 %  ┃   1.000e-01  │   01:18 min  ┃┈░┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┈┨      0.5264  │     77.02 %  ┃
┃           4  ┃      0.