In [2]:
import os
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.test_utils import LeNet5
from advertorch_examples.utils import get_mnist_train_loader
from advertorch_examples.utils import get_mnist_test_loader
from advertorch_examples.utils import TRAINED_MODEL_PATH

In [7]:
seed = 0
mode = 'adv' # 'cln' or 'adv'
train_bs = 50 # train batch size
test_bs = 1000 # test batch size
log_interval = 200

In [8]:
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if mode == 'cln':
    flag_adv_train = False
    nb_epoch = 10
    model_filename = 'mnist_lenet5_clntrained.pth'
elif mode == "adv":
    flag_advtrain = True
    nb_epoch = 90
    model_filename = "mnist_lenet5_advtrained.pth"
else:
    raise RuntimeError('mode must be "cls" or "adv"')

In [10]:
train_loader = get_mnist_train_loader(
    batch_size=train_bs, shuffle=True)
test_loader = get_mnist_test_loader(
    batch_size=test_bs, shuffle=False)

model = LeNet5()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

0.00B [00:00, ?B/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


 99%|█████████▉| 9.80M/9.91M [00:15<00:00, 1.83MB/s]

Extracting /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/train-images-idx3-ubyte.gz



0.00B [00:00, ?B/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0.00/28.9k [00:00<?, ?B/s][A
 57%|█████▋    | 16.4k/28.9k [00:00<00:00, 89.2kB/s][A
32.8kB [00:00, 57.9kB/s]                            [A
0.00B [00:00, ?B/s][A

Extracting /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0.00/1.65M [00:00<?, ?B/s][A
  1%|          | 16.4k/1.65M [00:00<00:21, 77.4kB/s][A
  2%|▏         | 41.0k/1.65M [00:00<00:18, 89.1kB/s][A
  6%|▌         | 98.3k/1.65M [00:00<00:13, 113kB/s] [A
 11%|█▏        | 188k/1.65M [00:01<00:09, 149kB/s] [A
 18%|█▊        | 303k/1.65M [00:01<00:06, 193kB/s][A
 25%|██▍       | 410k/1.65M [00:01<00:05, 231kB/s][A
 39%|███▉      | 639k/1.65M [00:01<00:03, 306kB/s][A
 43%|████▎     | 713k/1.65M [00:01<00:02, 324kB/s][A
 47%|████▋     | 778k/1.65M [00:02<00:02, 304kB/s][A
 58%|█████▊    | 950k/1.65M [00:02<00:01, 398kB/s][A
 65%|██████▍   | 1.06M/1.65M [00:02<00:01, 486kB/s][A
 70%|███████   | 1.16M/1.65M [00:02<00:01, 491kB/s][A
 75%|███████▌  | 1.24M/1.65M [00:02<00:00, 521kB/s][A
 79%|███████▉  | 1.31M/1.65M [00:02<00:00, 532kB/s][A
 85%|████████▌ | 1.41M/1.65M [00:03<00:00, 542kB/s][A
 92%|█████████▏| 1.52M/1.65M [00:03<00:00, 635kB/s][A
9.92MB [00:20, 496kB/s]                             [A
Exception in thread 

Extracting /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


8.19kB [00:00, 17.7kB/s]                   

Extracting /Users/tanimu/.advertorch/data/data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!





In [12]:
if flag_advtrain:
    from advertorch.attacks import GradientSignAttack
    adversary = GradientSignAttack(
        model, loss_fn=nn.CrossEntropyLoss(reduction="sum"),
        eps=0.3, clip_min=0.0, clip_max=1.0, targeted=False)

In [14]:
for epoch in range(nb_epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        ori = data
        if flag_advtrain:
            # when performing attack, the model needs to be in eval mode
            # also the parameters should be accumulating gradients
            with ctx_noparamgrad_and_eval(model):
                data = adversary.perturb(data, target)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(
            output, target, reduction='elementwise_mean')
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx *
                len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

    model.eval()
    test_clnloss = 0
    clncorrect = 0

    if flag_advtrain:
        test_advloss = 0
        advcorrect = 0

    for clndata, target in test_loader:
        clndata, target = clndata.to(device), target.to(device)
        with torch.no_grad():
            output = model(clndata)
        test_clnloss += F.cross_entropy(
            output, target, reduction='sum').item()
        pred = output.max(1, keepdim=True)[1]
        clncorrect += pred.eq(target.view_as(pred)).sum().item()

        if flag_advtrain:
            advdata = adversary.perturb(clndata, target)
            with torch.no_grad():
                output = model(advdata)
            test_advloss += F.cross_entropy(
                output, target, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            advcorrect += pred.eq(target.view_as(pred)).sum().item()

    test_clnloss /= len(test_loader.dataset)
    print('\nTest set: avg cln loss: {:.4f},'
          ' cln acc: {}/{} ({:.0f}%)\n'.format(
              test_clnloss, clncorrect, len(test_loader.dataset),
              100. * clncorrect / len(test_loader.dataset)))
    if flag_advtrain:
        test_advloss /= len(test_loader.dataset)
        print('Test set: avg adv loss: {:.4f},'
              ' adv acc: {}/{} ({:.0f}%)\n'.format(
                  test_advloss, advcorrect, len(test_loader.dataset),
                  100. * advcorrect / len(test_loader.dataset)))


Test set: avg cln loss: 0.3822, cln acc: 9310/10000 (93%)

Test set: avg adv loss: 1.1265, adv acc: 6273/10000 (63%)


Test set: avg cln loss: 0.2200, cln acc: 9499/10000 (95%)

Test set: avg adv loss: 1.0391, adv acc: 6314/10000 (63%)


Test set: avg cln loss: 0.1895, cln acc: 9427/10000 (94%)

Test set: avg adv loss: 0.6429, adv acc: 8011/10000 (80%)


Test set: avg cln loss: 0.2426, cln acc: 9237/10000 (92%)

Test set: avg adv loss: 0.4250, adv acc: 8797/10000 (88%)


Test set: avg cln loss: 0.4287, cln acc: 8458/10000 (85%)

Test set: avg adv loss: 0.2796, adv acc: 9216/10000 (92%)


Test set: avg cln loss: 0.7200, cln acc: 7772/10000 (78%)

Test set: avg adv loss: 0.2011, adv acc: 9425/10000 (94%)


Test set: avg cln loss: 0.8167, cln acc: 7713/10000 (77%)

Test set: avg adv loss: 0.1621, adv acc: 9543/10000 (95%)


Test set: avg cln loss: 0.6945, cln acc: 8060/10000 (81%)

Test set: avg adv loss: 0.1285, adv acc: 9649/10000 (96%)


Test set: avg cln loss: 0.7529, cln acc: 7798/1

KeyboardInterrupt: 

In [None]:
torch.save(
    model.state_dict(),
    os.path.join(TRAINED_MODEL_PATH, model_filename))