# Robust Feature Level Adversaries are Interpretability Tools

https://arxiv.org/abs/2110.03605

Stephen Casper (scasper@csail.mit.edu)

Max Nadeau (mnadeau@college.harvard.edu)

Dylan Hadfield-Menell

Gabriel Kreiman

```
@article{casper2021robust,
  title={Robust Feature-Level Adversaries are Interpretability Tools},
  author={Casper, Stephen and Nadeau, Max and Kreiman, Gabriel},
  journal={arXiv preprint arXiv:2110.03605},
  year={2022}
}
```

### Tips
- [Code folding](https://stackoverflow.com/questions/54819026/codefolding-on-google-colab) will make this notebook more pleasant to work with. 
- If you get any CUDA out of memory issues, try restarting the runtime or reducing the ```batch_size``` or ```sub_batch_size``` arguments to the attack functions. These issues are more likely if a cell is stopped mid-execution and restarted. 
- If you want things to run more quickly, you can often get away with reducing ```n_batches```. 
- Targeted, universal attacks tend to have have variable success, especially when optimizing for a complex objective as in our case. So always run multiple trials. You can also increase the patch/perturbation size if you want attacks that are more successful on average. 
- Attacks tend to be the easiest to produce when using semantically-related source/target class pairs such as bee/fly or pufferfish/lionfish. 
- You can modify the ```latent_i``` param to change which block of the generator the perturbation is trained in. Using the very last one (```latent_i=13```) will result in a standard pixel-space attack. 
- You can play around with the loss hyperparameters to get attacks optimized more or less for different parts of the objective. 
- This code should be fairly easy to modify for your own experiments. The key functions to play with will be ```patch_adversary```, ```region_generalized_patch_adversary```, ```custom_loss_patch_adv```, and ```custom_loss_region_gen_patch_adv```. 
- If you have questions or feedback, just email us. We're friendly and excited about interpretable adversarial features. We can also discuss ideas for future work. 

## First thing's first
This cell MIGHT be needed. As of November 2022, this seems to help sometimes with issues with downloading data. You may need to run and then restart the runtime if it asks you to.

In [None]:
%%bash
pip install --upgrade gdown

### Installing Packges and Downloading Data
This may take a couple minutes to run the first time. This downloads some data including images, labels, and models. 


In [None]:
%%capture
%%bash 
pip install -q pytorch-pretrained-biggan
pip install -q git+https://github.com/S-aiueo32/lpips-pytorch.git
pip install -q pytorch_pretrained_vit

In [None]:
%%bash
# make a directory called data
if ! [ -d ./data/ ] ; then
    mkdir data/
    echo 'data dir successfully created :)'
fi

data dir successfully created :)


In [None]:
import sys
import os
import gdown

# Download a set of 2k imagenet validation images
if not os.path.isfile('./data/imagenet2k.pkl'):
    gdown.download('https://drive.google.com/uc?id=1eksXWRHvv3qhCKOHQg90-F6tEifgZ67o', 
                    './data/imagenet2k.pkl', quiet=True)
    
# Download labels for a set of 2k imagenet validation images
if not os.path.isfile('./data/imagenet2k_labels.pkl'):
    gdown.download('https://drive.google.com/uc?id=1loxsvOBkD9-C3u7j-mIaYuT6G86dzZj-', 
                    './data/imagenet2k_labels.pkl', quiet=True)
    
# Download a dict of imagenet class labels
if not os.path.isfile('./data/imagenet_classes.pkl'):
    gdown.download('https://drive.google.com/uc?id=1AnniTzpmPHumxCDdfLTeCblvom3bWYt9', 
                    './data/imagenet_classes.pkl', quiet=True)

# Download a couple of images
if not os.path.isfile('./data/traffic_light.png'):
    gdown.download('https://drive.google.com/uc?id=1ycDA2zusMs_-upmN3T7xR5-M7nWGPL08', 
                    './data/traffic_light.png', quiet=True)
if not os.path.isfile('data/bee.png'):
    gdown.download('https://drive.google.com/uc?id=14Y07EF0JmANV53Bgkh40acSHeEXkR6RB', 
                    './data/bee.png', quiet=True)
    
# Download a zipped folder with various model weights
if not os.path.isfile('./fla_models.zip'):
    gdown.download('https://drive.google.com/uc?export=download&id=13Uta4vNU-YYWrb2r8hf59W_J0VNPJwyP', 
                    './fla_models.zip', quiet=True)
    
print('Files successfully downloaded :)')


Files successfully downloaded :)


In [None]:
%%bash
# unzipping
if ! [ -d ./fla_models/ ] ; then
    unzip -q ./fla_models.zip -d .
    echo 'fla_models successfully unzipped :)'
fi

fla_models successfully unzipped :)


### Imports



