Imports

In [None]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

# If in Colab, we need to pull utilities from github
if IN_COLAB:
  !wget https://raw.githubusercontent.com/soberhofer/Importance_based_Adversarial_Examples/main/load_model.py
  !wget https://raw.githubusercontent.com/soberhofer/Importance_based_Adversarial_Examples/main/utils.py
  !wget https://raw.githubusercontent.com/soberhofer/Importance_based_Adversarial_Examples/main/ImagenetteDataset.py

In [None]:
%pip install -q grad-cam

from utils import imshow, imagenette_outputs, multiple_c_o_m, shift
from ImagenetteDataset import ImagenetteDataset
from load_model import load_model
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import random
import cv2
import torchvision, torch, torchvision.transforms as T
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, EigenGradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from tqdm import tqdm
from scipy.ndimage import center_of_mass

Configure Size of Imagenette Pictures and PyTorch Device

In [None]:
#160 uses ~8GB RAM, 320 uses ~24GB RAM, Fullsize not tested
size = 160


if torch.cuda.is_available():
  device = torch.device('cuda:0')
# elif torch.backends.mps.is_available():
#    device = torch.device("mps")
#    %env PYTORCH_ENABLE_MPS_FALLBACK=1
else:
  device = "cpu"

print(device)
#EigenGradCAM ScoreCAM seems not to work with mps
# AblationCAM is funky
cams = [XGradCAM]
#cams = [EigenCAM, XGradCAM, GradCAM, HiResCAM, GradCAMPlusPlus]
#cams = [EigenCAM]
#%env

Download and unpack images

In [None]:
if size in [160, 320]:
  #Download resized images
  if not os.path.isfile(f'imagenette2-{size}.tgz'):
    !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-{size}.tgz
    !tar -xf imagenette2-{size}.tgz
elif os.path.isdir(f"imagenette2-{size}"):
    print("Data is present, continuing")
else:
  #Download original images
  print("Downloading originals and resizing")
  if not os.path.isfile(f'imagenette2.tgz'):
    !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    !tar -xf imagenette2.tgz
    # Downscale to chosen size
    folder_dir = f"imagenette2-{size}"
    os.rename("imagenette2",folder_dir)
    for dataset in ["train","val"]:
      for classes in os.listdir(f"{folder_dir}/{dataset}"):
        for image in os.listdir(f"{folder_dir}/{dataset}/{classes}"):
          image_path = f"{folder_dir}/{dataset}/{classes}/{image}"
          img = Image.open(image_path)
          img.thumbnail((size,size))
          os.remove(image_path)
          img.save(image_path)



Load Model and target Layers for GradCam

In [None]:
model, target_layers = load_model('mobilenet', norm_layer=True)


Build our DataLoaders

In [None]:
bs = 32

trainset = ImagenetteDataset(size, should_normalize=False)
valset = ImagenetteDataset(size, should_normalize=False, validation=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = bs, shuffle = True)
valloader = torch.utils.data.DataLoader(valset, batch_size = bs, shuffle = False)




Get first Batch for Testing

In [None]:
data_batch, labels_batch = next(iter(trainloader))
print(data_batch.size())
print(labels_batch.size())
out = torchvision.utils.make_grid(data_batch)

Predict First Batch with our model

In [None]:
model.to(device)
class_names = trainset.classes
print(class_names)
outputs = model(data_batch.to(device))
print(outputs.shape)
preds = imagenette_outputs(outputs)
print(labels_batch)
#print(preds)

In [None]:
imshow(out, denorm=False)#, title=[class_names[x] for x in preds])

Run Inference on whole trainset

In [None]:
all_predictions = []
num_correct = 0
with torch.no_grad():
  loop = tqdm(trainloader)
  for idx, (data, labels) in enumerate(loop):
    outputs = model(data.to(device))
    preds = imagenette_outputs(outputs)
    all_predictions.extend(preds)
    corrects = torch.sum(preds == labels.to(device))
    num_correct += corrects
    loop.set_description(f"Processing batch {idx+1}")
    loop.set_postfix(current_accuracy = num_correct.double().item()/(len(labels)*(idx+1)))
    #print(f"Done with batch of size {(len(labels))}")
