In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

for testing, this just initializes the dataloader, classifier, and patch model

In [None]:
import torch
from train import patch_loader

loader = patch_loader(split='validation', batch_size=4, num_samples=100, streaming=False)
batch = next(iter(loader))

In [None]:
from train import CLIPClassifier

model = CLIPClassifier(deep=1024)
checkpoint = torch.load('checkpoints/imnet_1k.pt', map_location='cpu')
model.load_state_dict(checkpoint['model'])

In [None]:
from attack import Patch, UniversalPerturbation

patch = Patch(model=model, target_label=965, patch_r=0.2, init_size=1024)
patch_pt = torch.load('checkpoints/attack_v5_unbounded_large_9999.pt', map_location='cpu')
patch.load_params(patch_pt['params'])
_ = patch.eval()

perturbation = UniversalPerturbation(model=model, target_label=965, shape=(4, 3, 224, 224), epsilon=0.1)
checkpoint = torch.load('checkpoints/up_checkpoint.pt', map_location='cpu')
perturbation.load_params(checkpoint['params'])
_ = perturbation.eval()

this gets the model predictions. it gets 3/4 correct here.

In [None]:
logits = patch.forward(batch)
print(torch.argmax(logits, dim=-1))

logits = perturbation.forward(batch)
print(torch.argmax(logits, dim=-1))

you can see the images we're attacking here

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
axes[0][0].imshow(batch['pixel_values'][0])
axes[0][1].imshow(batch['pixel_values'][1])
axes[1][0].imshow(batch['pixel_values'][2])
axes[1][1].imshow(batch['pixel_values'][3])

axes[0][0].axis('off')
axes[0][1].axis('off')
axes[1][0].axis('off')
axes[1][1].axis('off')

plt.tight_layout()
plt.savefig('figures/baseline.png')

here you can see what the input images look like with the patch applied

In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    patched = patch._apply_patch(batch['pixel_values'])

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
axes[0][0].imshow(patched[0])
axes[0][1].imshow(patched[1])
axes[1][0].imshow(patched[2])
axes[1][1].imshow(patched[3])

axes[0][0].axis('off')
axes[0][1].axis('off')
axes[1][0].axis('off')
axes[1][1].axis('off')

plt.tight_layout()
plt.savefig('figures/patch_large.png')

here you can wee what it looks like with the universal perturbation applied

In [None]:
import matplotlib.pyplot as plt

with torch.no_grad():
    perturbed = perturbation.apply_attack(batch['pixel_values'])

fig, axes = plt.subplots(2, 2, figsize=(8, 8))
axes[0][0].imshow(perturbed[0].permute(1, 2, 0))
axes[0][1].imshow(perturbed[1].permute(1, 2, 0))
axes[1][0].imshow(perturbed[2].permute(1, 2, 0))
axes[1][1].imshow(perturbed[3].permute(1, 2, 0))

axes[0][0].axis('off')
axes[0][1].axis('off')
axes[1][0].axis('off')
axes[1][1].axis('off')

plt.tight_layout()
plt.savefig('figures/perturbed_weak.png')

## regressions

In [None]:
!python train.py --config=configs/patch/toy.yml --device=cpu

In [None]:
!python train.py --config=configs/universal/perturbation_weak.yml --device=cpu

In [None]:
!python eval.py --config=configs/eval/perturbation_weak.yml --device=cpu

## figuring out the universal perturbation stuff

In [None]:
import torch
from attack import UniversalPerturbation

up = UniversalPerturbation(None, None, shape=(4, 3, 224, 224))
checkpoint = torch.load('checkpoints/up_checkpoint.pt', map_location='cpu')
up.load_params(checkpoint['params'])