In [None]:
import pickle
import copy
import random
from pathlib import Path
from time import time
from tqdm import tqdm
from collections import OrderedDict
from IPython.utils import io
import numpy as np
from scipy import ndimage
import cv2
import imageio
import matplotlib.pyplot as plt
from matplotlib import image
import torch
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as T
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from lpips_pytorch import LPIPS
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_int, truncated_noise_sample,
                                       save_as_images, display_in_terminal)
from pytorch_pretrained_vit import ViT
from fla_models.biggan_disc import Discriminator
from fla_models.pytorch_pretrained_gans import make_gan

assert torch.cuda.is_available(), 'In Colab, select [Runtime -> Change Runtime Type -> Hardware Accelerator -> GPU]'
device = 'cuda'

### Constants, Transforms, and Data

In [None]:
# constants
N_CLASSES = 1000
PATCH_SIDE = 64
IMAGE_SIDE = 256
N_ROUND = 4
GAUSS_SIGMA = 0.1
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

# transforms
resize64 = T.Resize((64, 64))
resize128 = T.Resize((128, 128))
resize256 = T.Resize((256, 256))
normalize = T.Normalize(mean=MEAN, std=STD)
unnormalize = T.Normalize(mean=-MEAN/STD, std=1/STD)
to_tensor = T.ToTensor()
def gaussian_noise(tens, sigma=GAUSS_SIGMA):
    noise = torch.randn_like(tens) * sigma
    return tens + noise.to(device)
cjitter = T.ColorJitter(0.25, 0.25, 0.25, 0.05)
def custom_colorjitter(tens):
    tens = unnormalize(tens)
    tens = cjitter(tens)
    tens = normalize(tens)
    return tens

# for patch attacks
transforms_patch = T.Compose([custom_colorjitter, T.GaussianBlur(3, (.1, 1)), gaussian_noise,
                            T.RandomPerspective(distortion_scale=0.25, p=0.66), 
                            T.RandomRotation(degrees=(-10, 10))]) 
# for region and generalized patch attacks
transforms_im = T.Compose([T.GaussianBlur(3, (.1, .5)), T.RandomHorizontalFlip()]) 

# get data: imagenet classes, 2000 imagenet validation images, and their labels (2 per class)
with open('data/imagenet_classes.pkl', 'rb') as f:
    class_dict = pickle.load(f)
with open('data/imagenet2k.pkl', 'rb') as f:
    imagenet2k = pickle.load(f)
with open('data/imagenet2k_labels.pkl', 'rb') as f:
    imagenet2k_labels = pickle.load(f)

### Load Models
This will take a minute to run the first time. 

In [None]:
%%capture
# load models including ensembles of classifiers, a BigGAN generator, and a BigGAN discriminator

class Ensemble:

    """
    Ensembles together a set of classifiers, combining them by averaging their softmax outputs.
    """
    
    def __init__(self, classifiers):
        self.cfs = [self.get_classifier(cf) for cf in classifiers]
        self.n_cfs = len(self.cfs)

    def get_classifier(self, name):
        if name == 'vit':
            C = ViT('B_16_imagenet1k', pretrained=True, image_size=(256, 256)).to(device)
        elif 'robust' in name:
            C = models.resnet50(pretrained=False).eval().to(device)
            model_dict = C.state_dict()
            if name == 'resnet50_robust_l2':
                load_dict = torch.load('fla_models/imagenet_l2_3_0.pt')['model']
            elif name == 'resnet50_robust_linf':
                load_dict = torch.load('fla_models/imagenet_linf_4.pt')['model']
            else:
                raise ValueError('invalid robust model name')
            new_state_dict = OrderedDict()
            for mk in model_dict.keys():
                for lk in load_dict.keys():
                    if lk[13:] == mk:
                        new_state_dict[mk] = load_dict[lk]
            C.load_state_dict(new_state_dict)
            del model_dict
            del load_dict
        else:
            lcls = locals()
            exec(f'C = models.{name}(pretrained=True).eval().to(device)', globals(), lcls)
            C = lcls['C']
        return C

    def __call__(self, inpt):
        outpts = [F.softmax(cf(inpt), 1) for cf in self.cfs]
        return sum(outpts) / self.n_cfs

ALL_CLASSIFIERS = ['alexnet', 'resnet50', 'vgg19', 'inception_v3', 'densenet121', 'resnet50_robust_l2', 'resnet50_robust_linf', 'vit']
TRAIN_CLASSIFIERS = ['resnet50']
REG_CLASSIFIERS = ['resnet50_robust_l2', 'resnet50_robust_linf']
INTERP_CLASSIFIERS = ['inception_v3']

