In [None]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms as T
from torch.utils import data

In [None]:
from tqdm import tqdm
import os, time, sys
import json

In [None]:
import foolbox as fb
import foolbox.attacks as fa

import pickle

In [None]:
sys.path.append("./Input-Invex-Neural-Network/")

In [None]:
import dtnnlib as dtnn

In [None]:
import nflib
from nflib.flows import SequentialFlow, ActNorm, ActNorm2D, BatchNorm1DFlow, BatchNorm2DFlow
import nflib.res_flow as irf

In [None]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=0.5,
        std=0.5,
    ),
])

train_dataset = datasets.FashionMNIST(root="data/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="data/", train=False, download=True, transform=mnist_transform)
# train_dataset = datasets.MNIST(root="data/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.MNIST(root="data/", train=False, download=True, transform=mnist_transform)

In [None]:
train_dataset.data.shape

In [None]:
batch_size = 100
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=True)

In [None]:
device = torch.device("cuda:0")

In [None]:
for xx, yy in train_loader:
    print(xx.shape)
#     xx, yy = xx.view(-1,28*28).to(device), yy.to(device)
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

## Train Test method

In [None]:
model_dir = "outputs/15.0_models"

In [None]:
criterion = nn.CrossEntropyLoss()

## Adverserial Test

In [None]:
attack_dict = {
    "FGSM": fa.FGSM(), ## LinfFastGradientAttack
    "FGM": fa.FGM(), ## L2FastGradientAttack
    "L2PGD": fa.L2PGD(steps=20), ## higher steps (>10) better ??_!!
    "LinfPGD": fa.LinfPGD(steps=20), ## PGD
    "L1AdamPGD": fa.L1AdamPGD(steps=20, adam_beta1=0.8, adam_beta2=0.95),
    "L2AdamPGD": fa.L2AdamPGD(steps=20, adam_beta1=0.8, adam_beta2=0.95),
    "LinfAdamPGD": fa.LinfAdamPGD(steps=20, adam_beta1=0.8, adam_beta2=0.95),
    "L2AdamBasic": fa.L2AdamBasicIterativeAttack(steps=10), ## default steps
}

In [None]:
attack_dict.keys()

In [None]:
def plot_experiment_instance(ax, hidden_units, data_init, center_lr, 
                             bound=1, temp_scale=1.0, attack_type="FGSM", adv_alpha=0.5):
    
    init = "cent" if data_init else "rand"
    model_name = f"dtesm_identity_I{init}_clrs{center_lr}_h{hidden_units}_mean"
    ckpt = torch.load(f"{model_dir}/{model_name}.pth")
    accuracy = ckpt["acc"]
    
    observation_dir = f"outputs/15.1_evaluating_models/{model_name}/"
    if not os.path.exists(f'{observation_dir}/experiments_data.pkl'):
        print(observation_dir, "does not exist")
        raise ValueError("Given parameter does not have experiments")
        
    center_lr, bound, temp_scale, adv_alpha = float(center_lr), int(bound), float(temp_scale), float(adv_alpha)
    
    config = f"{attack_type}_e{adv_alpha}_b{bound}_ts{temp_scale}"
    with open(f'{observation_dir}/experiments_data.pkl', 'rb') as handle:
        actv_data_dict = pickle.load(handle)
#         print(actv_data_dict.keys())
    data = actv_data_dict[config]
    
    ##### plot after each experiment
    test_count = data[0,2]
    ax.plot(data[:,0], data[:,1], lw=2, label="measure", marker='.')
    ax.plot(data[:,0], data[:,3]/test_count, linestyle="dashed", label="failed")
    ax.plot(data[:,0], data[:,4]/test_count, linestyle="dotted", label="rejected")
    ax.plot(data[:,0], data[:,5]/test_count, linestyle="dotted", label="x_rejected")


    _mn = f"init:{init} clr:{center_lr} nh:{hidden_units} acc:{accuracy*100:.1f}"
    _cf = f"{attack_type} "+r"$\alpha$"+f":{int(adv_alpha)} b:[{-int(bound)}, {int(bound)}]"
    ax.set_xlabel(r"$\epsilon$ for "+f"{_cf}\n{_mn}")

    ax.legend()
    
    ax.hlines(data[:,1].min(), data[0,0], data[-1,0], linestyle='dashed', lw=0.5, color='k')
    ax.hlines(data[-1,1], data[0,0], data[-1,0], linestyle='dotted', lw=0.5, color='b')
    
    ax.set_ylim(0,1)