pred = torch.stack(all_predictions)



In [None]:
print("Accuracy: {:.4f}".format(num_correct.double()/len(trainset)))

In [None]:
print(pred.size())
print(pred[0:100])

Generate Adversarial Examples

In [None]:
#Iterate over all cams
for ourcam in cams:
  folder = f"./adv_examples_{ourcam.__name__}_{size}/"
  if not os.path.exists(folder):
    os.mkdir(folder)
  cam = ourcam(model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available(), use_mps=True)
  
  torch.manual_seed(42)
  # To avoid bias, we only use images which have been used as test set during training
  loop = tqdm(valloader)
  examples = []
  found = 0
  thirdlabel = 0
  same = 0
  invalid = 0
  bad_ex = 0
    
  for batch, (data, labels) in enumerate(loop):
    #stop after 10% of the dataset
    #if batch > len(valloader)//5:
    #  break
    #make sure we have even number of samples, if not, remove the last one. Use even block size to avoid this
    if len(labels) % 2 != 0:
      data = data[:-1,:,:,:]
      labels = labels[:-1]

    # Sort the batch so that the base and attack image do not have the same label
    # we try it for bs^2 times and then stop, some batches are not sortable in this way
    # we should get almost all of them sorted nicely though
    data, labels = data.to(device), labels.to(device)
    correct = False
    tries = 0
    loop.set_description(f"Sorting batch...")
    while (not correct and tries < bs**2):
      swapped = False
      for idx, img in enumerate(data):
        if idx >= len(labels)/2:
          break
        if labels[idx] == labels[idx+int((len(labels)/2))]:
          labels[idx], labels[idx+1] = labels[idx+1], labels[idx]
          data[idx], data[idx+1] = data[idx+1], data[idx]
          swapped = True
      if not swapped:
        correct = True
      tries += 1
    # get the CAMs for the batch
    grayscale_cam = cam(input_tensor=data, targets=None)
    cams_base, cams_attack = np.array_split(grayscale_cam, 2, axis=0)
    imgs_base, imgs_attack = np.array_split(data.cpu().numpy(), 2, axis=0)
    labels_base, labels_attack = np.array_split(labels.cpu().numpy(), 2, axis=0)
    #iterate over each batch
    for base_img, attack_img, base_cam, attack_cam, base_label, attack_label in zip(imgs_base, imgs_attack, cams_base, cams_attack, labels_base, labels_attack):
      # ignore pairs with same label (should not happen too often now)
      if (attack_label == base_label):
        same += 1
        continue
      #start with a 1% mask
      current_threshold = 0.99
      
      # Look for the adversarial Example
      while True:
        loop.set_description(f"Found: {found}, 3rdlabel: {thirdlabel} same label: {same}, invalid: {invalid}, bad_ex: {bad_ex}, using {ourcam.__name__}")
        base_threshold = np.quantile(base_cam.flatten(), current_threshold)
        attack_threshold = np.quantile(attack_cam.flatten(), current_threshold)
        base_mask = np.where(base_cam>base_threshold, np.ones_like(base_cam), np.zeros_like(base_cam))
        attack_mask = np.where(attack_cam>attack_threshold, np.ones_like(attack_cam), np.zeros_like(attack_cam))
        c_o_m_base = np.array(center_of_mass(base_mask))
        c_o_m_attack = np.array(center_of_mass(attack_mask))
        offset = c_o_m_base - c_o_m_attack

        # Remember the last image we produced, in case this is the adversarial example
        if 'invariance_adv' in locals():
          last_img = invariance_adv.copy()

        #Produce the example
        invariance_adv = np.where(base_mask==True, shift(attack_img, offset), base_img)

        #Check output of Model
        output = imagenette_outputs(model(torch.from_numpy(invariance_adv).unsqueeze(0).to(device)))

        
        if output.item() == base_label:
          # threshold <= 0.01 means we have a mask of 99% -> we can't find an adversarial example
          if current_threshold <= 0.01:
            invalid +=1
            break
          #Model still predicts base label -> make mask bigger
          current_threshold -= 0.01
        
        #elif output.item() == attack_label:
        #  #We found the example. Write it to disk
        #  found += 1
        #  img = Image.fromarray((last_img*255).astype(np.uint8).transpose(1,2,0))
        #  # Format of image name: base_label_attack_label_intermediate_label_threshold.jpg
        #  #img.save(f"/content/drive/MyDrive/adv_examples_{size}/{base_label}_{attack_label}_{current_threshold:.2f}.jpg")
        #  examples.append((last_img, base_label, attack_label, current_threshold))
        #  break
        else:
          # threshold >= 0.99 means we have a mask of 1% and the model already flips label. We can't find an adversarial example
          if current_threshold >= 0.99:
            invalid +=1
            break
          #model flips early, we look for a better example
          if current_threshold >= 0.3:
            bad_ex += 1
            break
          #We found the example. Write it to disk
          img = Image.fromarray((last_img*255).astype(np.uint8).transpose(1,2,0))
          #Format of image name: base_label_attack_label_intermediate_label_threshold.jpg

          img.save(f"{folder}/{base_label}_{attack_label}_{output.item()}_{current_threshold:.2f}.jpg")
          examples.append((last_img, base_label, attack_label, output.item(), current_threshold))
          if output.item() != attack_label:
            thirdlabel += 1
          else:
            found += 1
          break
  with open(f"{folder}/results.txt", "w") as f:
    f.write(f"Found: {found}, 3rdlabel: {thirdlabel} same label: {same}, invalid: {invalid}, bad: {bad_ex} using {ourcam.__name__} and {size}x{size} images")