E_attack = Ensemble(TRAIN_CLASSIFIERS)  # for attacking
E_reg = Ensemble(REG_CLASSIFIERS)  # for disguise and interpretability regularization 
E_interp = Ensemble(INTERP_CLASSIFIERS)  # for analyzing the interpretability of an adversarial feature
G = make_gan(gan_type='biggan', model_name='biggan-deep-256').to(device)
D = Discriminator()  # for interpretability regularization
D.load_state_dict(torch.load('fla_models/D.pth'))
D.to(device)

### Helper Functions
These functions are mostly related to calculating the loss and image processing. 

In [None]:
%%capture

def tensor_to_numpy_image(tensor, unnormalize_img=True):
    """
    Takes a tensor and turns it into an imshowable np.ndarray
    """
    image = tensor
    if unnormalize_img:
        image = unnormalize(image)
    image = image.detach().cpu().numpy()
    image = np.squeeze(image)
    image = np.transpose(image, axes=(1, 2, 0))
    image = np.clip(image, 0, 1)
    return image

def tensor_to_0_1(tensor):
    """
    Shifts 0 to be at 0.5, then normalizes s.t. image falls on [0,1]
    """
    return tensor / torch.max(torch.abs(tensor)) / 2 + 0.5

nll_loss = nn.NLLLoss()  # negative log likelihood

class LPIPS_Device(LPIPS): 
    """
    Calculates perceptual distance between images. Used for regularization in region and generalized patch attacks. 
    """
    def __init__(self, net_type: str='alex', version: str='0.1'):
        super().__init__(net_type, version)
        # put the weights on device
        self.net.to(device)
        self.lin.to(device)
lpips_dist = LPIPS_Device(net_type='vgg', version='0.1')  # ['alex', 'squeeze', 'vgg']

def total_variation(images):
    """
    Calculates the summed L1 variation of images in tensor NCHW form
    """
    if len(images.size()) == 4:
        h_var = torch.sum(torch.abs(images[:, :, :-1, :] - images[:, : ,1:, :]))
        w_var = torch.sum(torch.abs(images[:, :, :, :-1] - images[:, :, :, 1:]))
    else:  # if 3 (CHW)
        h_var = torch.sum(torch.abs(images[:, :-1, :] - images[: ,1:, :]))
        w_var = torch.sum(torch.abs(images[:, :, :-1] - images[:, :, 1:]))
    return h_var + w_var

def entropy(sm_tensor, epsilon=1e-10):
    """
    Returns a N length vector of entropies from an NxC tensor.
    """
    log_sm_tensor = torch.log(sm_tensor+epsilon)
    h = -torch.sum(sm_tensor * log_sm_tensor, dim=1)  # formula for entropy
    return h

def custom_loss_patch_adv(output, target, patch, lam_xent=3.0, lam_tvar=1e-3, 
                          lam_disc=0.005, lam_patch_xent=0.2, lam_ent=0.2, quant=0.5, patch_bs=16):
    """
    Calculates the targeted misclassification crossentropy loss with regularization based on 
    total variation, discriminator realisticness confidence, classifier patch non-target confidence, 
    and classifier patch entropy.
    """
    avg_xent = nll_loss(torch.log(output), target)  # crossentropy (minimize)
    avg_tvar = total_variation(patch) / output.shape[0]  # avg total variation (minimize)
    loss = lam_xent*avg_xent + lam_tvar*avg_tvar

    if lam_disc != 0:
        y = torch.tensor(list(range(N_CLASSES))).to(device)  # y for all classes
        disc_out = D(patch, y)[:, 0]  # class conditioned output for all 1000 classes
        disc_q = torch.quantile(disc_out, quant)  # quantile marking the k highest
        disc = torch.mean(disc_out[disc_out > disc_q]) # discriminator conf, mean is over top k (maximize)
        loss -= lam_disc*disc

    if lam_patch_xent != 0 or lam_ent != 0:
        patch256 = resize256(patch)
        classifiers_out = E_reg(torch.cat([transforms_patch(patch256) for i in range(patch_bs)], axis=0)) # what the classifiers think of the patch
        patch_xent = nll_loss(torch.log(classifiers_out), target[:patch_bs])  # cross entropy loss for target (maximize)
        ent = torch.mean(entropy(classifiers_out)) # entropy for softmax outputs (minimize)
        loss -= lam_patch_xent*patch_xent
        loss += lam_ent*ent 

    return loss