In [None]:
!mkdir outputs/15.2_observation
!mkdir outputs/15.2_observation/measure_comp_bound

In [None]:
"""
Verifies that bound increase also increase rejection..
"""
## Choose axis to compare
for data_init in [True, False]:
    print("Data Init = ",data_init)
    for attack_type in attack_dict.keys():
        for lr in [1.0, 0.01]:
            for alpha in [1.0, 3.0, 9.0, 20.0]:
                hidden_units = 100
                init = "cent" if data_init else "rand"
                _mn = f"I{init}_clr{lr}_h{hidden_units}"
                _cf = f"{attack_type}_e{alpha}"
                
                print(f"=========alpha:{alpha}=======")
                fig, axs = plt.subplots(1,2, figsize=(8,3))
                plot_experiment_instance(axs[0], hidden_units=hidden_units, data_init=data_init, center_lr=lr, 
                                         bound=1, temp_scale=1.0, attack_type=attack_type, adv_alpha=alpha)
                plot_experiment_instance(axs[1], hidden_units=hidden_units, data_init=data_init, center_lr=lr, 
                                         bound=10, temp_scale=1.0, attack_type=attack_type, adv_alpha=alpha)

                plt.savefig(f"./outputs/15.2_observation/measure_comp_bound/{_mn};{_cf}.pdf", bbox_inches="tight")
                plt.show()    

In [None]:
!mkdir outputs/15.2_observation/measure_comp_init

In [None]:
"""
Verifies that center init is better for rejection
"""
## Choose axis to compare
for hidden_units in [100, 500]:
    print(f"Hidden Units: {hidden_units}")
    for center_lr in [1.0, 0.01]:
        print(f"Center LR: {center_lr}")
        for adv_alpha in [1.0, 3.0, 9.0, 20.0]:
            for attack_type in attack_dict.keys():
                for bounds in [1, 10]:
                    init = "cent" if data_init else "rand"
                    _mn = f"clr{lr}_h{hidden_units}"
                    _cf = f"{attack_type}_e{adv_alpha}_b{bounds}"
                    
                    fig, axs = plt.subplots(1,2, figsize=(8,3))
                    plot_experiment_instance(axs[0], hidden_units=hidden_units, data_init=True, center_lr=center_lr, 
                                             bound=bounds, temp_scale=1.0, attack_type=attack_type, adv_alpha=adv_alpha)
                    plot_experiment_instance(axs[1], hidden_units=hidden_units, data_init=False, center_lr=center_lr, 
                                             bound=bounds, temp_scale=1.0, attack_type=attack_type, adv_alpha=adv_alpha)
                    
                    plt.savefig(f"./outputs/15.2_observation/measure_comp_init/{_mn};{_cf}.pdf", bbox_inches="tight")
                    plt.show()

In [None]:
!mkdir outputs/15.2_observation/measure_hidden_units

In [None]:
"""
Verifies that more hidden units is better for rejection
"""
## Choose axis to compare
for data_init in [True, False]:
    print(f"Hidden Units: {hidden_units}")
    for center_lr in [1.0, 0.01]:
        print(f"Center LR: {center_lr}")
        for adv_alpha in [1.0, 3.0, 9.0, 20.0]:
            for attack_type in attack_dict.keys():
                for bounds in [1, 10]:
                    init = "cent" if data_init else "rand"
                    _mn = f"clr{lr}_I{init}"
                    _cf = f"{attack_type}_e{adv_alpha}_b{bounds}"
                    
                    fig, axs = plt.subplots(1,2, figsize=(8,3))
                    plot_experiment_instance(axs[0], hidden_units=100, data_init=data_init, center_lr=center_lr, 
                                             bound=bounds, temp_scale=1.0, attack_type=attack_type, adv_alpha=adv_alpha)
                    plot_experiment_instance(axs[1], hidden_units=500, data_init=data_init, center_lr=center_lr, 
                                             bound=bounds, temp_scale=1.0, attack_type=attack_type, adv_alpha=adv_alpha)
                    
                    plt.savefig(f"./outputs/15.2_observation/measure_hidden_units/{_mn};{_cf}.pdf", bbox_inches="tight")
                    plt.show()

