In [None]:
from tools import utils, config, trainer, parts
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.nn import functional as F
from tools import utils
import torch.nn as nn

"""class AuxNet(nn.Module):
    def __init__(self, model):
        super(AuxNet, self).__init__()
        self.conv1 = model.conv1
        self.bn1 = model.bn1
        self.relu1 = model.relu1
        
        self.pool1 = model.pool1
        
        self.aux_conv = model.aux_conv
        self.aux_reduce_conv = model.aux_reduce_conv
        
        self.bn2 = model.bn2
        self.relu2 = model.relu2
        
        self.pool2 = model.pool2
        self.conv3 = model.conv3
        self.bn3 = model.bn3
        self.relu3 = model.relu3
        
        self.pool3 = model.pool3
        self.flatten1 = model.flatten1
        self.fc1 = model.fc1
        self.relu_out = model.relu_out
        self.fc2 = model.fc2
        self.softmax = model.softmax
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        
        x = self.pool1(x)
        x = self.aux_conv(x)
        x = self.aux_reduce_conv(x)
        x = self.bn2(x)
        x = self.relu2(x)
        
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        
        x = self.pool3(x)
        fc1_in = self.flatten1(x)
        fc1_out = self.fc1(fc1_in)
        relu_out = self.relu_out(fc1_out)
        fc2_out = self.fc2(relu_out)
        logsoftmax_output = self.softmax(fc2_out)
            
        return logsoftmax_output"""

plt.style.use('fast')
PLOT_DIR = 'plots'

cfg = config.from_yaml("experiments\exp5_net5\config.yaml")
dataset = utils.load_dataset_module(**cfg.data_supervised)
dataset.torch_seed()
test_loader = dataset.get_test_loader(**cfg.data_supervised)
test_dataset = dataset.get_test_dataset()

# Trained model
model_trained = utils.load_model(**cfg.model)

part_manager_trained = parts.PartManager(model_trained)
part_manager_trained.enable_all()

i_part = 1
trn_trained = trainer.ModelTrainer(model_trained, cfg.trainer_sup, part_manager_trained)
#trn_trained.set_aux_layer(part_manager_trained.parts[i_part], size_multiplier=1)
model_trained.load_state_dict(torch.load("experiments\exp4_net5\checkpoint.pth"))
#model_trained = AuxNet(model_trained)

In [None]:
import torchattacks
from tqdm import tqdm

def get_accuracy(model, device, attack_on = True):
    correct = 0
    n_examples = 0

    atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
    # adv_images = None
    # y = None
    # adv_pred = None

    for x, y in tqdm(test_loader):
        x.to(device)
        y.to(device)
        n_examples += x.shape[0]
        adv_images = atk(x, y) if attack_on else x.to(device)
        adv_pred = model(adv_images).data.max(1, keepdim=True)[1].cpu()
        correct += adv_pred.eq(y.data.view_as(adv_pred)).sum()
        
    acc = correct / n_examples
    return acc

In [None]:
from autoattack import AutoAttack

def get_acc_autoattack(model, device, loader):
    correct = 0
    n_examples = 0

    adversary = AutoAttack(model_trained, norm='Linf', eps=8/255, version='custom', attacks_to_run=['apgd-ce', 'apgd-dlr'])
    x_all = []
    y_all = []
    for x, y in tqdm(loader):
        x_all.append(x)
        y_all.append(y)
        n_examples += x.shape[0]
        
    x_all = torch.concat(x_all, dim=0).to(device)
    y_all = torch.concat(y_all, dim=0).to(device)
    
    SAMPLE_SIZE = 1000
    sample_idx = np.random.choice(range(SAMPLE_SIZE), SAMPLE_SIZE, replace=False)
    x_all = x_all[sample_idx]
    y_all = y_all[sample_idx]
        
    _ = adversary.run_standard_evaluation(x_all, y_all, bs=25)

In [None]:
get_acc_autoattack(model_trained, trn_trained.device, test_loader)

In [None]:
acc_no_attack = get_accuracy(model_trained, trn_trained.device, attack_on = False)
acc_attack = get_accuracy(model_trained, trn_trained.device, attack_on = True)

print(f"acc_no_attack: {acc_no_attack}")
print(f"acc_attack: {acc_attack}")

In [None]:
"""i_example = 0

plt.figure(figsize=(3, 3))
plt.imshow(adv_images.cpu().detach().numpy()[i_example].transpose([1, 2, 0]))
plt.title("label: " + str(y.detach().numpy()[i_example]) + ' pred: ' + str(adv_pred.detach().numpy()[i_example][0]))"""