def custom_loss_region_gen_patch_adv(output, target, perturbation, adv_img, orig_img, lam_tvar=1e-5, lam_lpips=4, 
                                     lam_disc=0.1, lam_wd = 0.0001, lam_patch_xent=0.1, lam_ent=0.2):
    """
    For region and generalized patch adversaries. 
    Calculates the targeted misclassification crossentropy loss with regularization based on 
    total variation, LPIPS perceptual distance, discriminator realisticness confidence, 
    perturbation norm, classifier patch non-target confidence, and classifier patch entropy.
    """
    avg_x_ent = nll_loss(torch.log(output), target) # crossentropy (minimize)
    n_imgs = output.shape[0]
    avg_t_var = total_variation(orig_img-adv_img) / n_imgs  # avg total variation (minimize)
    wd = torch.mean(torch.linalg.norm(torch.flatten(perturbation, start_dim=1), dim=1))  # L2 perturbation norm (minimize)
    loss = avg_x_ent + lam_tvar*avg_t_var + lam_wd * wd
    
    if lam_lpips != 0:
        avg_lpips = lpips_dist(adv_img, orig_img) / n_imgs  # lpips perceptual distance (minimize)
        loss += lam_lpips * torch.squeeze(avg_lpips)

    if lam_disc != 0:
        y = torch.tensor([list(range(N_CLASSES))]*adv_img.shape[0]).to(device) 
        avg_disc = torch.mean(torch.topk(D(resize128(adv_img), y), 5, dim=0)[0])  # avg across top-5 disc values and across minibatch (maximize)
        loss -= lam_disc * avg_disc

    if lam_patch_xent != 0 or lam_ent != 0:
        patches = [crop_to_square(get_gen_patch(*pair)) for pair in zip(orig_img, adv_img)]
        patches256 = torch.cat([resize256(patch) for patch in patches])  # get patches as full ims
        classifiers_out = E_reg(normalize(patches256))
        avg_patch_xent = nll_loss(torch.log(classifiers_out), target)  # classifier xent for *target* class (maximize)
        avg_ent = torch.mean(entropy(classifiers_out)).item()  # classifier softmax entropy (minimize)
        loss -= lam_patch_xent * avg_patch_xent 
        loss += lam_ent * avg_ent

    return loss

def insert_patch(patch, batch_size, prop_lower=0.2, prop_upper=0.8, side_radius=10, transform=True, from_generator=False, y=None):
    """
    For universal patch attacks, this randomly tiles images and inserts patches into them.
    """
    if from_generator:  # if generating your own patches
        with torch.no_grad():
            ys = torch.cat([y]*batch_size, 0)
            rand_noises = G.sample_latent(batch_size=batch_size, device=device)
            images = normalize(G(rand_noises, ys))
            orig_images = copy.deepcopy(images).to(device)
    else:  # if using ImageNet validation set images
        rand_is = np.random.randint(0, imagenet2k_labels.shape[0], size=batch_size)
        images = normalize(torch.stack([to_tensor(imagenet2k[rand_i]) for rand_i in rand_is])).to(device)
        orig_images = copy.deepcopy(images).to(device)
    mid = (IMAGE_SIDE-PATCH_SIDE) // 2
    for i in range(batch_size): 
        if transform:  # randomly transform and insert
            side = np.random.randint(PATCH_SIDE-side_radius, PATCH_SIDE+side_radius+1)
            rand_x = np.random.randint(int((IMAGE_SIDE-side)*prop_lower), 
                                    int((IMAGE_SIDE-side)*prop_upper)+1)
            rand_y = np.random.randint(int((IMAGE_SIDE-side)*prop_lower), 
                                    int((IMAGE_SIDE-side)*prop_upper)+1)
            to_insert = transforms_patch(T.functional.resize(patch, [side, side]))
            mask = to_insert != 0.0  # the mask makes any black parts of the patch not inserted
            images[i, :, rand_x: rand_x+side, rand_y: rand_y+side] *= torch.logical_not(mask)
            images[i, :, rand_x: rand_x+side, rand_y: rand_y+side] += mask * to_insert
        else:  # randomly insert
            rand_x = np.random.randint(int((IMAGE_SIDE-PATCH_SIDE)*prop_lower), 
                                       int((IMAGE_SIDE-PATCH_SIDE)*prop_upper)+1)
            rand_y = np.random.randint(int((IMAGE_SIDE-PATCH_SIDE)*prop_lower), 
                                       int((IMAGE_SIDE-PATCH_SIDE)*prop_upper)+1)
            images[i, :, rand_x: rand_x+PATCH_SIDE, rand_y: rand_y+PATCH_SIDE] = resize64(patch)
    return images, orig_images

def get_mask(orig, adv, quant_threshold=0.9):
    """
    For generalized patch attacks. Takes in two tensors, produces a bool mask tensor of their differences
    """
    diff = tensor_to_numpy_image(tensor_to_0_1(adv-orig), False)
    smooth_absdiff = ndimage.gaussian_filter(np.abs(diff-0.5), 12)
    mask =  smooth_absdiff > np.quantile(smooth_absdiff, quant_threshold)
    mask = np.any(mask, axis=-1) # differences on each color channel merged
    mask = ndimage.binary_opening(mask, iterations=4)
    mask = ndimage.binary_closing(mask, iterations=4, border_value=1)
    return torch.tensor(mask, device=device) 

