In [1]:
from sam import SAM
from utility.step_lr import StepLR
from wide_res_net import WideResNet
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from smooth_cross_entropy import smooth_crossentropy
from utility.log import Log

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import numpy as np
import random
import os

import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
train_data = dsets.CIFAR10(root='data/',
                         train=True,
                         transform=transforms.ToTensor(),
                         download=True)

test_data = dsets.CIFAR10(root='data/',
                        train=False,
                        transform=transforms.ToTensor(),
                        download=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [25]:
model = WideResNet(16, 8, 0.0, in_channels = 3, labels = 10).to(device)

In [6]:
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=2.0, adaptive=True, lr=0.1, momentum=0.9, weight_decay=0.0005)

In [7]:
scheduler = StepLR(optimizer,0.1, 200)

In [8]:
log = Log(log_each=10)

In [9]:
batch_size = 128

train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True)

test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=False)

In [None]:
for epoch in range(20):
    model.train()
    log.train(len_dataset=len(train_data))
    for i, (batch_images, batch_labels) in enumerate(train_loader):
        inputs = batch_images.to(device)
        targets = batch_labels.to(device) 
        
        # first forward-backward step
        enable_running_stats(model)
        predictions = model(inputs)
        loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
        loss.mean().backward()

        optimizer.first_step(zero_grad=True)

        # second forward-backward step
        disable_running_stats(model)
        smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()
        optimizer.second_step(zero_grad=True)

        with torch.no_grad():
            correct = torch.argmax(predictions.data, 1) == targets
            log(model, loss.cpu(), correct.cpu(), scheduler.lr())
            scheduler(epoch)

    model.eval()
    log.eval(len_dataset=len(train_data))

    with torch.no_grad():
        for i, (batch_images, batch_labels) in enumerate(test_loader):
            inputs = batch_images.to(device)
            targets = batch_labels.to(device) 

            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            log(model, loss.cpu(), correct.cpu())

correct = 0
total = 0

for i, (batch_images, batch_labels) in enumerate(test_loader):
        inputs = batch_images.to(device)
        targets = batch_labels.to(device) 
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of test images: %f %%' % (100 * float(correct) / total))