In [None]:
from tools import utils, config, trainer, parts
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import functional as F
from tools import utils

plt.style.use('fast')
PLOT_DIR = 'plots'

cfg = config.from_yaml("experiments\exp1_cmpe597_regular_mnist\config.yaml")

dataset = utils.load_dataset_module(**cfg.data_supervised)
dataset.torch_seed()
test_loader = dataset.get_test_loader(**cfg.data_supervised)
test_dataset = dataset.get_test_dataset()

# Trained model
model = utils.load_model(**cfg.model)
model.load_state_dict(torch.load("experiments\exp1_cmpe597_regular_mnist\checkpoint.pth"))

part_manager = parts.PartManager(model)
part_manager.enable_all()

trn = trainer.ModelTrainer(model, cfg.trainer_sup, part_manager)

model.eval();

### Get Adversarial Examples

In [None]:
from attack import do_attack
attack_examples = do_attack(model, trn.device, test_loader, fast = True)

### Get samples from each class

In [None]:
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
n_classes = len(classes)
class_size = 27
classes_to_accumulate = [c for c in classes]
examples = {i: [] for i in classes}
predictions = {c: [] for c in classes}

i = 0
while i < len(test_dataset):
    x, y = test_dataset[i]
    if not y in classes:
        i += 1
        continue
    
    if not y in classes_to_accumulate:
        i+=1
        continue
    
    x = attack_examples[i]
    x = x.unsqueeze(0).to(trn.device)
    y_pred = model(x)
    y_pred = np.argmax(y_pred.cpu().detach().numpy())
    
    if y_pred != y:
        examples[y].append(i)
        predictions[y].append(y_pred)
    if len(examples[y]) == class_size:
        classes_to_accumulate.remove(y)
        if len(classes_to_accumulate) == 0:
            break
    i+=1

In [None]:
print("Number of perturbed examples per class:")
[f'{k}:{len(v)}' for k, v in examples.items()]

### Get Outputs for Adversarial and Regular Examples

In [None]:
part_i = 0

part_output_list_adv = [[] for _ in range(n_classes)]
part_output_list = [[] for _ in range(n_classes)]

for class_i, class_examples in examples.items():
    class_order = classes.index(class_i)
    for example_i in class_examples:
        x, y = test_dataset[example_i]
        x = x.unsqueeze(0).to(trn.device)
        _ = model(x)
        
        part_output = getattr(part_manager.parts[part_i].get_loss_end_layer(), trainer.SAVED_OUTPUT_NAME)
        part_output = torch.squeeze(part_output).cpu().detach().numpy()
        part_output_list[class_order].append(part_output)
        
        x = attack_examples[example_i]
        x = x.unsqueeze(0).to(trn.device)
        _ = model(x)
        
        part_output = getattr(part_manager.parts[part_i].get_loss_end_layer(), trainer.SAVED_OUTPUT_NAME)
        part_output = torch.squeeze(part_output).cpu().detach().numpy()
        part_output_list_adv[class_order].append(part_output)
    
activations = np.array(part_output_list)
activations_adv = np.array(part_output_list_adv)

activations.shape # i_class, i_example, i_kernel, h, w

In [None]:
def get_flattened_activations(activations, class_i, channel_i):
    flat_size = class_size
    h = activations.shape[-2]

    flattened_kernel = np.zeros((h, flat_size * h))
    for example_i in range(activations.shape[1]):
        flattened_kernel[:, example_i * h: (example_i+1) * h] = activations[class_i, example_i, channel_i]
    return flattened_kernel

class_i = 0
channel_i = 0
flattened_activations = get_flattened_activations(activations, class_i, channel_i)
flattened_activations_adv = get_flattened_activations(activations_adv, class_i, channel_i)

### Flattened activations for a single kernel - Regular

In [None]:
plt.figure(figsize=(25, 5))
plt.imshow(flattened_activations)

### Flattened activations for a single kernel - Adversarial

In [None]:
plt.figure(figsize=(25, 5))
plt.imshow(flattened_activations_adv)

In [None]:
KERNEL_I = 1
for class_i in range(n_classes):
    plt.figure(figsize=(30, 3))
    flat_kernel = get_flattened_activations(activations, class_i, KERNEL_I)
    plt.imshow(flat_kernel)

In [None]:
for class_i in range(n_classes):
    plt.figure(figsize=(30, 5))
    flat_kernel = get_flattened_activations(activations_adv, class_i, KERNEL_I)
    plt.imshow(flat_kernel)

In [None]:
activations_median = np.mean(activations, axis=1).reshape((10, -1))
activations_adv_median = np.mean(activations_adv, axis=1).reshape((10, -1))

activations_median_std = np.nan_to_num(activations_median.std(axis=0))
activations_adv_median_std = np.nan_to_num(activations_adv_median.std(axis=0), nan=1.0)
activations_median_std[activations_median_std==0] = 1.0
activations_adv_median_std[activations_adv_median_std==0] = 1.0

activations_median = (activations_median - activations_median.mean(axis=0)) / activations_median_std
activations_adv_median = (activations_adv_median - activations_adv_median.mean(axis=0)) / activations_adv_median_std

"""activations_median = np.nan_to_num(activations_median, nan=np.nanmin(activations_median))
activations_adv_median = np.nan_to_num(activations_adv_median, nan=np.nanmin(activations_adv_median))"""