def crop_to_square(patch):
    """
    Takes a patch over a grey background and condenses it to a minimal bounding square. 
    """
    mask = patch[0,0] != 0.5  # just use the R channel as a heuristic
    adv_region = patch[np.ix_([True], [True, True, True], torch.any(mask, dim=1).cpu().numpy(), torch.any(mask, dim=0).cpu().numpy())]
    sh = adv_region.shape
    square = torch.ones((sh[0], sh[1], max([sh[2], sh[3]]), max([sh[2], sh[3]])), device=device) * 0.5
    square[:, :, 0:sh[2], 0:sh[3]] = adv_region
    return square

def get_gen_patch(orig, adv):
    """
    For generlized patch attacks. Returns a patch of the diff between the adv and orig imgs over a gray background
    """
    mask = get_mask(orig, adv)
    patch = adv * mask[None, None, :, :]
    patch += 0.5 * torch.ones_like(adv) * torch.logical_not(mask[None, None, :, :])
    return patch   


### Attack Training and Evaluation Functions
Where the magic happens. These functions perform patch, region, generalized-patch, and copy-paste attacks.

In [None]:
def patch_adversary(n_batches=64, batch_size=32, lr=0.01, latent_i=8, 
                    source_class=None, target_class=None, train_noise=True, 
                    train_class_vector=True, input_lr_factor=0.025, loss_hypers={}):
    """
    This function trains an adversarial patch that is targeted, universal, interpretable, and 
    physically-realizable. The success rate is variable for random choices of target classes, 
    so try running it multiple times. 
    """
    # get target class
    if target_class is None:
        target_class = np.random.randint(N_CLASSES)
    target_tensor = torch.tensor([target_class]*batch_size, dtype=torch.long).to(device)

    # if a class universal adversary
    if source_class is not None:
        source_tmp = G.sample_class(batch_size=1, device=device) * 0
        source_tmp[0][source_class] += 1
        source = source_tmp

    # get latents from the patch generator
    with torch.no_grad():
        cv = torch.ones(1, 1000).to(device) / 999
        cv[:, target_class] = 0.0
        cvp = nn.Parameter(torch.zeros_like(cv)).to(device).requires_grad_()
        nv = G.sample_latent(batch_size=1, device=device)
        nvp = nn.Parameter(torch.zeros_like(nv)).requires_grad_()
        lp = nn.Parameter(torch.zeros_like(G(nv, cv, return_latents=True)[latent_i]))
        params = [{'params': lp}]
        if train_class_vector:
            params.append({'params': cvp, 'lr': lr * input_lr_factor})
        if train_noise:
            params.append({'params': nvp, 'lr': lr * input_lr_factor})
        optimizer = optim.Adam(params, lr)

    # generate patch, insert into images, and train
    for _ in tqdm(range(n_batches)):
        patch = normalize(G(nv, cv, nvp, cvp, lp, insertion_layer=latent_i))
        if source_class is None:  # if a universal attack
            patched_images, orig_images = insert_patch(patch[0], batch_size)
        else:  # if a class_universal attack
            patched_images, orig_images = insert_patch(patch[0], batch_size, from_generator=True, y=source)
        predictions = E_attack(patched_images)
        optimizer.zero_grad()
        loss = custom_loss_patch_adv(predictions, target_tensor, patch, **loss_hypers)
        loss.backward()
        if train_class_vector:
            cvp.grad[:, target_class] *= 0.0
        optimizer.step()

    # evaluate
    with torch.no_grad():
        patch = normalize(G(nv, cv, nvp, cvp, lp, insertion_layer=latent_i))
        if source_class is None:  # if a universal attack
            patched_images, _ = insert_patch(patch[0], batch_size, transform=False) 
        else:
            patched_images, _ = insert_patch(patch[0], batch_size, from_generator=True, y=source, transform=False)
        adv_sm_out = E_attack(patched_images)
        mean_conf = round(np.mean(np.array([float(aso[target_class]) for aso in adv_sm_out])), N_ROUND)
        i_sm_out = E_interp(resize256(patch))
        i_class = int(torch.argmax(i_sm_out))
        i_conf = round(float(torch.max(i_sm_out)), N_ROUND)
        
    # show results
    plt.imshow(tensor_to_numpy_image(patch[0]))
    if source_class is None:
        plt.title(f'Universal Patch Adversary\nlatent: {latent_i}\ntarget: {class_dict[target_class]}, mean conf: {mean_conf}\ndisguise: {class_dict[i_class]}, conf: {i_conf}'.title())
    else: 
        plt.title(f'Class Universal Patch Adversary\nlatent: {latent_i}\nsource={class_dict[source_class]}\ntarget: {class_dict[target_class]}, mean conf: {mean_conf}\ndisguise: {class_dict[i_class]}, conf: {i_conf}'.title())
    plt.xticks([])
    plt.yticks([])
    plt.show()    