In [None]:
# Check if outputs are correct
for (img, base_label,attack_label,_,threshold) in examples:
  output = imagenette_outputs(model(torch.from_numpy(img).unsqueeze(0).to(device)))
  if output.item() != base_label:
    print(f"Wrong output for {base_label}_{attack_label}: {output.item()} with {threshold:.2f}")
    #print(img)
    #plt.imshow(img.transpose(1,2,0))
    #imshow(img, denorm=True)
    #break


#### Results

##### XGradCam valset 320px 
Found: 135, 3rdlabel: 26 same label: 18, invalid: 109, bad_ex: 1673, cutoff: 0.3,  Median 0.25, best: 0.13. Time 38:53

In [None]:
examples.sort(key=lambda x: x[4], reverse=False)
#examples.sort(key=lambda x: (x[4]-0.5)**2, reverse=True)
idx = 0
#for idx in range(30):
#  print(f"{examples[idx][4]:.2f}")
print(f"{examples[idx][4]:.2f}, {imagenette_labels[examples[idx][1]]}, {imagenette_labels[examples[idx][2]]}, {imagenette_labels[examples[idx][3]]}")
#plt.imshow(examples[idx][0].transpose(1,2,0))
thresholds = [x[4] for x in examples]
#print median
print(f"Median: {np.median(thresholds):.2f}")
plt.hist(thresholds);


Export Pictures to disk

In [None]:
# for idx, img in enumerate(examples):
#   img = Image.fromarray((img*255).astype(np.uint8).transpose(1,2,0))
#   img.save(f"/content/drive/MyDrive/adv_examples_320/{idx}.jpg")

Plot some of the Pictures

In [None]:

# f, xarr = plt.subplots(2,2, figsize=(15,15))
# xarr.flatten()
# for idx, ax in enumerate(xarr.flatten()):
#   ax.imshow(examples[idx][0].transpose(1,2,0))

In [None]:
# idx = 4
# com_b = c_o_m_base[idx]
# com_a = c_o_m_attack[idx]
# offset = offsets[idx]
# base_image = imgs_base[idx]
# attack_image = imgs_attack[idx]
# print (base_image.shape)
# print (attack_image.shape)
# fig, ax = plt.subplots(1, 2, figsize=(20, 20))
# ax = ax.flatten()
# ax[0].imshow(base_image.transpose(1,2,0))
# ax[0].scatter(com_b[0], com_b[1], s=size, c='C0', marker='+')
# ax[1].imshow(attack_image.transpose(1,2,0))
# ax[1].scatter(com_a[0], com_a[1], s=size, c='C1', marker='+')
# ax[1].scatter(com_b[0], com_b[1], s=size, c='C0', marker='+')


