In [None]:
from tools import utils, config, trainer, parts
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.nn import functional as F
from tools import utils
import torch.nn as nn

plt.style.use('fast')
PLOT_DIR = 'plots'

cfg = config.from_yaml("experiments\exp9_net6\config.yaml")
dataset = utils.load_dataset_module(**cfg.data_supervised)
dataset.torch_seed()
test_loader = dataset.get_test_loader(**cfg.data_supervised)
test_dataset = dataset.get_test_dataset()

# Trained model
model_trained = utils.load_model(**cfg.model)

part_manager_trained = parts.PartManager(model_trained)
part_manager_trained.enable_all()

i_part = 1
trn_trained = trainer.ModelTrainer(model_trained, cfg.trainer_sup, part_manager_trained)
model_trained.load_state_dict(torch.load("experiments\exp9_net6\checkpoint.pth"))

In [None]:
from autoattack import AutoAttack

def get_acc_autoattack(model, device, loader):
    correct = 0
    n_examples = 0

    adversary = AutoAttack(model_trained, norm='Linf', eps=8/255, version='custom', attacks_to_run=['apgd-ce', 'apgd-dlr'])
    x_all = []
    y_all = []
    for x, y in tqdm(loader):
        x_all.append(x)
        y_all.append(y)
        n_examples += x.shape[0]
        
    x_all = torch.concat(x_all, dim=0).to(device)
    y_all = torch.concat(y_all, dim=0).to(device)
    
    SAMPLE_SIZE = 1000
    sample_idx = np.random.choice(range(SAMPLE_SIZE), SAMPLE_SIZE, replace=False)
    x_all = x_all[sample_idx]
    y_all = y_all[sample_idx]
        
    x_all_attack = adversary.run_standard_evaluation(x_all, y_all, bs=25)
    return x_all_attack

### Get samples from each class

In [None]:
N_CLASSES = 10
class_size = 20
classes_to_accumulate = list(range(N_CLASSES))
examples = {i: [] for i in range(N_CLASSES)}

i = 0
while True:
    x, y = test_dataset[i]
    if not y in classes_to_accumulate:
        i+=1
        continue
    
    examples[y].append(i)
    if len(examples[y]) == class_size:
        classes_to_accumulate.remove(y)
        if len(classes_to_accumulate) == 0:
            break
    i+=1

### Regular Examples

In [None]:
extract_layer = model_trained.se1

part_output_trained_list = [[] for _ in range(N_CLASSES)]

for class_i, class_examples in examples.items():
    for example_i in class_examples:
        x, y = test_dataset[example_i]
        x = x.unsqueeze(0).to(trn_trained.device)
        _ = model_trained(x)
        
        part_output_trained = extract_layer.sigmoid_output
        part_output_trained = torch.squeeze(part_output_trained).cpu().detach().numpy()
        part_output_trained_list[class_i].append(part_output_trained)
    
activations_trained = np.array(part_output_trained_list)
activations_trained.shape # i_class, i_example, i_kernel, h, w

In [None]:
x_plot = activations_trained
for class_i in range(N_CLASSES):
    plt.figure(figsize=(25, 5))
    plt.imshow(x_plot[class_i])