In [None]:
exclude_i = 8
activations_median_copy = activations_median.copy()
activations_median_copy[exclude_i] = 0
max_activations = np.max(activations_median_copy, axis=0)
activations_median_copy[exclude_i] = 999
min_activations = np.min(activations_median_copy, axis=0)

In [None]:
excluded_class_most_activated = activations_median[exclude_i] - max_activations > 0
excluded_class_least_activated = activations_median[exclude_i] - min_activations < 0
excluded_feature_selected = np.logical_or(excluded_class_most_activated, excluded_class_least_activated)
diff_order = np.argsort(activations_median[exclude_i])

print(f"Excluded class is least activated in these features: {excluded_class_least_activated.sum()}")
print(f"Excluded class is most activated in these features: {excluded_class_most_activated.sum()}")

In [None]:
plt.figure(figsize=(35, 5))
plt.imshow(activations_median[:, diff_order][:, -200:])

plt.figure(figsize=(35, 5))
plt.imshow(activations_adv_median[:, diff_order][:, -200:])
print("Activations sorted by their values for y=8 on regular examples")

In [None]:
plt.figure(figsize=(35, 5))
act = activations_median[:, excluded_feature_selected]
act_order = np.argsort(act[exclude_i, :])
plt.imshow(act[:, act_order])

plt.figure(figsize=(35, 5))
act_adv = activations_adv_median[:, excluded_feature_selected]
plt.imshow(act_adv[: , act_order])
print("Activations where value for y=8 is the min or max among classes")

In [None]:
n_features = activations_median.shape[1]
robustness = np.zeros(n_features)
usefulness = np.zeros(n_features)
for i in range(n_features):
    usefulness[i] = (activations_median[:, i].max() - activations_median[:, i].mean()) / activations_median[:, i].std()
    robustness[i] = np.corrcoef(activations_median[:, i], activations_adv_median[:, i])[0, 1]
        
robustness = np.nan_to_num(robustness)
usefulness = np.nan_to_num(usefulness)

In [None]:
feature_order = np.argsort(usefulness)
plt.figure(figsize=(35, 7))
plt.imshow(activations_median[:, feature_order][:, -150:])
plt.title("Activations for regular examples")

plt.figure(figsize=(35, 7))
plt.imshow(activations_adv_median[:, feature_order][:, -150:])
plt.title("Activations for adversarial examples")
print("Features sorted by usefulness")

In [None]:
feature_order = np.argsort(robustness)
robustness_sorted = np.sort(robustness)
plt.figure(figsize=(35, 7))
plt.imshow(activations_median[:, feature_order][:, -75:])
plt.title("Activations for regular examples (top 75)")

plt.figure(figsize=(35, 7))
plt.imshow(activations_adv_median[:, feature_order][:, -75:])
plt.title("Activations for adversarial examples (top 75)")

plt.figure(figsize=(35, 7))
plt.imshow(activations_median[:, feature_order][:, :75])
plt.title("Activations for regular examples (bottom 75)")

plt.figure(figsize=(35, 7))
plt.imshow(activations_adv_median[:, feature_order][:, :75])
plt.title("Activations for adversarial examples (bottom 75)")
print("Features sorted by robustness")
print(f"Top 25: {list(reversed(robustness_sorted[-25:]))}")
print(f"Bottom 25: {robustness_sorted[:25]}")

In [None]:
plt.scatter(usefulness, robustness)
plt.xlim([1.0, 3.1])
plt.xlabel("usefulness")
plt.ylabel("robustness")
print(f"Correlation: {np.corrcoef(usefulness, robustness)[1, 0]}")

### Robustness for kernels

In [None]:
kernel_robustness = robustness.reshape((20, 7, 7)).mean(axis=(1,2))
kernel_usefulness = usefulness.reshape((20, 7, 7)).mean(axis=(1,2))

In [None]:
for i in np.argsort(kernel_robustness):
    plt.figure(figsize=(3, 3))
    plt.imshow(model.conv1.weight[i, 0].cpu().detach().numpy(), cmap="gray")
    plt.title(f"{i}: robustness={kernel_robustness[i]:.3f}, usefulness={kernel_usefulness[i]:.3f}")

In [None]:
plt.scatter(kernel_usefulness, kernel_robustness)
plt.xlabel("kernel_usefulness")
plt.ylabel("kernel_robustness")
print(f"Correlation: {np.corrcoef(kernel_usefulness, kernel_robustness)[1, 0]}")

### Remove kernels

In [None]:
kernels_to_zero = np.argsort(kernel_robustness)[:2]
new_w = model.conv1.weight.clone()
new_b = model.conv1.bias.clone()

for i_kernel in kernels_to_zero:
    new_w[i_kernel] = 0.0
    new_b[i_kernel] = -100000.0

In [None]:
model.conv1.weight = torch.nn.Parameter(new_w)
model.conv1.bias = torch.nn.Parameter(new_b)
model.eval();

In [None]:
from attack import do_attack
attack_examples = do_attack(model, trn.device, test_loader, fast = False, first_n = 5000)

### Remove features

In [None]:
featrues_to_zero = np.argsort(usefulness)[:10]
feature_mask = np.ones(activations.shape[2] * activations.shape[3] * activations.shape[4])
feature_mask[featrues_to_zero] = 0.0

feature_mask = feature_mask.reshape((1, activations.shape[2], activations.shape[3], activations.shape[4]))
model.feature_gate = torch.nn.Parameter(torch.tensor(feature_mask, dtype=torch.float32).cuda())
model.eval();

In [None]:
from attack import do_attack
attack_examples = do_attack(model, trn.device, test_loader, fast = False)