In [None]:
# print(attack_image.shape)
# shifted = shift(attack_image, offset)
# #print(offset[::-1])
# print (shifted.shape)
# plt.imshow(shifted.transpose(1,2,0))

# # print(com_b, com_a, offset)
# # attack_image_cropped = attack_image[:,39:,11:]
# # #plt.imshow(attack_image_cropped.transpose(1,2,0))
# # print(attack_image_cropped.shape)
# # empty = np.zeros_like(attack_image)
# # empty[:,0:121,0:149] = attack_image_cropped
# # print(empty.shape)



In [None]:
# invariance_adv = np.where(masks_base[0]==True, shifted, base_image)
# plt.imshow(invariance_adv.transpose(1,2,0))

In [None]:
# from scipy.ndimage import shift
# print(offsets[11])
# print(masked_base[11].shape)
# print(offsets[0,0])
# print (offsets[:,0])
# one_image = imgs_attack[:,:,offsets[:,0]:,offsets[:,1]:]

# #shifted = shift(masked_base[11], offsets[11], cval=0)
# plt.imshow(attack_patches[4].transpose(1,2,0))
# #plt.imshow(masked_base[11].transpose(1,2,0))

In [None]:
# print(normalized.shape)
# fig, ax = plt.subplots(2, 4, figsize=(20, 20))
# ax = ax.flatten()
# for i in range(8):
#   idx = random.randint(0, len(masked_images)-1)
#   ax[i].imshow(normalized[idx].transpose(0,1))


In [None]:
# c_o_m = multiple_c_o_m(masked_images)

# print(c_o_m.shape)
# #c_o_m

In [None]:
# fig, ax = plt.subplots(2, 4, figsize=(20, 20))
# ax = ax.flatten()
# for i in range(8):
#   idx = random.randint(0, len(masked_images)-1)
#   ax[i].imshow(masked_images[idx].transpose(1,2,0))
#   ax[i].scatter(c_o_m[idx][0], c_o_m[idx][1], s=size, c='C0', marker='+')
#   print(idx, c_o_m[idx])


# plt.show()

In [None]:
# #not needed
# threshold = np.quantile(gradcam_hm.flatten(), .85)
# b_mask = np.where(gradcam_hm>threshold, np.ones_like(gradcam_hm), np.zeros_like(gradcam_hm))
# print (b_mask.shape)
# img_batch = next(iter(trainloader))[0]
# idx = 4
# plt.imshow((b_mask[idx]*img_batch[idx].detach().cpu().numpy()).transpose(1,2,0))

Explainability with Pytorch Captum

In [None]:
%pip install -q git+https://github.com/pytorch/captum.git

from captum.attr import IntegratedGradients, NoiseTunnel
from captum.attr import visualization as viz
from matplotlib.colors import LinearSegmentedColormap

Integrated Gradients

In [None]:

ig = IntegratedGradients(model)
data, labels = next(iter(trainloader))
idx = 4
input = data[idx].unsqueeze(0).to(device)
label = labels[idx].to(device)
#print (data[0].size())
attributions = ig.attribute(input, target=label, n_steps=100)

default_cmap = LinearSegmentedColormap.from_list('custom blue',
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

_ = viz.visualize_image_attr(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
                             np.transpose(data[idx].squeeze().cpu().detach().numpy(), (1,2,0)),
                             method='heat_map',
                             cmap=default_cmap,
                             show_colorbar=True,
                             sign='positive',
                             outlier_perc=1)


In [None]:
imshow(data[idx], denorm=False)

Noise Tunnel for Smooting

In [None]:
# nt_samples <= 7 for 15GB VRAM 
noise_tunnel = NoiseTunnel(ig)

attributions_ig_nt = noise_tunnel.attribute(input, nt_samples=20, nt_type='smoothgrad_sq', target=label)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(data[idx].squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "positive"],
                                      cmap=default_cmap,
                                      show_colorbar=True)

In [None]:
#plt.imshow(show_cam_on_image(np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0)),
 #                                     np.transpose(data[idx].squeeze().cpu().detach().numpy(), (1,2,0)), use_rgb=True))