In [1]:
from sam import SAM
from utility.step_lr import StepLR
from wide_res_net import WideResNet
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats
from smooth_cross_entropy import smooth_crossentropy

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.utils
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import numpy as np
import random
import os

import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
train_data = dsets.MNIST(root='data/',
                         train=True,
                         transform=transforms.ToTensor(),
                         download=True)

test_data = dsets.MNIST(root='data/',
                        train=False,
                        transform=transforms.ToTensor(),
                        download=True)

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [38]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv_layer = nn.Sequential(
            nn.Conv2d(1, 16, 5), #[100,16,24,24]
            nn.ReLU(),
            nn.Conv2d(16, 32, 5),#[100,32,20,20]
            nn.ReLU(),
            nn.MaxPool2d(2, 2),#100,32,10,10
            nn.Conv2d(32, 64, 5),#100,64,6,6
            nn.ReLU(),
            nn.MaxPool2d(2, 2)#100,64,3,3 -> 여기서 batch 빼고 나머지 64*3*3
        )
        
        self.fc_layer = nn.Sequential(
            nn.Linear(64*3*3, 100), #여기서 64*3*3 -> 100
            nn.ReLU(),
            nn.Linear(100, 10)#100 -> 10
        )       
        
    def forward(self,x):
        out = self.conv_layer(x)
        out = out.view(-1,64*3*3)
        out = self.fc_layer(out)

        return out

In [45]:
model = CNN().to(device)

In [46]:
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=2.0, adaptive=True, lr=0.1, momentum=0.9, weight_decay=0.0005)

In [41]:
optimizer._grad_norm()

tensor(0., device='cuda:0', dtype=torch.float64)

In [42]:
scheduler = StepLR(optimizer,0.1, 200)

In [47]:
batch_size = 100

train_loader = DataLoader(dataset=train_data,
                          batch_size=batch_size,
                          shuffle=True)

test_loader = DataLoader(dataset=test_data,
                         batch_size=batch_size,
                         shuffle=False)

In [51]:
for epoch in range(20):
    model.train()

    #for batch in train_data.train:
    for i, (batch_images, batch_labels) in enumerate(train_loader):
        inputs = batch_images.to(device)
        targets = batch_labels.to(device) 
        
        # first forward-backward step
        enable_running_stats(model)
        predictions = model(inputs)
#         loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
#         loss.mean().backward()
        loss = nn.CrossEntropyLoss()(predictions, targets)
        loss.backward()
        optimizer.first_step(zero_grad=True)

        # second forward-backward step
        disable_running_stats(model)
        #smooth_crossentropy(model(inputs), targets, smoothing=0.1).mean().backward()
        nn.CrossEntropyLoss()(model(inputs), targets).backward()
        optimizer.second_step(zero_grad=True)

        with torch.no_grad():
            correct = torch.argmax(predictions.data, 1) == targets
            #log(model, loss.cpu(), correct.cpu(), scheduler.lr())
            scheduler(epoch)

    model.eval()
  #  log.eval(len_dataset=len(dataset.test))

    with torch.no_grad():
        for i, (batch_images, batch_labels) in enumerate(test_loader):
        #for batch in dataset.test:
            inputs = batch_images.to(device)
            targets = batch_labels.to(device) 

            predictions = model(inputs)
#            loss = smooth_crossentropy(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
          #  log(model, loss.cpu(), correct.cpu())

correct = 0
total = 0

for i, (batch_images, batch_labels) in enumerate(test_loader):
    #for batch in dataset.test:
        inputs = batch_images.to(device)
        targets = batch_labels.to(device) 
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of test images: %f %%' % (100 * float(correct) / total))

Accuracy of test images: 11.000000 %


In [53]:
model = CNN().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [54]:
for epoch in range(20):
    model.train()

    #for batch in train_data.train:
    for i, (batch_images, batch_labels) in enumerate(train_loader):
        X = batch_images.cuda()
        Y = batch_labels.cuda()

        predictions = model(X)
        #loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
        loss = nn.CrossEntropyLoss()(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
correct = 0
total = 0

for images, labels in test_loader:
    
    images = images.cuda()
    outputs = model(images)
    
    _, predicted = torch.max(outputs.data, 1)
    
    total += labels.size(0)
    correct += (predicted == labels.cuda()).sum()
    
print('Accuracy of test images: %f %%' % (100 * float(correct) / total))

Accuracy of test images: 16.220000 %