def assess_gp(patches, target_int, n_test=3, n_display=3, source_class=None):
    """
    This function is called from inside of assess_rgp and displayes the generalized patches.
    """
    with torch.no_grad():
        # Classes/noise for GAN (very finicky, change at your peril)
        cv = G.sample_class(batch_size=n_test, device=device)
        if source_class is not None:
            cv *= 0
            cv[:,source_class] += 1
        cv_int = torch.argmax(cv, -1).detach().cpu().numpy()
        orig_noise = G.sample_latent(batch_size=n_test, device=device)
        orig_target =  torch.zeros(n_test)

        # Set up fig and some stats tensors
        n_display = min([n_display, n_test])
        fig, axes = plt.subplots(1 + len(patches), 1 + n_display, figsize=(4*(1 + n_display), 5*(1 + len(patches))))
        gp_target = torch.zeros((len(patches), n_test))
        gp_mean_conf = torch.zeros(len(patches))
        gp_std_conf = torch.zeros(len(patches))

        # Fills the first row with the original generated images
        orig_imgs = []
        for j in range(n_test):
            orig_img = G(orig_noise[[j]], cv[[j]])
            orig_imgs.append(orig_img)
            if j < n_display:
                orig_sm_out = E_attack(normalize(orig_img))[0]
                orig_target[j] = round(float(orig_sm_out[target_int]), N_ROUND)
                axes[0, j+1].imshow(tensor_to_numpy_image(orig_img, False))
                axes[0, j+1].set_title(f'{class_dict[cv_int[j]]}: {round(float(orig_sm_out[cv_int[j]]), N_ROUND)}\n {class_dict[target_int]}: {round(orig_target[j].item(), N_ROUND)}'.title(), fontweight="bold")

        # Fill out each successive row with each patch's results
        for i, patch in enumerate(patches):
            mask = patch != 0.5
            for j, orig_img in enumerate(orig_imgs):
                gp_img = orig_img * torch.logical_not(mask) + patch * mask
                gp_sm_out = E_attack(normalize(transforms_im(gp_img)))[0]
                gp_target[i, j] = round(float(gp_sm_out[target_int]), N_ROUND)
                if j < n_display:
                    axes[i+1, j+1].imshow(tensor_to_numpy_image(gp_img, False))
                    axes[i+1, j+1].set_title(f'{class_dict[cv_int[j]]}: {round(float(gp_sm_out[cv_int[j]]), N_ROUND)}\n'.title() +
                                             f'{class_dict[target_int]}: {round(float(gp_sm_out[target_int]), N_ROUND)}'.title(), fontweight="bold")
            
            gp_mean_conf[i] = torch.mean(gp_target[i,:])
            gp_std_conf[i] = torch.std(gp_target[i,:])
            square = crop_to_square(patch)
            
            # For fully grey composite patches (occur when there's no overlap in the generated patched)
            if torch.numel(square) <= 1:
                square = patch

            # Patches are evaluated on their own with the interp classifier
            reg_out = E_interp(normalize(resize256(square))).squeeze(0)
            axes[i+1, 0].imshow(tensor_to_numpy_image(resize256(square), False))
            axes[i+1, 0].set_title(f'Disguise conf ({class_dict[torch.argmax(reg_out).item()]}): {round(float(torch.max(reg_out)), N_ROUND)}\n'.title() +
                                   f'Mean target conf: {round(float(gp_mean_conf[i]), N_ROUND)}\n'.title() + 
                                   f'Std target conf: {round(float(gp_std_conf[i]), N_ROUND)}'.title(), fontweight="bold")

        axes[0,0].axis('off')
        for ax in axes.flatten():
            ax.set_xticks([])
            ax.set_yticks([])
        plt.show()

