In [None]:
import numpy as np
import copy
import matplotlib
import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms as T
from torch.utils import data

In [None]:
from tqdm import tqdm
import os, time, sys, random
import json

In [None]:
import foolbox as fb
import foolbox.attacks as fa

import pickle

In [None]:
sys.path.append("./Input-Invex-Neural-Network/")

In [None]:
import dtnnlib as dtnn

In [None]:
import nflib
from nflib.flows import SequentialFlow, ActNorm, ActNorm2D, BatchNorm1DFlow, BatchNorm2DFlow
import nflib.res_flow as irf

In [None]:
mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=0.5,
        std=0.5,
    ),
])

train_dataset = datasets.FashionMNIST(root="data/", train=True, download=True, transform=mnist_transform)
test_dataset = datasets.FashionMNIST(root="data/", train=False, download=True, transform=mnist_transform)
# train_dataset = datasets.MNIST(root="data/", train=True, download=True, transform=mnist_transform)
# test_dataset = datasets.MNIST(root="data/", train=False, download=True, transform=mnist_transform)

In [None]:
train_dataset.data.shape

In [None]:
batch_size = 1000
train_loader = data.DataLoader(dataset=train_dataset, num_workers=4, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, num_workers=4, batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device("cuda:0")

In [None]:
for xx, yy in train_loader:
    print(xx.shape)
#     xx, yy = xx.view(-1,28*28).to(device), yy.to(device)
    xx, yy = xx.to(device), yy.to(device)
    print(xx.shape, yy.shape)
    break

## Train Test method

In [None]:
model_dir = "outputs/15.0_models"

In [None]:
criterion = nn.CrossEntropyLoss()

## Adverserial Test

In [None]:
attack_dict = {
    "FGSM": fa.FGSM(), ## LinfFastGradientAttack
    "FGM": fa.FGM(), ## L2FastGradientAttack
    "L2PGD": fa.L2PGD(steps=20), ## higher steps (>10) better ??_!!
    "LinfPGD": fa.LinfPGD(steps=20), ## PGD
    "L1AdamPGD": fa.L1AdamPGD(steps=20, adam_beta1=0.8, adam_beta2=0.95),
    "L2AdamPGD": fa.L2AdamPGD(steps=20, adam_beta1=0.8, adam_beta2=0.95),
    "LinfAdamPGD": fa.LinfAdamPGD(steps=20, adam_beta1=0.8, adam_beta2=0.95),
    "L2AdamBasic": fa.L2AdamBasicIterativeAttack(steps=10), ## default steps
}

In [None]:
attack_dict.keys()

## Model
copied from prev...

In [None]:
class DistanceTransform_Epsilon(dtnn.DistanceTransformBase):
    
    def __init__(self, input_dim, num_centers, p=2, bias=False, epsilon=0.1, itemp=1):
        super().__init__(input_dim, num_centers, p=2)
        
        nc = num_centers
        if epsilon is not None:
            nc += 1
        self.scaler = nn.Parameter(torch.log(torch.ones(1, 1)*itemp))
        self.bias = nn.Parameter(torch.ones(1, nc)*0) if bias else None
        
#         self.epsilon = epsilon
        if epsilon is None:
            self.epsilon = None
        else:
            self.epsilon = dtnn.EMA(mu=epsilon)
        
    def forward(self, x):
        dists = super().forward(x)
        
        if self.epsilon is not None:
            #################################
#             dists = torch.cat([dists, torch.ones(len(x), 1).to(x)*self.epsilon], dim=1)
            #################################
            if self.training:
#                 mdist = dists.min().data
#                 mdist = dists.max().data
                mdist = dists.mean().data

                self.epsilon(mdist)
            dists = torch.cat([dists, torch.ones(len(x), 1).to(x)*self.epsilon.mu], dim=1)
            #################################
        
        ## scale the dists
        dists = 1-dists*torch.exp(self.scaler)
    
        if self.bias is not None: dists = dists+self.bias
        return dists

In [None]:
class LocalMLP_epsilonsoftmax(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, output_dim, epsilon=1.0, itemp=1):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.new_hidden_dim = 0
        self.output_dim = output_dim
        
        self.layer0 = DistanceTransform_Epsilon(self.input_dim, self.hidden_dim, bias=False, epsilon=epsilon, itemp=itemp)
        
        hdim = self.hidden_dim
        if epsilon is not None:
            hdim += 1
            
#         self.scale_shift = dtnn.ScaleShift(hdim, scaler_init=10, shifter_init=0, scaler_const=True, shifter_const=True)
        self.softmax = nn.Softmax(dim=-1)
        self.layer1 = nn.Linear(hdim, self.output_dim, bias=False)
    
        self.temp_maximum = None
        
    def forward(self, x):
        xo = self.layer0(x)
        ## dropout here creates 0 actv (is relatively high), hence serves as noise --> does not work for high values
#         xo = F.dropout(xo, p=0.001, training=self.training) ## use -inf as dropped value...
#         xo = self.scale_shift(xo)
        xo = self.softmax(xo)
        self.temp_maximum = xo.data
        
        self.layer1.weight.data[:,-1]*=0.
        xo = self.layer1(xo)
        return xo

In [None]:
def load_model( hidden_units, data_init, center_lr):
    
    init = "cent" if data_init else "rand"
    model_name = f"dtesm_identity_I{init}_clrs{center_lr}_h{hidden_units}_mean"
    ckpt = torch.load(f"{model_dir}/{model_name}.pth")
    accuracy = ckpt["acc"]
    
    flows = [irf.Flatten(img_size=[1, 28, 28])]
    backbone = nn.Sequential(*flows).to(device)
    classifier = LocalMLP_epsilonsoftmax(784, hidden_units, 10).to(device)
    model = nn.Sequential(backbone, classifier)
    model.load_state_dict(ckpt["model"])
    
    return model

In [None]:
model = load_model(hidden_units=100, data_init=True, center_lr=0.01)

In [None]:
model[1].layer0.bias ## not used

In [None]:
invbackbone = SequentialFlow([*model[0]]).to(device)
            
_, axs = plt.subplots(5, 5, figsize=(10, 10))
axs = axs.flatten()
with torch.no_grad():
    c = invbackbone.inverse(model[1].layer0.centers.data).data.cpu().numpy().reshape(-1, 28,28)
imgs = c[:len(axs)]
for img, ax in zip(imgs, axs):
    im = ax.imshow(img)
    ax.set_axis_off()
    plt.colorbar(im)

# plt.savefig(f"{observation_dir}/centers_sample.jpg", bbox_inches='tight')
# plt.close()

In [None]:
_, axs = plt.subplots(2, 5, figsize=(10, 4))
axs = axs.reshape(-1)
cls_rep = model[1].layer1.weight[:,:-1].argmax(dim=0)
for i in range(10):
    idx = torch.nonzero(cls_rep == i).cpu()
    imgs = c[idx].reshape(-1, 1, 28, 28)
    img = imgs.mean(axis=(0,1))
    if imgs.shape[0]==1: print("single center at:",i)
    im = axs[i].imshow(img)
    axs[i].set_axis_off()
    plt.colorbar(im)

In [None]:
for i, k in enumerate(attack_dict.keys()):
    print(i, k)

In [None]:
atk_idx = 1 ## ^^
epsilon = 9.0

#######################################
fmodel = fb.PyTorchModel(model.eval(), bounds=(-10, 10), device=device)
attack = attack_dict[list(attack_dict.keys())[atk_idx]]


count = 0
failed = 0
rejected = 0
x_rejected = 0
for i, (xx, yy) in enumerate(test_loader): ## randomize test_loader ^^
    xx = xx.to(device)
    yy = yy.to(device)
    break
    
### without adversarial
# yout = model(xx)
# reject_hid = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
# reject = reject_hid
# x_rejected += int(reject.type(torch.float32).sum())

### with adversarial
unbound_advs, advs, success = attack(fmodel, xx, yy, epsilons=1.0)   
grad = xx-unbound_advs
grad = grad/(torch.norm(grad.view(xx.shape[0], -1), dim=1)[:, None, None, None])
advs = (xx - grad*epsilon).clip(-10, 10)
yout = model(advs)
success = ~(yout.argmax(dim=1) == yy)

# reject = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
# rejected += int(reject.type(torch.float32).sum())

# fail = torch.bitwise_and(success, ~reject).type(torch.float32).sum()
# failed += int(fail)    
# count += len(xx)

In [None]:
advs.shape, advs.device

In [None]:
xx, advs, grad = xx[:,0].cpu(), advs[:,0].cpu(), grad[:,0].cpu()
success = success.cpu().numpy()

In [None]:
fig, axs = plt.subplots(8, 6, figsize=(12, 14))
for i, axx in enumerate(axs.reshape(-1, 3)):
    im = axx[0].imshow(xx[i])
    axx[0].tick_params(which = 'both', size = 0, labelsize = 0)
    plt.colorbar(im)
    _norm = torch.norm(xx[i]).item()
    axx[0].set_xlabel(f"x > norm:{_norm:.2f}")
    
    im = axx[1].imshow(grad[i])
    axx[1].tick_params(which = 'both', size = 0, labelsize = 0)
    plt.colorbar(im)
    _norm = torch.norm(grad[i]).item()
    axx[1].set_xlabel(f"g > succ:{success[i]}")
    
    im = axx[2].imshow(advs[i])
    axx[2].tick_params(which = 'both', size = 0, labelsize = 0)
    plt.colorbar(im)
    _norm = torch.norm(advs[i]).item()
    axx[2].set_xlabel(f"ad > norm:{_norm:.2f}")

### Modified adverserial attack
- to attack based on normalized magnitude of adverserial

In [None]:
attack_dict.keys()

In [None]:
def get_adverserial_rejection(model, epsilon, bounds, attack_str): ## bounds in [1, 10]
    fmodel = fb.PyTorchModel(model.eval(), bounds=[-100, 100], device=device) ## no bound, manually bounded 

    attack = attack_dict[attack_str]
    
    count = 0
    failed = 0
    rejected = 0
    x_rejected = 0
    
    correct = 0
    for i, (xx, yy) in enumerate(test_loader):
        xx = xx.to(device)
        yy = yy.to(device)

        yout = model(xx)
        _, predicted = yout.max(1)
        correct += predicted.eq(yy).sum().item()
        reject = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
        x_rejected += int(reject.type(torch.float32).sum())

        unbound_advs, advs, success = attack(fmodel, xx, yy, epsilons=1.0)   
        grad = xx-unbound_advs
        grad = grad/(torch.norm(grad.view(xx.shape[0], -1), dim=1)[:, None, None, None])
        advs = (xx - grad*epsilon).clip(*bounds)
        yout = model(advs)
        success = ~(yout.argmax(dim=1) == yy)
        
        reject = model[1].temp_maximum.max(dim=1)[1] == model[1].hidden_dim
        rejected += int(reject.type(torch.float32).sum())

        fail = torch.bitwise_and(success, ~reject).type(torch.float32).sum()
        failed += int(fail)    
        count += len(xx)

    return count, failed, rejected, x_rejected, correct

In [None]:
def search_minimal_adverserial(model, attack_str, adv_epsilon, bounds, training_epsilon, inner_search_iter=2):
    mus = training_epsilon*(2**torch.linspace(-2, 1, steps=10))

    eps_measure_dict = {}
    
    #################################################
    max_xrej, mxe = -1, None
    max_failed, mfe = -1, None
    for mu in mus:
        model[1].layer0.epsilon.mu[0] = mu
        count, failed, rejected, x_rejected, correct = get_adverserial_rejection(model, adv_epsilon, bounds, attack_str)
        measure = (failed+x_rejected)/count ## ^ maximize
#         print("eps:", mu, measure, f"failed: {failed} x_rej: {x_rejected}")
        accuracy = correct/count
        eps_measure_dict[float(mu)] = [measure, count, failed, rejected, x_rejected, accuracy]

        if x_rejected >= max_xrej:
            max_xrej = x_rejected
            mxe = mu
        if failed > max_failed:
            max_failed = failed
            mfe = mu
    lowval, highval = mxe, mfe
    #################################################
    for _ in range(inner_search_iter):
#         print()
        mus = torch.linspace(lowval, highval, 22)[1:-1]
        min_measure, idx = 9e9, None 
        for i, mu in enumerate(mus):
            model[1].layer0.epsilon.mu[0] = mu
            count, failed, rejected, x_rejected, correct = get_adverserial_rejection(model, adv_epsilon, bounds, attack_str)
            measure = (failed+x_rejected)/count ## ^ maximize
#             print("eps:",mu , measure, f"failed: {failed} x_rej: {x_rejected}")
            accuracy = correct/count
            eps_measure_dict[float(mu)] = [measure, count, failed, rejected, x_rejected, accuracy]

            if measure < min_measure:
                min_measure = measure
                idx = i
                
        gap = mus[1]-mus[0]
        lowval, highval = mus[idx]-gap, mus[idx]+gap
                
    all_data = []
    for k, v in sorted(eps_measure_dict.items()):
        all_data.append([k, *v])
#         print(all_data[-1])
    all_data = np.array(all_data)
    print("Search Finished\n")
    return all_data, ["measure", "count", "failed", "rejected", "x_rejected", "accuracy"]

## Benchmark Model Training

In [None]:
# breakall=False
for center_lr_scaler in [1.0, 0.01]:
    for hidden_units in [100, 500]:
#     for hidden_units in [100]:
        for data_init in [True, False]:
            init = "rand"
            if data_init:
                init = "cent"
            model_name = f"dtesm_identity_I{init}_clrs{center_lr_scaler}_h{hidden_units}_mean"
            ########################################
            print(model_name)
            
            flows = [
                irf.Flatten(img_size=[1, 28, 28]),
                    ]
            backbone = nn.Sequential(*flows).to(device)

            classifier = LocalMLP_epsilonsoftmax(784, hidden_units, 10).to(device)
            model = nn.Sequential(backbone, classifier)
            print("num_parameters", sum([p.numel() for p in model.parameters()]))
            
            ckpt = torch.load(f"{model_dir}/{model_name}.pth")
            model.load_state_dict(ckpt["model"])
            
            invbackbone = SequentialFlow([*backbone]).to(device)
            #######################################
            training_epsilon = model[1].layer0.epsilon.mu.item()
            backup_scaler = model[1].layer0.scaler.data
            
            
            observation_dir = f"outputs/15.1_evaluating_models/{model_name}/"
            os.makedirs(observation_dir, exist_ok=True)
            
            if os.path.exists(f'{observation_dir}/experiments_data.pkl'):
                print("Experiment Alreay Finished... \n NEXT \n")
                continue
                
            #######################################
            _, axs = plt.subplots(5, 5, figsize=(10, 10))
            axs = axs.flatten()
            with torch.no_grad():
                c = invbackbone.inverse(model[1].layer0.centers.data).data.cpu().numpy().reshape(-1, 28,28)
            imgs = c[:len(axs)]
            for img, ax in zip(imgs, axs):
                im = ax.imshow(img)
                ax.set_axis_off()
                plt.colorbar(im)
    
            plt.savefig(f"{observation_dir}/centers_sample.jpg", bbox_inches='tight')
            plt.close()
            #######################################
            _, axs = plt.subplots(2, 5, figsize=(10, 4))
            axs = axs.reshape(-1)
            cls_rep = model[1].layer1.weight[:,:-1].argmax(dim=0)
            for i in range(10):
                idx = torch.nonzero(cls_rep == i).cpu()
                imgs = c[idx].reshape(-1, 1, 28, 28)
                img = imgs.mean(axis=(0,1))
                if imgs.shape[0]==1: print("single center at:",i)
                im = axs[i].imshow(img)
                axs[i].set_axis_off()
                plt.colorbar(im) 
                axs[i].set_xlabel(test_dataset.classes[i])

            plt.savefig(f"{observation_dir}/centers_mean.jpg", bbox_inches='tight')
            plt.close()
            #######################################
            print("INITIATING ADVERSARIAL ATTACK")
            adv_data_dict = {"metadata":
                    {"training_epsilon": training_epsilon,
                     "learned_scaler": backup_scaler[0,0].item()}
                }
            #######################################
#             for temp_scale in [1.0, 0.25, 4.0]:
            for temp_scale in [1.0]:
                for bound in [10, 1]:
#                     for atk_str in attack_dict.keys():
                    for atk_str in random.sample(list(attack_dict.keys()), 3):
                        for adv in [0.5, 1.0, 3.0, 9.0, 20.0]:
                            config = f"{atk_str}_e{adv}_b{bound}_ts{temp_scale}"
                            print(config)
                            model[1].layer0.scaler.data = backup_scaler*temp_scale
                            data, keys = search_minimal_adverserial(model, atk_str, adv, (-bound, bound), training_epsilon, 2)
                            adv_data_dict[config] = data

                            ##### plot after each experiment
                            test_count = data[0,2]
                            plt.plot(data[:,0], data[:,1], lw=2, label="measure", marker='.')
                            plt.plot(data[:,0], data[:,3]/test_count, linestyle="dashed", label="failed")
                            plt.plot(data[:,0], data[:,4]/test_count, linestyle="dotted", label="rejected")
                            plt.plot(data[:,0], data[:,5]/test_count, linestyle="dotted", label="x_rejected")
                            plt.plot(data[:,0], data[:,6], linestyle="dashdot", label="x_accuracy")
                            
                            _mn = f"init:{init} clr:{center_lr_scaler} nh:{hidden_units} acc:{data[-1,6]:.1f}"
                            _cf = f"{attack_type} "+r"$\alpha$"+f":{int(adv_alpha)} b:[{-int(bound)}, {int(bound)}]"
                            plt.xlabel(r"$\epsilon$ for "+f"{_cf}\n{_mn}")
                            plt.legend()
                            plt.savefig(f"{observation_dir}/obs_{config}.png", bbox_inches='tight')
                            plt.close()
                            
#                         break
#                     break
#                 break
            with open(f'{observation_dir}/experiments_data.pkl', 'wb') as f:
                pickle.dump(adv_data_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
            #######################################
            
            #######################################
            