In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
device = 'cuda:3' if torch.cuda.is_available() else 'cpu'

In [4]:
torch.cuda.set_device(device)

In [5]:
from utils import load_mnist, load_emnist_letters

In [6]:
# load cifar10 data
trainloader, testloader = load_mnist(data_dir="../data/mnist", batch_size=128, 
                                       test_batch = 128,train_shuffle=True)
# classes = ('plane', 'car', 'bird', 'cat', 'deer',
#            'dog', 'frog', 'horse', 'ship', 'truck')

In [7]:
from utils import train_valid_split

In [8]:
# load dataloader for steal
stealloader,st_testloader = load_emnist_letters(data_dir="../data/mnist", batch_size=128, train_shuffle=False)

In [9]:
st_trloader, st_valloader = train_valid_split(stealloader,10000, datatype="mnist",seed=1228)

total data: 10000


In [50]:
class args:
    save_dir = "../results"
    orig_model = "mnist_orig_net.pth"
#     fake_model = "mnist_fake_mpc_minloss_swd.pth"
#     fake_model = "mnist_fake_mpc_minloss.pth"
    fake_model = "mnist_fake_taylor.pth"
    fake_model_swd = "mnist_fake_taylor_swd.pth"
    tau = 0.4
#     nb_stolen = 10000 # attack model 학습시 사용할 데이터 수
#     st_trloader = "attnet_trloader.dl"
#     st_valloader = "attnet_valloader.dl"
#     att_epochs = 300
    att_model = "mnist_attacknet_hard_taylor_ml_tau%s.pth"%tau
    att_model_swd = "mnist_attacknet_hard_taylor_swd_tau%s.pth"%tau
#     att_model = "mnist_attacknet_mpc_ml_swd_tau%s.pth"%tau

In [51]:
args.att_model

'mnist_attacknet_hard_taylor_ml_tau0.4.pth'

In [52]:
# from cifar_models import Net, Net_logit, AttackNet
from mnist_models import small_NetMNIST, NetMNIST
from utils import CombNet, CombNet_logit, Net_softmax

In [55]:
import os
net = NetMNIST()
net.load_state_dict(torch.load(os.path.join(args.save_dir,args.orig_model),map_location='cpu'))

net_fake = NetMNIST()
net_fake.load_state_dict(torch.load(os.path.join(args.save_dir,args.fake_model),map_location='cpu'))

net_fake_swd = NetMNIST()
net_fake_swd.load_state_dict(torch.load(os.path.join(args.save_dir,args.fake_model_swd),map_location='cpu'))


comb_net = CombNet(net, net_fake, args.tau)
comb_net_swd = CombNet(net, net_fake_swd, args.tau)

In [56]:
print(os.path.join(args.save_dir, args.att_model))
print(os.path.join(args.save_dir, args.att_model_swd))

../results/mnist_attacknet_hard_taylor_ml_tau0.4.pth
../results/mnist_attacknet_hard_taylor_swd_tau0.4.pth


In [57]:
criterion = nn.CrossEntropyLoss()
criterion_NLL = nn.NLLLoss()

# Adversarial PGD attacks

In [58]:
att_net = AttackNetMNIST()
att_net.load_state_dict(torch.load(os.path.join(args.save_dir, args.att_model),map_location='cpu'))

att_net_swd = AttackNetMNIST()
att_net_swd.load_state_dict(torch.load(os.path.join(args.save_dir, args.att_model_swd),map_location='cpu'))

<All keys matched successfully>

In [59]:
from torchattacks import PGD

In [60]:
atk_logit_ml = PGD(att_net, eps=8/255, alpha=2/255, steps=7)
atk_logit_ml.save(data_loader=testloader,save_path = os.path.join(args.save_dir, 'mnist_HE_adv.pth'))

- Save progress: 100.00 % / Robust accuracy: 77.69 % / L2: 17.57829 (0.072 it/s) 	


In [61]:
atk_logit_ml_swd = PGD(att_net_swd, eps=8/255, alpha=2/255, steps=7)
atk_logit_ml_swd.save(data_loader=testloader,save_path = os.path.join(args.save_dir, 'mnist_HE_swd_adv.pth'))

- Save progress: 100.00 % / Robust accuracy: 77.37 % / L2: 17.69803 (0.064 it/s) 	


In [62]:
adv_images, adv_labels = torch.load(os.path.join(args.save_dir, 'mnist_HE_adv.pth'))
adv_images_swd, adv_labels_swd = torch.load(os.path.join(args.save_dir, 'mnist_HE_swd_adv.pth'))

In [63]:
from torch.utils.data import TensorDataset, DataLoader
adv_loader = DataLoader(TensorDataset(adv_images.float(),adv_labels),batch_size=128, shuffle=False)
adv_loader_swd = DataLoader(TensorDataset(adv_images_swd.float(),adv_labels_swd),batch_size=128, shuffle=False)

In [64]:
from utils import test_model

In [65]:
print(test_model(comb_net.to(device),adv_loader,criterion_NLL,device,100.0,pred_prob=True))
print(test_model(comb_net_swd.to(device),adv_loader_swd,criterion_NLL,device,100.0,pred_prob=True))

(tensor(1.2880), tensor(80.7600))
(tensor(1.1626), tensor(80.8100))
