In [1]:
%%capture
# for plotting (not really necessary)
!pip install --upgrade plotly
!pip install -U kaleido
!pip install asciichartpy

# main adversarial attack toolbox 
!pip install foolbox


In [2]:
from google.colab import drive
drive.mount("/content/drive")
%cd "/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust"

Mounted at /content/drive
/content/drive/MyDrive/Courses/Fall 2021/dlsys/bnn-cf-vs-robust


In [155]:
import os, glob 
import pandas as pd 
import numpy as np 
import yaml 
from pathlib import Path
from tqdm.notebook import tqdm 
import time

import asciichartpy
from asciichartpy import plot as ascii_plt


In [4]:
input_root = Path('data/input/pmnist_robustness')
data_root = Path('data/output/exp1-pmnist-robustness')
fig_root = Path('figures/exp1-pmnist_robustness')

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import numpy as np

import foolbox as fb 
import foolbox.attacks as fa
import eagerpy as ep

In [7]:
from src.models_utils import BNN
from src.pmnist_robustness_data_utils import TaskDataSet

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))
])

common_dload_args = dict(
    batch_size  = 1
)

test_dataset = TaskDataSet(input_root / 'task-01/original/test', transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=True, **common_dload_args)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [237]:
data_dir = 'pmnist_robustness_[2048x2048]_[meta=1.35]'
model_dir = data_root / data_dir / 'models'
ckpt = torch.load(model_dir / 'task-03.pt') 
model = BNN(**ckpt['model_args']).to(device)
model.load_state_dict(ckpt['model_states'])
model.eval()

