In [1]:
# !apt update && apt install -y git
# !pip install foolbox

In [2]:
import torch
# import torchvision
import foolbox
import numpy as np
from onnx2torch import convert
import os

In [3]:
task_name='fashion_mnist'
model_type='large'

In [4]:
d = np.load(f'../{task_name}/ds_test.npz')
test_x = torch.Tensor(d['x'].astype(np.float32).transpose(0,3,1,2))
test_t = torch.Tensor(d['t'].astype(np.int32)).type(torch.int64)

In [5]:
def get_model(model_type):
    if model_type == 'large':
        student = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=2,
                ),
            torch.nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=2,
                ),
            torch.nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            # 7*7*16=784
            torch.nn.Flatten(),
            torch.nn.Linear(7*7*16, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10),
        )
    elif model_type == 'medium':
        student = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=2,
                ),
            torch.nn.Conv2d(
                in_channels=16,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=2,
                ),
            # 7*7*16=784
            torch.nn.Flatten(),
            torch.nn.Linear(7*7*16, 10)
        )
    elif model_type == 'small':
        student = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,
                ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(
                kernel_size=2,
                ),
            # 7*7*16=784
            torch.nn.Flatten(),
            torch.nn.Linear(14*14*16, 10)
        )
    else:
        raise ValueError(model_type)
    return student

In [6]:
def create_AEs(model, ds_test, attack_method, epsilons):
    sm = foolbox.PyTorchModel(model, bounds=(0, 1))
    res = {'raw': [], 'clipped': [], 'is_adv': [], 'label': []}
    # Create AEs
    for batch in ds_test:
        x,t = batch
        # print(x.shape)
        # print(t.shape)
        _, clipped, is_adv = attack_method(sm, x.to('cuda'),
                                           t.to('cuda'), epsilons=epsilons)
        res['raw'].append([batch[0]]*len(epsilons))
        res['clipped'].append([v.to('cpu') for v in clipped])
        res['is_adv'].append([v.to('cpu') for v in is_adv])
        res['label'].append([batch[1]]*len(epsilons))

    # Aggregate
    result = []
    for i in range(len(epsilons)):
        tmp = {}
        for k in res.keys():
            tmp[k] = np.concatenate(
                [res[k][v][i].numpy() for v in range(len(res[k]))], axis=0)
        result.append(tmp)

    return result

In [7]:
import pickle
onnx_model_path = f'../{task_name}/medium.onnx'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_model = convert(onnx_model_path)
eps = ['4/255.0', '8/255.0', '16/255.0', '24/255.0']
os.makedirs(os.path.join('res', f'{task_name}', f'{model_type}'), exist_ok=True)
num_models = 111

success_rate = {}
transferability = {}
for i in range(num_models):
    print("##### {}/{} #####".format(i, num_models))
    test_ds = torch.utils.data.TensorDataset(test_x, test_t)
    test_dl = torch.utils.data.DataLoader(
        test_ds,     
        batch_size=512,
        shuffle=False)
    p = os.path.join('.','checkpoint', f'{task_name}', f'student_{model_type}', f'{i}.pt')
    model = get_model(model_type)
    model.load_state_dict(torch.load(p))
    attack_method = foolbox.attacks.PGD(
        abs_stepsize=1/255.0, random_start=True)
    tmp = create_AEs(model, test_dl, attack_method, [eval(e) for e in eps])
    
    for j, e in enumerate(eps):
        aes_ds = torch.utils.data.TensorDataset(torch.Tensor(tmp[j]['clipped'][tmp[j]['is_adv']]), torch.Tensor(tmp[j]['label'][tmp[j]['is_adv']]))
        test_dl = torch.utils.data.DataLoader(
            aes_ds,     
            batch_size = 512,
            shuffle=False
            )
        model = torch_model.to(device)
        predicts = []
        labels = []
        for batch in test_dl:
            y = model(batch[0].to(device)).to('cpu').detach().numpy()
            t = batch[1].numpy()
            predicts.append(y)
            labels.append(t)
        predicts = np.concatenate(predicts, axis=0)
        labels = np.concatenate(labels, axis=0)
        
        tmp[j]['victim_prediction'] = predicts
        tmp[j]['is_attack_success'] = (
            np.argmax(predicts, axis=1) != labels)

        if e in transferability.keys():
            transferability[e].append(np.mean(tmp[j]['is_attack_success']))
            success_rate[e].append(np.mean(tmp[j]['is_adv']))
        else:
            transferability[e] = [np.mean(tmp[j]['is_attack_success'])]
            success_rate[e] = [np.mean(tmp[j]['is_adv'])]

    result = {k: v for k, v in zip(eps, tmp)}
    with open(os.path.join('res', f'{task_name}', f'{model_type}', f'{i}.pkl'), 'wb') as f:
        pickle.dump(result, f)

