In [1]:
import numpy as np
import torch
import adversary.cw as cw
from adversary.jsma import SaliencyMapMethod
from adversary.fgsm import Attack
import torchvision
import torch.nn.functional as F
import torch.utils.data as Data
from models.mnist_model import MnistModel, MLP
from torchvision import transforms

%reload_ext autoreload
%autoreload 2

In [2]:
MNIST_UNDERCOVER_CKPT = './checkpoint/mnist_undercover.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

mlp = MLP().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mlp.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)


undercoverNet = MnistModel().to(device)
checkpoint = torch.load(MNIST_UNDERCOVER_CKPT, map_location=torch.device(device))
undercoverNet.load_state_dict(checkpoint['net'])

<All keys matched successfully>

In [4]:
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_test)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=4)
trainiter = iter(trainloader)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=4)
testiter = iter(testloader)

# Take BIM attack as an example

In [5]:
undercover_gradient_attacker = Attack(undercoverNet, F.cross_entropy)

In [6]:
# construct bim adversarial samples
# --------------------train---------------------
normal_samples, adversarial_samples = [], []
for x, y in trainloader:
    x, y = x.to(device), y.to(device)
    y_pred = undercoverNet(x).argmax(dim=1)
    
    eps = 0.3
    x_adv = undercover_gradient_attacker.i_fgsm(x, y, eps=eps, alpha=1/255, iteration=int(min(eps*255 + 4, 1.25*eps*255)))
    y_pred_adv = undercoverNet(x_adv).argmax(dim=1)
    selected = (y == y_pred) & (y != y_pred_adv)
    normal_samples.append(x[selected].detach().cpu())
    adversarial_samples.append(x_adv[selected].detach().cpu())
#     break

normal_x = torch.cat(normal_samples, dim=0)
adversarial_x = torch.cat(adversarial_samples, dim=0)
normal_y = torch.zeros(normal_x.shape[0]).long()
adversarial_y = torch.ones(adversarial_x.shape[0]).long()

dba_trainloader = Data.DataLoader(Data.TensorDataset(torch.cat([normal_x, adversarial_x], dim=0),
                                           torch.cat([normal_y, adversarial_y], dim=0)), 
                                  batch_size=512, shuffle=True, num_workers=4)
dba_trainiter = iter(dba_trainloader)

# ----------------test---------------------
normal_samples, adversarial_samples = [], []
for x, y in testloader:
    x, y = x.to(device), y.to(device)
    y_pred = undercoverNet(x).argmax(dim=1)
    
    eps = 0.3
    x_adv = undercover_gradient_attacker.i_fgsm(x, y, eps=eps, alpha=1/255, iteration=int(min(eps*255 + 4, 1.25*eps*255)))
    y_pred_adv = undercoverNet(x_adv).argmax(dim=1)
    selected = (y == y_pred) & (y != y_pred_adv)
    normal_samples.append(x[selected].detach().cpu())
    adversarial_samples.append(x_adv[selected].detach().cpu())
#     break

normal_x = torch.cat(normal_samples, dim=0)
adversarial_x = torch.cat(adversarial_samples, dim=0)
normal_y = torch.zeros(normal_x.shape[0]).long()
adversarial_y = torch.ones(adversarial_x.shape[0]).long()

dba_testloader = Data.DataLoader(Data.TensorDataset(torch.cat([normal_x, adversarial_x], dim=0),
                                           torch.cat([normal_y, adversarial_y], dim=0)), 
                                  batch_size=1024, shuffle=True, num_workers=4)
dba_testiter = iter(dba_testloader)

In [7]:
# train the mlp
epochs = 10
for i in range(epochs):
    train_loss = 0
    correct, total = 0, 0
    for x, y in dba_trainloader:
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)
        _, V1 = undercoverNet(x, dba=True)
        undercover_adv = undercover_gradient_attacker.fgsm(x, y, False, 1/255)
        _, V2 = undercoverNet(undercover_adv, dba=True)
        V = torch.cat([V1, V2, V1 - V2, V1 * V2], axis=-1)
        y_pred = mlp(V)
        total += y.size(0)
        correct += y_pred.argmax(dim=1).eq(y).sum().item()
        
        loss = criterion(y_pred, y)
        train_loss += loss.item()
        #print('epoch: %d, train loss: %.2f, train acc: %.4f' % (epoch, train_loss, acc))
        loss.backward()
        optimizer.step()
    acc = 1.0 * correct / total
    print('epoch: %d, train loss: %.2f, train acc: %.4f' % (i, train_loss, acc))
    #print('epoch: %d, train loss: %.2f' % (i, train_loss))

epoch: 0, train loss: 26.68, train acc: 0.9622
epoch: 1, train loss: 5.93, train acc: 0.9941
epoch: 2, train loss: 3.98, train acc: 0.9959
epoch: 3, train loss: 3.08, train acc: 0.9968
epoch: 4, train loss: 2.49, train acc: 0.9973
epoch: 5, train loss: 2.13, train acc: 0.9978
epoch: 6, train loss: 1.82, train acc: 0.9980
epoch: 7, train loss: 1.66, train acc: 0.9982
epoch: 8, train loss: 1.44, train acc: 0.9985
epoch: 9, train loss: 1.28, train acc: 0.9988


In [8]:
# test
total, correct = 0, 0
for x, y in dba_testloader:
    x, y = x.to(device), y.to(device)
    _, V1 = undercoverNet(x, dba=True)
    undercover_adv = undercover_gradient_attacker.fgsm(x, y, False, 1/255)
    _, V2 = undercoverNet(undercover_adv, dba=True)
    V = torch.cat([V1, V2, V1 - V2, V1 * V2], axis=-1)
    y_pred = mlp(V).argmax(dim=1)
    
    total += y.size(0)
    correct += y_pred.eq(y).sum().item()
print(correct / total)

0.9983099186648358