In [None]:
asdasdasdasd

### Plot samples of adv examples 

In [None]:
class DistanceTransform_Epsilon(dtnn.DistanceTransformBase):
    
    def __init__(self, input_dim, num_centers, p=2, bias=False, epsilon=0.1, itemp=1):
        super().__init__(input_dim, num_centers, p=2)
        
        nc = num_centers
        if epsilon is not None:
            nc += 1
        self.scaler = nn.Parameter(torch.log(torch.ones(1, 1)*itemp))
        self.bias = nn.Parameter(torch.ones(1, nc)*0) if bias else None
        
#         self.epsilon = epsilon
        if epsilon is None:
            self.epsilon = None
        else:
            self.epsilon = dtnn.EMA(mu=epsilon)
        
    def forward(self, x):
        dists = super().forward(x)
        
        if self.epsilon is not None:
            #################################
#             dists = torch.cat([dists, torch.ones(len(x), 1).to(x)*self.epsilon], dim=1)
            #################################
            if self.training:
#                 mdist = dists.min().data
#                 mdist = dists.max().data
                mdist = dists.mean().data

                self.epsilon(mdist)
            dists = torch.cat([dists, torch.ones(len(x), 1).to(x)*self.epsilon.mu], dim=1)
            #################################
        
        ## scale the dists
        dists = 1-dists*torch.exp(self.scaler)
    
        if self.bias is not None: dists = dists+self.bias
        return dists

In [None]:
class LocalMLP_epsilonsoftmax(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, epsilon=1.0, itemp=1):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.new_hidden_dim = 0
        self.output_dim = output_dim
        
        self.layer0 = DistanceTransform_Epsilon(self.input_dim, self.hidden_dim, bias=False, epsilon=epsilon, itemp=itemp)
        
        hdim = self.hidden_dim
        if epsilon is not None:
            hdim += 1
            
#         self.scale_shift = dtnn.ScaleShift(hdim, scaler_init=10, shifter_init=0, scaler_const=True, shifter_const=True)
        self.softmax = nn.Softmax(dim=-1)
        self.layer1 = nn.Linear(hdim, self.output_dim, bias=False)
    
        self.temp_maximum = None
        
    def forward(self, x):
        xo = self.layer0(x)
        ## dropout here creates 0 actv (is relatively high), hence serves as noise --> does not work for high values
#         xo = F.dropout(xo, p=0.001, training=self.training) ## use -inf as dropped value...
#         xo = self.scale_shift(xo)
        xo = self.softmax(xo)
        self.temp_maximum = xo.data
        
        self.layer1.weight.data[:,-1]*=0.
        xo = self.layer1(xo)
        return xo

In [None]:
def load_model( hidden_units, data_init, center_lr):
    
    init = "cent" if data_init else "rand"
    model_name = f"dtesm_identity_I{init}_clrs{center_lr}_h{hidden_units}_mean"
    ckpt = torch.load(f"{model_dir}/{model_name}.pth")
    accuracy = ckpt["acc"]
    
    flows = [irf.Flatten(img_size=[1, 28, 28])]
    backbone = nn.Sequential(*flows).to(device)
    classifier = LocalMLP_epsilonsoftmax(784, hidden_units, 10).to(device)
    model = nn.Sequential(backbone, classifier)
    model.load_state_dict(ckpt["model"])
    
    return model

In [None]:
model = load_model(hidden_units=100, data_init=False, center_lr=0.01)

In [None]:
model[1].layer0.bias ## not used

In [None]:
for i, k in enumerate(attack_dict.keys()):
    print(i, k)

In [None]:
model[1].layer0.epsilon.mu[0] = 10.

In [None]:
atk_idx = 1
epsilon = 9.0 ### alpha (as learning rate)

fmodel = fb.PyTorchModel(model.eval(), bounds=(-10, 10), device=device)
attack = attack_dict[list(attack_dict.keys())[atk_idx]]


