In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from models import *

import numpy as np
import matplotlib.pyplot as plt

import os
import datetime

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [3]:
cifar_train = datasets.CIFAR10('./data', train=True, download=True, transform=transforms.ToTensor())
cifar_test = datasets.CIFAR10('./data', train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(cifar_train, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(cifar_test, batch_size=128, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img_tensor):
    img = torchvision.utils.make_grid(img_tensor)
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
    plt.show

def show_labels(labels):
    print(' '.join('%5s' % classes[labels[j]] for j in range(labels.size()[0])))

Files already downloaded and verified
Files already downloaded and verified


In [4]:
net = ResNet50()
net.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (

In [5]:
def fgsm(model, images, labels, epsilon):
    training =  model.training
    model.eval()
    delta = torch.zeros_like(images, requires_grad=True)
    loss = nn.CrossEntropyLoss()(model(images + delta), labels)
    loss.backward()
    perturbed_images = images + epsilon * delta.grad.detach().sign()
    perturbed_images = torch.clamp(perturbed_images, 0, 1)
    if training:
        model.train()
    return perturbed_images


def pgd(model, images, labels, epsilon, alpha=0.005, num_iter=20):
    training = model.training
    model.eval()
    delta = torch.zeros_like(images, requires_grad=True)
    for t in range(num_iter):
        loss = nn.CrossEntropyLoss()(model(images + delta), labels)
        loss.backward()
        delta.data = delta + alpha * delta.grad.detach().sign()
        delta.data = torch.clamp(delta, -epsilon, epsilon)
        delta.data = torch.max(torch.min(delta, 1 - images), -images)
        delta.grad.zero_()
    if training:
        model.train()
    return images + delta

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9, weight_decay=5e-4)
epsilon = 0.05

In [None]:
n_epoch = 350

for epoch in range(n_epoch):
    
    if epoch == 150:
        print('Change learning rate to 1e-2')
        optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
    elif epoch == 250:
        print('Change learning rate to 1e-3')
        optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)
    
    running_loss = 0.0
    correct = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        
        adv_inputs = pgd(net, inputs, labels, epsilon).detach()
        
        optimizer.zero_grad()
        
        outputs = net(adv_inputs)
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    progress_text = '[%d] loss: %.9f accuracy: %.3f' % (epoch + 1, running_loss / len(cifar_train), 100 * correct / len(cifar_train))
    
    print(progress_text)
    
    with open('adv_training.txt', 'a') as file:
        file.write(str(datetime.datetime.now().time()) + '\n')
        file.write(progress_text + '\n')

    
    PATH = './resnet50_adversarial_training' + str(epoch + 1) + '.pth'
    torch.save(net.state_dict(), PATH)
    if epoch > 0:
        os.remove('resnet50_adversarial_training' + str(epoch) + '.pth')

print('Finished Training')

[1] loss: 0.024812463 accuracy: 10.320
[2] loss: 0.017727672 accuracy: 13.054
[3] loss: 0.016644739 accuracy: 20.956
[4] loss: 0.015772856 accuracy: 25.684
[5] loss: 0.015040252 accuracy: 28.576
[6] loss: 0.014415308 accuracy: 31.328
[7] loss: 0.013935585 accuracy: 33.522
[8] loss: 0.013568316 accuracy: 34.752
[9] loss: 0.013156182 accuracy: 37.050
[10] loss: 0.012769820 accuracy: 38.838
[11] loss: 0.012334563 accuracy: 40.554
[12] loss: 0.011889038 accuracy: 42.844
[13] loss: 0.011517749 accuracy: 44.706
[14] loss: 0.011129611 accuracy: 46.522
[15] loss: 0.010789480 accuracy: 48.176
[16] loss: 0.010475808 accuracy: 49.420
[17] loss: 0.010199490 accuracy: 50.780
[18] loss: 0.009963608 accuracy: 51.852
[19] loss: 0.009682876 accuracy: 53.230
[20] loss: 0.009491107 accuracy: 54.234


In [None]:
PATH = './resnet50_adversarial_training.pth'
torch.save(net.state_dict(), PATH)