def assess_rgp(insertion, layer, modify_fn, target_int, n_test=20, n_display=3, metadata=None, source_class=None):
    """
    This function asses and displays results a region attack and calls assess_gp to do so for 
    the corresponding generalized patch attack. 
    """
    with torch.no_grad():
        # Set up fig
        n_display = min([n_display, n_test])
        fig, axes = plt.subplots(n_display, 3, figsize=(15, 5*n_display))

        # Classes/noise for GAN (very finicky, change at your peril)
        cv = G.sample_class(batch_size=n_test, device=device)
        if source_class is not None:
            cv *= 0
            cv[:,source_class] += 1
        cv_int = torch.argmax(cv, -1).detach().cpu().numpy()
        nv = G.sample_latent(batch_size=n_test, device=device)
        orig_target, adv_target, gps = [], [], []  # lists to save target confidences and generalized patches

        # Generate and display some images; save ram by running the GAN with one image at a time
        for i in range(n_test): 
            orig_latents = G(nv[[i]], cv[[i]], return_latents=True)
            orig_img = orig_latents[-1]
            lp, op = modify_fn(torch.clone(orig_latents[layer]), insertion)
            adv_img = G(nv[[i]], cv[[i]], lp=lp, insertion_layer=layer)
            adv_sm_out = E_attack(normalize(adv_img))[0]
            adv_target.append(round(float(adv_sm_out[target_int]), N_ROUND))
            gps.append(get_gen_patch(orig_img, adv_img))   

            # display the first n_display examples
            if i < n_display:
                orig_sm_out = E_attack(normalize(orig_img))[0]
                orig_target.append(round(float(orig_sm_out[target_int]), N_ROUND))
                axes[i, 0].imshow(tensor_to_numpy_image(orig_img, False))
                axes[i, 0].set_title(f'{class_dict[cv_int[i]].title()}: {round(float(orig_sm_out[cv_int[i]]), N_ROUND)}\n {class_dict[target_int]}: {orig_target[i]}'.title(), fontweight = "bold")
                axes[i, 1].imshow(tensor_to_numpy_image(adv_img.squeeze(0), False))
                axes[i, 1].set_title(f'{class_dict[cv_int[i]].title()}: {round(float(adv_sm_out[cv_int[i]]), N_ROUND)}\n {class_dict[target_int]}: {adv_target[i]}'.title(), fontweight = "bold")
                axes[i, 2].imshow(tensor_to_numpy_image(tensor_to_0_1(adv_img.squeeze(0)-orig_img), False))
                axes[i, 2].set_title(f'Normalized pixel-level diff'.title(), fontweight = "bold") 

        fig.suptitle('Latent ' + str(layer) + 
                     '\nMean target confidence: ' + str(round(np.mean(adv_target), N_ROUND)) +
                     '\nStd target confidence: ' + str(round(np.std(adv_target), N_ROUND)) + 
                     (f'\nLoss hyperparameters: {metadata["loss_hypers"]}' if metadata is not None else ""), fontweight="bold")

        for ax in axes.flatten():
            ax.set_xticks([])
            ax.set_yticks([])
        plt.show()

        # Make a composite patch by finding the regions that are perturbed in >80% of the patches
        mask_most = torch.sum(torch.stack([(gp[0,0] != 0.5) for gp in gps]), dim=0) > (0.8 * len(gps))
        patch_avg = torch.mean(torch.cat([gp for gp in gps]), dim=0) 
        patch_comp = patch_avg * mask_most[None, None, :, :] + \
                     0.5 * torch.ones_like(patch_avg) * torch.logical_not(mask_most[None, None, :, :])
        gps.append(patch_comp)  # add it to patch list

        # Next, assess the generalized patches
        assess_gp(gps[(len(gps)-n_display):], target_int, source_class=source_class)