count = 0
failed = 0
rejected = 0
x_rejected = 0
for i, (xx, yy) in enumerate(test_loader): ## randomize test_loader ^^
    xx = xx.to(device)
    yy = yy.to(device)
    break
    
### without adversarial
# yout = model(xx)
# reject_hid = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
# reject = reject_hid
# x_rejected += int(reject.type(torch.float32).sum())

### with adversarial
unbound_advs, advs, success = attack(fmodel, xx, yy, epsilons=1.0)   
grad = xx-unbound_advs
grad = grad/(torch.norm(grad.view(xx.shape[0], -1), dim=1)[:, None, None, None])
advs = (xx - grad*epsilon).clip(-10, 10)
yout = model(advs)
success = ~(yout.argmax(dim=1) == yy)

# reject = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
# rejected += int(reject.type(torch.float32).sum())

# fail = torch.bitwise_and(success, ~reject).type(torch.float32).sum()
# failed += int(fail)    
# count += len(xx)

In [None]:
advs.shape, advs.device

In [None]:
# torch.norm(grad.view(xx.shape[0], -1), dim=1)

In [None]:
# advs[-2]
grad.shape, xx.shape

In [None]:
torch.nonzero(torch.isnan(advs).to(torch.float32).sum(dim=(1,2,3)) == 0.).reshape(-1)

In [None]:
model[1].layer0.epsilon.mu

In [None]:
# xx, advs, grad = xx[:,0].cpu(), advs[:,0].cpu(), grad[:,0].cpu()
# success = success.cpu().numpy()

In [None]:
def plot_adv_atk(xx, grad, advs, success):
    xx, advs, grad = xx[:,0].cpu(), advs[:,0].cpu(), grad[:,0].cpu()
#     idx = torch.isnan(grad)

#     non_nan_idx = torch.nonzero(torch.isnan(advs).to(torch.float32).sum(dim=(1,2,3)) == 0.).reshape(-1)
#     assert len(non_nan_idx) >= 5, "Most grads seems nan"
    
        
#     break
    success = success.cpu().numpy()    
    fig, axs = plt.subplots(5, 3, figsize=(6, 10))
    for i, axx in enumerate(axs.reshape(-1, 3)):
        im = axx[0].imshow(xx[i])
        axx[0].tick_params(which = 'both', size = 0, labelsize = 0)
        plt.colorbar(im)
        _norm = torch.norm(xx[i]).item()
        axx[0].set_xlabel(r"$x$"+f"-norm:{_norm:.2f}")

        im = axx[1].imshow(grad[i])
        axx[1].tick_params(which = 'both', size = 0, labelsize = 0)
        plt.colorbar(im)
        _norm = torch.norm(grad[i]).item()
        fld = ["✕", "✓"][int(success[i])]
        axx[1].set_xlabel(r"$g$"+f" ;  fooled:{fld}")

        im = axx[2].imshow(advs[i])
        axx[2].tick_params(which = 'both', size = 0, labelsize = 0)
        plt.colorbar(im)
        _norm = torch.norm(advs[i]).item()
        axx[2].set_xlabel(f"adv-norm:{_norm:.2f}")
#     plt.show()

In [None]:
plot_adv_atk(xx, grad, advs, success)