BNN(
  (layers): ModuleDict(
    (fc1): BinarizeLinear(in_features=784, out_features=2048, bias=False)
    (bn1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc2): BinarizeLinear(in_features=2048, out_features=2048, bias=False)
    (bn2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc3): BinarizeLinear(in_features=2048, out_features=10, bias=False)
    (bn3): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [239]:
fmodel = fb.PyTorchModel(model, bounds=(0,1), device=device)

In [238]:
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=10000)
images, labels = next(iter(test_loader))
images, labels = ep.astensors(images.to(device), labels.to(device))

In [190]:
clean_acc = fb.accuracy(fmodel, images, labels)

In [74]:
%%time
_, _, attack_success = fa.FGSM()(fmodel, images, labels, epsilons=epsilons)
robust_accuracy = 1 - attack_success.float32().mean(axis=-1)
print(robust_accuracy)

PyTorchTensor(tensor([9.8030e-01, 9.8000e-01, 9.7870e-01, 9.7880e-01, 9.7640e-01, 9.7560e-01,
                      9.6380e-01, 2.0220e-01, 4.3100e-02, 7.8000e-03, 1.5000e-03, 3.0005e-04],
                     device='cuda:0'))
CPU times: user 2.71 s, sys: 157 ms, total: 2.87 s
Wall time: 2.86 s


``` python
# use these first
attacks = [
    fa.FGSM(),
    fa.LinfPGD(),
    fa.LinfBasicIterativeAttack(),
    fa.LinfAdditiveUniformNoiseAttack(),
    fa.LinfDeepFoolAttack(),
    fa.DDNAttack(),
]
# hold off on these
unused_attacls = [
    fa.VirtualAdversarialAttack(), # cannot use,  need to revisit 
    fa.InversionAttack(), # cannot use, need to revisit 
    fa.L2CarliniWagnerAttack(), # too long, need to revisit 
    fa.NewtonFoolAttack(), # eps=0 is already corrupted, need to revisit 
    fa.EADAttack(), # too long, don't do much, need to revisit params
    fa.SaltAndPepperNoiseAttack(), # don't do much, need to revisit 
    fa.BinarizationRefinementAttack(), # cannot use
    fa.BoundaryAttack(), # black box, need to revisit 
    fa.LinfinityBrendelBethgeAttack(), # too long 
]
```


In [240]:
attacks = [
    fa.FGSM(),
    fa.LinfBasicIterativeAttack(),
    fa.LinfAdditiveUniformNoiseAttack(),
    fa.LinfDeepFoolAttack(),
    fa.DDNAttack(steps=20),
    fa.LinfPGD(steps=20),
]

epsilons = [0.0] \
    + (np.array([[1,2,4,8]]) * np.array([1e-4, 1e-3])[:, None]).flatten().tolist() \
    + (np.logspace(0,3.5,15,base=2.0).round(3)*1e-2).tolist() \
    + [0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

robust_accuracy = dict()
attack_time = dict()

ascii_conf = dict(
    min = 0, 
    max = 1.0, 
    height = 10, 
    colors = [
        asciichartpy.black,
        asciichartpy.blue
    ]
)

for attack in tqdm(attacks):
    print('-'*200)
    print(attack)
    attack_name = type(attack).__name__
    
    t0 = time.time()
    _, _, success = attack(fmodel, images, labels, epsilons=epsilons)
    t1 = time.time()

    success = success.numpy()
    assert success.shape == (len(epsilons), len(images)) and success_.dtype == np.bool

    rob_acc = 1.0 - success.mean(axis=-1)
    elapsed = t1-t0
    robust_accuracy[attack_name] = rob_acc
    attack_time[attack_name] = elapsed

    print(ascii_plt(rob_acc, ascii_conf))

    print('%s - ELAPSED: took %.2f sec' %(attack_name, elapsed))


  0%|          | 0/6 [00:00<?, ?it/s]

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
LinfFastGradientAttack(rel_stepsize=1.0, abs_stepsize=None, steps=1, random_start=False)
    1.00  ┼
    0.90  ┤
    0.80  ┤
    0.70  ┤
    0.60  ┤
    0.50  ┼[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m╮[0m
    0.40  ┤       [30m╰[0m[30m─[0m[30m─[0m[30m─[0m[30m╮[0m
    0.30  ┤           [30m╰[0m[30m─[0m[30m─[0m[30m╮[0m
    0.20  ┤              [30m╰[0m[30m─[0m[30m╮[0m
    0.10  ┤                [30m╰[0m[30m─[0m[30m─[0m[30m╮[0m
    0.00  ┤                   [30m╰[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m
LinfFastGradientAttack - ELAPSED: took 5.64 sec
-------------------------------------------------------------------------

In [248]:
ascii_conf = dict(
    min = 0, 
    max = 0.6, 
    height = 20, 
    colors = [
        asciichartpy.black,
        asciichartpy.lightblue,
        asciichartpy.blue,
        asciichartpy.lightcyan,
        asciichartpy.lightmagenta,
        asciichartpy.magenta,
        asciichartpy.lightgreen,
        asciichartpy.green,

    ]
)
print(ascii_plt([epsilons] + list(robust_accuracy.values()), ascii_conf))


    0.60  ┼                             [30m╭[0m[30m─[0m[30m─[0m[30m─[0m[30m─[0m
    0.57  ┤                     [96m╭[0m[96m─[0m[96m─[0m[96m─[0m[96m╮[0m   [30m│[0m
    0.54  ┤  [96m╭[0m[96m╮[0m     [96m╭[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m─[0m[96m╯[0m   [96m╰[0m[96m─[0m[96m╮[0m [30m│[0m
    0.51  ┤[92m─[0m[92m─[0m[92m─[0m[92m─[0m[92m─[0m[92m─[0m[92m╮[0m[35m─[0m[35m─[0m[35m─[0m[35m─[0m[35m─[0m[35m─[0m[35m─[0m[35m─[0m[35m─[0m[35m╮[0m          [96m│[0m[30m╭[0m[30m╯[0m
    0.48  ┤     [95m╰[0m[92m╰[0m[92m╮[0m        [35m╰[0m[35m─[0m[35m─[0m[35m─[0m[35m╮[0m      [96m│[0m[30m│[0m
    0.45  ┤       [92m╰[0m[92m─[0m[92m╮[0m          [35m╰[0m[35m─[0m[35m─[0m[35m╮[0m   [96m╰[0m[96m╮[0m
    0.42  ┤       [95m╰[0m[95m╮[0m[92m╰[0m[92m╮[0m            [35m╰[0m[35m╮[0m   [96m│[0m
    0.39  ┤        [