np.savez(os.path.join('res', f'{task_name}', f'{model_type}', 'result.npz'), transferability=transferability, success_rate=success_rate)

##### 0/111 #####




##### 1/111 #####




##### 2/111 #####




##### 3/111 #####




##### 4/111 #####




##### 5/111 #####




##### 6/111 #####




##### 7/111 #####




##### 8/111 #####




##### 9/111 #####




##### 10/111 #####




##### 11/111 #####




##### 12/111 #####




##### 13/111 #####




##### 14/111 #####




##### 15/111 #####




##### 16/111 #####




##### 17/111 #####




##### 18/111 #####




##### 19/111 #####




##### 20/111 #####




##### 21/111 #####




##### 22/111 #####




##### 23/111 #####




##### 24/111 #####




##### 25/111 #####




##### 26/111 #####




##### 27/111 #####




##### 28/111 #####




##### 29/111 #####




##### 30/111 #####




##### 31/111 #####




##### 32/111 #####




##### 33/111 #####




##### 34/111 #####




##### 35/111 #####




##### 36/111 #####




##### 37/111 #####




##### 38/111 #####




##### 39/111 #####




##### 40/111 #####




##### 41/111 #####




##### 42/111 #####




##### 43/111 #####




##### 44/111 #####




##### 45/111 #####




##### 46/111 #####




##### 47/111 #####




##### 48/111 #####




##### 49/111 #####




##### 50/111 #####




##### 51/111 #####




##### 52/111 #####




##### 53/111 #####




##### 54/111 #####




##### 55/111 #####




##### 56/111 #####




##### 57/111 #####




##### 58/111 #####




##### 59/111 #####




##### 60/111 #####




##### 61/111 #####




##### 62/111 #####




##### 63/111 #####




##### 64/111 #####




##### 65/111 #####




##### 66/111 #####




##### 67/111 #####




##### 68/111 #####




##### 69/111 #####




##### 70/111 #####




##### 71/111 #####




##### 72/111 #####




##### 73/111 #####




##### 74/111 #####




##### 75/111 #####




##### 76/111 #####




##### 77/111 #####




##### 78/111 #####




##### 79/111 #####




##### 80/111 #####




##### 81/111 #####




##### 82/111 #####




##### 83/111 #####




##### 84/111 #####




##### 85/111 #####




##### 86/111 #####




##### 87/111 #####




##### 88/111 #####




##### 89/111 #####




##### 90/111 #####




##### 91/111 #####




##### 92/111 #####




##### 93/111 #####




##### 94/111 #####




##### 95/111 #####




##### 96/111 #####




##### 97/111 #####




##### 98/111 #####




##### 99/111 #####




##### 100/111 #####




##### 101/111 #####




##### 102/111 #####




##### 103/111 #####




##### 104/111 #####




##### 105/111 #####




##### 106/111 #####




##### 107/111 #####




##### 108/111 #####




##### 109/111 #####




##### 110/111 #####