In [None]:
def get_adverserial_rejection(model, epsilon, bounds, attack_str, xx, yy): ## bounds in [1, 10]
    
    fmodel = fb.PyTorchModel(model.eval(), bounds=[-100, 100], device=device) ## no bound, manually bounded 
    attack = attack_dict[attack_str]
    
    count = 0
    failed = 0
    rejected = 0
    x_rejected = 0
    dists_sum = 0
    actv_sum = 0
    if True:
        xx = xx.to(device)
        yy = yy.to(device)

        yout = model(xx)
        reject = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
        x_rejected += int(reject.type(torch.float32).sum())
        
        min_dist_cent = torch.cdist(xx.reshape(-1, 784), model[1].layer0.centers.data)
        min_dists = min_dist_cent.min(dim=1)[0]
        min_dists = min_dists.mean().item()
        print("for x, dist",min_dists)
        max_neuron_p = model[1].temp_maximum[:,:-1].max(dim=1)[0]
        max_neuron_p = max_neuron_p.mean().item()
        print("for x, neuron_p",max_neuron_p)
        

        unbound_advs, advs, success = attack(fmodel, xx, yy, epsilons=1.0)   
        grad = xx-unbound_advs
        grad = grad/(torch.norm(grad.view(xx.shape[0], -1), dim=1)[:, None, None, None])
        advs = (xx - grad*epsilon).clip(*bounds)
        yout = model(advs)
        success = ~(yout.argmax(dim=1) == yy)
        
        reject = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
        rejected += int(reject.type(torch.float32).sum())

        fail = torch.bitwise_and(success, ~reject).type(torch.float32)
        failed += int(fail.sum())    
        count += len(xx)
        
        min_dist_cent = torch.cdist(advs.reshape(-1, 784), model[1].layer0.centers.data)
        min_dists = min_dist_cent.min(dim=1)[0]
        min_dists = min_dists.mean().item()
        print("for adv, min_dist",min_dists)
        max_neuron_p = model[1].temp_maximum[:,:-1].max(dim=1)[0]
        max_neuron_p = max_neuron_p.mean().item()
        print("for adv, max_neuron",max_neuron_p)
        
#     return count, failed, rejected, x_rejected, dists_sum, actv_sum, xx, grad, advs, fail## sent as success
    return count, failed, rejected, x_rejected, dists_sum, actv_sum, xx, xx-unbound_advs, advs, fail## sent as success

In [None]:
def analyze_adversarial_samples(ax, hidden_units, data_init, center_lr, 
                             bound=10, temp_scale=1.0, attack_type="FGSM", adv_alpha=0.5):
    
    init = "cent" if data_init else "rand"
    model_name = f"dtesm_identity_I{init}_clrs{center_lr}_h{hidden_units}_mean"
    ckpt = torch.load(f"{model_dir}/{model_name}.pth")
    accuracy = ckpt["acc"]
    
    observation_dir = f"outputs/15.1_evaluating_models/{model_name}/"
    if not os.path.exists(f'{observation_dir}/experiments_data.pkl'):
        print(observation_dir, "does not exist")
        raise ValueError("Given parameter does not have experiments")
        
    center_lr, bound, temp_scale, adv_alpha = float(center_lr), int(bound), float(temp_scale), float(adv_alpha)
    
    config = f"{attack_type}_e{adv_alpha}_b{bound}_ts{temp_scale}"
#     config = f"{attack_type}_e{adv_alpha}_b{bound}"

    with open(f'{observation_dir}/experiments_data.pkl', 'rb') as handle:
        actv_data_dict = pickle.load(handle)
    data = actv_data_dict[config]
    
    #     print(test_count)
    min_eps_at, max_eps_at = data[0,0], data[-1,0]
    opt_eps_at = data[data[:,1].argmin(),0]
    
    
    ##### plot after each experiment
    test_count = data[0,2]
    ax.plot(data[:,0], data[:,1], lw=2, label="measure", marker='.')
    ax.plot(data[:,0], data[:,3]/test_count, linestyle="dashed", label="failed")
    ax.plot(data[:,0], data[:,4]/test_count, linestyle="dotted", label="rejected")
    ax.plot(data[:,0], data[:,5]/test_count, linestyle="dotted", label="x_rejected")
    ax.plot(data[:,0], data[:,6], linestyle="dashdot", label="x_accuracy")
#     ax.set_xlabel(f"{model_name[20:]}\n{config}   Acc:{accuracy}")
    
    _mn = f"init:{init} clr:{center_lr} nh:{hidden_units} acc:{accuracy:.1f}"
    _cf = f"{attack_type} "+r"$\alpha$"+f":{int(adv_alpha)} b:[{-int(bound)}, {int(bound)}]"
    ax.set_xlabel(r"$\epsilon$ for "+f"{_cf}\n{_mn}")
#     ax.set_xlabel(_cf)
#     ax.set_ylabel(_mn)
    
    
    ax.legend()
    
#     ax.set_ylabel(f"{data[:,1].min():.3f}@e={data[data[:,1].argmin(),0]:.3f}")
    ax.hlines(data[:,1].min(), data[0,0], data[-1,0], linestyle='dashed', lw=0.5, color='k')
    ax.hlines(data[-1,1], data[0,0], data[-1,0], linestyle='dotted', lw=0.5, color='b')
    
    ax.set_ylim(0,1)