def region_generalized_patch_adversary(prop_modified=1/8, latent_i=6, n_batches=128, batch_size=32, sub_batch_size=8, 
                                       lr=0.05, source_class=None, target_class=None, loss_hypers={}):
    """
    This function trains region and generalized patch attacks that are targeted, universal, 
    and interpretable. The success rate is variable for random choices 
    of target classes, so try running it multiple times. 
    """

    # Fix batch size if needed
    batch_size -= batch_size % sub_batch_size

    # Get target class for the attack
    if target_class is None:
        target_class = np.random.randint(N_CLASSES)
    target = torch.tensor(target_class, dtype=torch.long, device=device).unsqueeze(0)
    
    # If class-universal, set y (the source class) permanently
    if source_class is not None:
        cv_int = torch.tensor(sub_batch_size * [source_class], device=device)
        cv = torch.tensor(one_hot_from_int(sub_batch_size * [source_class], batch_size=sub_batch_size), device=device)
  
    # Get a sample pass through the generator and get params for the attack
    nv_init = G.sample_latent(batch_size=2, device=device)
    cv_init = G.sample_class(batch_size=2, device=device)
    latent = G(nv_init, cv_init, return_latents=True)[latent_i]
    region_side = int(np.sqrt(prop_modified) * latent.shape[-1])
    reg_x = np.random.randint(latent.shape[-1] - region_side + 1)
    reg_y = np.random.randint(latent.shape[-1] - region_side + 1)

    # The insertion parameterizes the parturbation
    insertion = nn.Parameter(torch.zeros((latent.shape[1], region_side, region_side), device=device))
    optimizer = optim.Adam([insertion], lr=lr)

    # This function applies the perturbation 
    def modify_fn(latent, insertion):
        perturbation = torch.zeros_like(latent)
        for i in range(perturbation.shape[0]):
            perturbation[i, :, reg_x:(reg_x+region_side), reg_y:(reg_y+region_side)] -= latent[i, :, reg_x:(reg_x+region_side), reg_y:(reg_y+region_side)]
            orig_perturbation = perturbation[:]
            perturbation[i, :, reg_x:(reg_x+region_side), reg_y:(reg_y+region_side)] += insertion
        return perturbation, orig_perturbation
    
    # Train the insertion
    for step in tqdm(range(n_batches), position=0, leave=True):
        for batch_i in range(batch_size//sub_batch_size):  # Avoids overtaxing GPUs
            with torch.no_grad():

                # Sample some source classes if it's a universal attack (and not a class universal one)
                if source_class is None:
                    cv = G.sample_class(batch_size=sub_batch_size, device=device)
                    cv_int = torch.argmax(cv, -1)

                # Generate a sub-batch of images and their latents
                nv = G.sample_latent(batch_size=sub_batch_size, device=device)
                orig_latents = G(nv, cv, return_latents=True)
                orig_latent = orig_latents[latent_i]
                orig_imgs = orig_latents[-1]

            # calc loss and backward
            lp, op = modify_fn(torch.clone(orig_latent), insertion) 
            adv_imgs = G(nv, cv, lp=lp, insertion_layer=latent_i)
            adv_prediction = E_attack(normalize(transforms_im(adv_imgs)))
            loss = custom_loss_region_gen_patch_adv(adv_prediction, torch.tile(target, (sub_batch_size,)), 
                                                    lp-op, adv_imgs, orig_imgs, **loss_hypers)
            loss.backward()
        
        # optimize
        optimizer.step()
        optimizer.zero_grad() 

    assess_rgp(insertion, latent_i, modify_fn, target.item(), source_class=source_class) 

def copy_paste_attack(source_file, patch_file, patch_side=85, prop_lower=0.1, prop_upper=0.9):
    """
    This function executes a copy paste attack. You will need to give it a source and target image file. 
    """

    source_path = Path(f'data/{source_file}')
    source_im = resize256(normalize(to_tensor(image.imread(source_path)[:, :, :3]))).to(device)
    patch_im = copy.deepcopy(source_im)

    patch_path = Path(f'data/{patch_file}')
    patch = normalize(to_tensor(image.imread(patch_path)[:, :, :3])).to(device)
    mid = (IMAGE_SIDE-patch_side) // 2
    diff = IMAGE_SIDE-patch_side
    rand_x = np.random.randint(int(diff*prop_lower), int(diff*prop_upper)+1)
    rand_y = np.random.randint(int(diff*prop_lower), int(diff*prop_upper)+1)
    patch_im[:, rand_x: rand_x+patch_side, rand_y: rand_y+patch_side] = T.functional.resize(patch, [patch_side, patch_side])
    
    orig_sm_out = E_attack(torch.unsqueeze(source_im, 0))[0]
    orig_label = torch.argmax(orig_sm_out)
    orig_conf = torch.max(orig_sm_out)
    patch_sm_out = E_attack(torch.unsqueeze(patch_im, 0))[0]
    patch_label = torch.argmax(patch_sm_out)
    patch_conf = torch.max(patch_sm_out)

    fig, axes = plt.subplots(1, 2, figsize=(8, 4))
    axes[0].imshow(tensor_to_numpy_image(source_im))
    axes[0].set_title(f'{class_dict[orig_label.item()]}: {round(orig_conf.item(), N_ROUND)}'.title(), fontweight='bold')
    axes[1].imshow(tensor_to_numpy_image(patch_im))
    axes[1].set_title(f'{class_dict[patch_label.item()]}: {round(patch_conf.item(), N_ROUND)}'.title(), fontweight='bold')
    for ax in axes:
        ax.set_xticks([])
        ax.set_yticks([])
    plt.show()

### Demo

In [None]:
# Generate universal patch attacks. 
# They have variable success for random target classes, so run multiple times. 
for _ in range(3):
    patch_adversary()

In [None]:
# Generate class-universal patch attacks (using generated source images rather than real ones)
for _ in range(3):
    patch_adversary(source_class=309, target_class=308)

In [None]:
# Generate universal region and generalized patch attacks. 
# They have variable success for random target classes, so run multiple times. 
for _ in range(3):
    region_generalized_patch_adversary()

In [None]:
# Generate class-universal region and generalized patch attacks.
for _ in range(3):
    region_generalized_patch_adversary(source_class=397, target_class=396)

In [None]:
# Simple function call to make a copy/paste attack. Upload your own images to create new ones. 
for _ in range(3): 
    copy_paste_attack('bee.png', 'traffic_light.png')