In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'

In [3]:
torch.cuda.set_device(device)

In [3]:
from utils import load_mnist, load_emnist_letters

In [4]:
# load cifar10 data
trainloader, testloader = load_mnist(data_dir="../data/mnist", batch_size=128, 
                                       test_batch = 128,train_shuffle=True)
# classes = ('plane', 'car', 'bird', 'cat', 'deer',
#            'dog', 'frog', 'horse', 'ship', 'truck')

In [5]:
from utils import train_valid_split

In [6]:
# load dataloader for steal
stealloader,st_testloader = load_emnist_letters(data_dir="../data/mnist", batch_size=128, train_shuffle=False)

In [9]:
st_trloader, st_valloader = train_valid_split(stealloader,10000, datatype="mnist",seed=1228)

total data: 10000


In [37]:
class args:
    save_dir = "../results"
    orig_model = "mnist_orig_mpc.pth"
#     fake_model = "mnist_fake_mpc_minloss_swd.pth"
#     fake_model = "mnist_fake_mpc_minloss.pth"
    fake_model = "mnist_fake_small_mpc_minloss.pth"
    fake_model_swd = "mnist_fake_small_mpc_minloss_swd.pth"
    tau = 0.9
#     nb_stolen = 10000 # attack model 학습시 사용할 데이터 수
#     st_trloader = "attnet_trloader.dl"
#     st_valloader = "attnet_valloader.dl"
#     att_epochs = 300
    att_model = "mnist_attacknet_small_mpc_ml_tau%s.pth"%tau
    att_model_swd = "mnist_attacknet_small_mpc_ml_swd_tau%s.pth"%tau
#     att_model = "mnist_attacknet_mpc_ml_swd_tau%s.pth"%tau

In [30]:
args.att_model

'mnist_attacknet_mpc_ml_tau0.9.pth'

In [17]:
# from cifar_models import Net, Net_logit, AttackNet
from mnist_models import AttackNetMNIST, small_AttackNetMNIST
from utils import CombNet, CombNet_logit, Net_softmax

In [21]:
import os
net = AttackNetMNIST()
net.load_state_dict(torch.load(os.path.join(args.save_dir,args.orig_model),map_location='cpu'))
net = Net_softmax(net)

net_fake = small_AttackNetMNIST()
net_fake.load_state_dict(torch.load(os.path.join(args.save_dir,args.fake_model),map_location='cpu'))
net_fake = Net_softmax(net_fake)

net_fake_swd = small_AttackNetMNIST()
net_fake_swd.load_state_dict(torch.load(os.path.join(args.save_dir,args.fake_model_swd),map_location='cpu'))
net_fake_swd = Net_softmax(net_fake_swd)


comb_net = CombNet(net, net_fake, args.tau)
comb_net_swd = CombNet(net, net_fake_swd, args.tau)

In [31]:
print(os.path.join(args.save_dir, args.att_model))
print(os.path.join(args.save_dir, args.att_model_swd))

../results/mnist_attacknet_mpc_ml_tau0.9.pth
../results/mnist_attacknet_mpc_ml_swd_tau0.9.pth


In [32]:
criterion = nn.CrossEntropyLoss()
criterion_NLL = nn.NLLLoss()

# Adversarial PGD attacks

In [39]:
att_net = AttackNetMNIST()
att_net.load_state_dict(torch.load(os.path.join(args.save_dir, args.att_model),map_location='cpu'))

att_net_swd = AttackNetMNIST()
att_net_swd.load_state_dict(torch.load(os.path.join(args.save_dir, args.att_model_swd),map_location='cpu'))

<All keys matched successfully>

In [40]:
from torchattacks import PGD

In [52]:
atk_logit_ml = PGD(att_net, eps=8/255, alpha=2/255, steps=7)
atk_logit_ml.save(data_loader=testloader,save_path = os.path.join(args.save_dir, 'mnist_mpc.adv'))

- Save progress: 100.00 % / Robust accuracy: 45.01 % / L2: 18.76702 (0.073 it/s) 	


In [53]:
atk_logit_ml_swd = PGD(att_net_swd, eps=8/255, alpha=2/255, steps=7)
atk_logit_ml_swd.save(data_loader=testloader,save_path = os.path.join(args.save_dir, 'mnist_mpc_swd.adv'))

- Save progress: 100.00 % / Robust accuracy: 69.37 % / L2: 18.96970 (0.071 it/s) 	


In [54]:
adv_images, adv_labels = torch.load(os.path.join(args.save_dir,  'mnist_mpc.adv'))
adv_images_swd, adv_labels_swd = torch.load(os.path.join(args.save_dir, 'mnist_mpc_swd.adv'))

In [55]:
from torch.utils.data import TensorDataset, DataLoader
adv_loader = DataLoader(TensorDataset(adv_images.float(),adv_labels),batch_size=128, shuffle=False)
adv_loader_swd = DataLoader(TensorDataset(adv_images_swd.float(),adv_labels_swd),batch_size=128, shuffle=False)

In [56]:
from utils import test_model

In [46]:
att_net.to(device)
print(test_model(att_net,adv_loader,criterion,device,100.0))
att_net.cpu()
comb_net.to(device)
print(test_model(comb_net,adv_loader,criterion_NLL,device,100.0,pred_prob=True))
comb_net.cpu()

(tensor(1.7176), tensor(44.9800))
(tensor(6.8202), tensor(52.7400))


In [57]:
# comb_net.to(device)
print(test_model(comb_net.to(device),adv_loader,criterion_NLL,device,100.0,pred_prob=True))
# comb_net.cpu()
# comb_net_swd.to(device)
print(test_model(comb_net_swd.to(device),adv_loader_swd,criterion_NLL,device,100.0,pred_prob=True))
# comb_net_swd.cpu()

(tensor(6.8221), tensor(52.7300))
(tensor(5.5966), tensor(53.5400))


In [48]:
att_net_swd.to(device)
print(test_model(att_net_swd,adv_loader_swd,criterion,device,100.0))
att_net_swd.cpu()
comb_net_swd.to(device)
print(test_model(comb_net_swd,adv_loader_swd,criterion_NLL,device,100.0,pred_prob=True))
comb_net_swd.cpu()

(tensor(1.2514), tensor(69.3500))
(tensor(5.5996), tensor(53.5200))


CombNet(
  (net_orig): Net_softmax(
    (model): AttackNet(
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (classifier): Sequential(
        (0): Linear(in_features=2048, out_features=256, bias=True)
        (1): Linear(in_features=256, out_features=10, bias=True)
      )
    )
  )
  (net_fake): Net_softmax(
    (model): AttackNet(
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (classifier): Sequential(
        (0)