#     plt.show()
    plt.savefig(f"./outputs/15.2_observation/adv_eg_{model_name[15:]}_{config}.pdf")
    

    flows = [irf.Flatten(img_size=[1, 28, 28])]
    backbone = nn.Sequential(*flows).to(device)
    classifier = LocalMLP_epsilonsoftmax(784, hidden_units, 10).to(device)
    model = nn.Sequential(backbone, classifier)
    model.load_state_dict(ckpt["model"])
    
    for _xx, _yy in test_loader:
        break
        
    ### eps at <95% between min and max=1(take for metric)
    mask = np.nonzero(np.logical_and(data[:, 1]>0.98, data[:, 0] < opt_eps_at))[0]
#     print(mask)
    min_eps_at = data[mask[-1], 0]
    print("min", min_eps_at)
    model[1].layer0.epsilon.mu[0] = min_eps_at
    count, failed, rejected, x_rejected, dists_sum, actv_sum, xx, grad, advs, success = get_adverserial_rejection(model, 
                                                                adv_alpha, (-bound, bound), attack_type, _xx, _yy)
    plot_adv_atk(xx, grad, advs, success)
    plt.savefig(f"./outputs/15.2_observation/adv_eg_{model_name[15:]}_{config}_eps{min_eps_at:.2f}_sample.pdf", bbox_inches="tight")
    plt.show()
    
    
#     print(dists_sum/count, actv_sum/count)
    print("optimal", opt_eps_at)
    model[1].layer0.epsilon.mu[0] = opt_eps_at
    count, failed, rejected, x_rejected, dists_sum, actv_sum, xx, grad, advs, success = get_adverserial_rejection(model, 
                                                                adv_alpha, (-bound, bound), attack_type, _xx, _yy)
    plot_adv_atk(xx, grad, advs, success)
    plt.savefig(f"./outputs/15.2_observation/adv_eg_{model_name[15:]}_{config}_eps{opt_eps_at:.2f}_sample.pdf", bbox_inches="tight")
    plt.show()
    
    
    mask = np.nonzero(np.logical_and(data[:, 1]>(0.98*data[-1,1]), data[:, 0] > opt_eps_at))[0]
#     print(mask)
    max_eps_at = data[mask[0], 0]
    print("max", max_eps_at)
    model[1].layer0.epsilon.mu[0] = max_eps_at
    count, failed, rejected, x_rejected, dists_sum, actv_sum, xx, grad, advs, success = get_adverserial_rejection(model, 
                                                                adv_alpha, (-bound, bound), attack_type, _xx, _yy)
    plot_adv_atk(xx, grad, advs, success)
    plt.savefig(f"./outputs/15.2_observation/adv_eg_{model_name[15:]}_{config}_eps{max_eps_at:.2f}_sample.pdf", bbox_inches="tight")
#     print(dists_sum/count, actv_sum/count)    
    plt.show()

In [None]:
# len("dtesm_identity_")

In [None]:
fig, axs = plt.subplots(1,1, figsize=(4,4))
analyze_adversarial_samples(axs, hidden_units=100, data_init=True, center_lr=0.01, 
                            bound=10, temp_scale=1.0, attack_type="FGM", adv_alpha=20.0)

In [None]:
fig, axs = plt.subplots(1,1, figsize=(4,4))
analyze_adversarial_samples(axs, hidden_units=100, data_init=False, center_lr=0.01, 
                            bound=10, temp_scale=1.0, attack_type="FGM", adv_alpha=20.0)

In [None]:
fig, axs = plt.subplots(1,1, figsize=(4,4))
analyze_adversarial_samples(axs, hidden_units=500, data_init=True, center_lr=0.01, 
                            bound=1, temp_scale=1.0, attack_type="LinfAdamPGD", adv_alpha=9.0)

In [None]:
fig, axs = plt.subplots(1,1, figsize=(4,4))
analyze_adversarial_samples(axs, hidden_units=500, data_init=False, center_lr=0.01, 
                            bound=1, temp_scale=1.0, attack_type="LinfAdamPGD", adv_alpha=9.0)