# BIGRoC: Boosting Image Generation via a Robust Classifier

--------------------------------------------------

## Boosting Guided Diffusion on ImageNet $256\times256$

This colab notebook contains the needed code to experiment with our proposed algorithm for boosting Guided Diffusion on ImageNet 256x256(in Section 5.2).

**How to use**

1.   Upload the notebook to colab
2.   Make sure that colab uses a GPU (Edit $\rightarrow$ Notebook settings $\rightarrow$ Hardware accelerator)

**Setup:**

This notebook mounts google drive and assumes that [this directory](https://drive.google.com/drive/folders/1yN6WjMmc-pi3zHylF-I1jri7I2nZpsGJ?usp=sharing) is located in your root folder in Google Drive. Therefore, please open the Google Drive's link and choose "add a shortcut to drive" and pick "My Drive" is the chosen location. By doing so, you are ready to go :)

If you rather not to mount your google drive, download the files and edit the relevant paths accordingly.

**Credits**

This notebook uses the following packages:

* [robustness](https://github.com/MadryLab/robustness) - For adversarially robust classifier
* [guided-diffuision](https://github.com/openai/guided-diffusion/tree/main/evaluations) - ImageNet quantitative evaluation and 50K set of generated images + labels

# Setup & Installation

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!pip install torch==1.8.2+cu102 torchvision==0.9.2+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html

In [None]:
!wget http://andrewilyas.com/ImageNet.pt

In [None]:
!pip install robustness

In [None]:
from robustness import model_utils, datasets
import torch


def create_dl_model(DATA='CIFAR', BATCH_SIZE=128, NUM_WORKERS=8):
    '''
    :param DATA: Choices: ['CIFAR', 'ImageNet', 'RestrictedImageNet']
    :param bs: batch size
    :param num_workers:
    :return: a dataloader object
    '''

    # Load dataset
    dataset_function = getattr(datasets, DATA)
    dataset = dataset_function('data')
    # Load model
    model_kwargs = {
        'arch': 'resnet50',
        'dataset': dataset,
        'resume_path': f'./{DATA}.pt'
    }
    model, _ = model_utils.make_and_restore_model(**model_kwargs)
    classifier = model.model
    classifier.eval()
    return classifier

adv_model = create_dl_model(DATA='ImageNet')

In [None]:
import torch
from torch import nn


class AttackerStep:
    '''
    Generic class for attacker steps, under perturbation constraints
    specified by an "origin input" and a perturbation magnitude.
    Must implement project, step, and random_perturb
    '''

    def __init__(self, orig_input, eps, step_size, use_grad=True):
        '''
        Initialize the attacker step with a given perturbation magnitude.
        Args:
            eps (float): the perturbation magnitude
            orig_input (ch.tensor): the original input
        '''
        self.orig_input = orig_input
        self.eps = eps
        self.step_size = step_size
        self.use_grad = use_grad

    def project(self, x):
        '''
        Given an input x, project it back into the feasible set
        Args:
            ch.tensor x : the input to project back into the feasible set.
        Returns:
            A `ch.tensor` that is the input projected back into
            the feasible set, that is,
        .. math:: \min_{x' \in S} \|x' - x\|_2
        '''
        raise NotImplementedError

    def step(self, x, g):
        '''
        Given a gradient, make the appropriate step according to the
        perturbation constraint (e.g. dual norm maximization for :math:`\ell_p`
        norms).
        Parameters:
            g (ch.tensor): the raw gradient
        Returns:
            The new input, a ch.tensor for the next step.
        '''
        raise NotImplementedError

    def random_perturb(self, x):
        '''
        Given a starting input, take a random step within the feasible set
        '''
        raise NotImplementedError

    def to_image(self, x):
        '''
        Given an input (which may be in an alternative parameterization),
        convert it to a valid image (this is implemented as the identity
        function by default as most of the time we use the pixel
        parameterization, but for alternative parameterizations this functino
        must be overriden).
        '''
        return x


from torch import nn

# simple Module to normalize an image
class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = torch.Tensor(mean)
        self.std = torch.Tensor(std)
    def forward(self, x):
        return (x - self.mean.type_as(x)[None, :, None, None]) / self.std.type_as(x)[None, :, None, None]

norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])



# L2 threat model
class L2Step(AttackerStep):
    """
    Attack step for :math:`\ell_\infty` threat model. Given :math:`x_0`
    and :math:`\epsilon`, the constraint set is given by:
    .. math:: S = \{x | \|x - x_0\|_2 \leq \epsilon\}
    """

    def project(self, x):
        """
        """
        if self.orig_input is None: self.orig_input = x.detach()
        self.orig_input = self.orig_input.cuda()
        diff = x - self.orig_input
        diff = diff.renorm(p=2, dim=0, maxnorm=self.eps)
        return torch.clamp(self.orig_input + diff, 0, 1)

    def step(self, x, g):
        """
        """
        l = len(x.shape) - 1
        g_norm = torch.norm(g.reshape(g.shape[0], -1), dim=1).view(-1, *([1] * l))
        scaled_g = g / (g_norm + 1e-10)
        return x + scaled_g * self.step_size

def targeted_pgd_l2(model, X, y, num_iter, eps, step_size):
    # input images are in range [0,1]
    steper = L2Step(eps=eps, orig_input=None, step_size=step_size)
    for t in range(num_iter):
        X = X.clone().detach().requires_grad_(True).cuda()
        loss = nn.CrossEntropyLoss(reduction='none')(model(norm(X)), y)
        loss = torch.mean(loss)
        grad, = torch.autograd.grad(-1 * loss, [X])
        X = steper.step(X, grad)
        X = steper.project(X)
    return X.detach()

In [None]:
!wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz

In [None]:
import numpy as np

x = np.load("admnet_guided_imagenet256.npz", mmap_mode='r') 
arr_orig = x['arr_0']
gt_labels = x['arr_1']

In [None]:
arr_orig.shape, arr_orig.min(), arr_orig.max()

Download Reference Batch

In [None]:
!wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz

--2022-01-16 16:13:29--  https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz
Resolving openaipublic.blob.core.windows.net (openaipublic.blob.core.windows.net)... 20.60.241.33
Connecting to openaipublic.blob.core.windows.net (openaipublic.blob.core.windows.net)|20.60.241.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2037122530 (1.9G) [application/octet-stream]
Saving to: ‘VIRTUAL_imagenet256_labeled.npz’


2022-01-16 16:14:33 (30.3 MB/s) - ‘VIRTUAL_imagenet256_labeled.npz’ saved [2037122530/2037122530]



# Guided Diffusion baseline quantitative results

ADM-G eval

In [None]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet256_labeled.npz ./admnet_guided_imagenet256.npz

downloading InceptionV3 model...
11679it [00:05, 2174.31it/s]
2022-01-13 16:06:56.476211: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:08<00:00,  8.25s/it]
computing reference batch activations...
100% 157/157 [00:32<00:00,  4.77it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:45<00:00,  4.73it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 186.69859313964844
FID: 4.586654813293535
sFID: 5.24677393057334
tcmalloc: large alloc 2000003072 bytes == 0x55b92cd0a000 @  0x7fb7cb766001 0x7fb7c822e54f 0x7fb7c827eb58 0x7fb7c8282b17 0x7fb7c8321203 0x55b85953b544 0x55b85953b240 0x55b8595af627 0x55b8595a9ced 0x55b85953cbda 0x55b8595aac0d 0x55b85953cafa 0x55b8595aac0d 0x55b85953cafa 0x55b8595aa915 0x55b8595a99ee 0x55b8595a96f3 0x55

# BIGRoC Application

In [None]:
# BIGRoC args
epsilon, steps = 3, 7
step_size = (epsilon * 1.5) / steps

In [None]:
from tqdm.notebook import tqdm

torch.manual_seed(1234)
for k in range(50):
  boosted_gen_imgs = []
  for i in tqdm(range(100)):
    x_batch = arr_orig[(1000*k) + i * 10: (1000*k) + (i + 1) * 10]
    x_batch_0_1 = torch.tensor(x_batch / 255.).cuda().permute(0,3,1,2).float()
    with torch.no_grad():
      # labels = torch.argmax(adv_model(norm(x_batch_0_1)), dim=1)
      labels = torch.tensor(gt_labels[(1000*k) + i * 10: (1000*k) + (i + 1) * 10]).cuda().long()
    b_imgs = targeted_pgd_l2(model=adv_model, X=x_batch_0_1.data, y=labels.long(), num_iter=steps, eps=epsilon,
                                      step_size=step_size).detach().cpu()
    boosted_gen_imgs.append(b_imgs)

  boosted_gen_imgs = torch.cat(boosted_gen_imgs)
  boosted_gen_imgs = (boosted_gen_imgs * 255.).int()
  boosted_gen_imgs = boosted_gen_imgs.detach().cpu().permute(0,2,3,1).numpy().astype(np.uint8)
  np.savez(f"./boosted_diff_256_{k}_eps_3_steps_{steps}_labeled", boosted_gen_imgs)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

load boosted images

In [None]:
import numpy as np
from tqdm.notebook import tqdm

arr = np.zeros(shape=(50000, 256, 256, 3), dtype=np.uint8)

for f in tqdm(range(50)):
  x = np.load(f"./boosted_diff_256_{f}_eps_3_steps_{steps}_labeled.npz")['arr_0']
  arr[1000 * f: 1000 * (f+1)] = x
  x = None

np.savez(f"boosted_diff_256_eps_3_steps_{steps}_labeled", arr)

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
arr.shape, arr.min(), arr.max()

((50000, 256, 256, 3), 0, 255)

$\epsilon = 3$, BIGRoC$_{GT}$

In [None]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet256_labeled.npz ./boosted_diff_256_eps_3_steps_7_labeled.npz

downloading InceptionV3 model...
11679it [00:04, 2878.08it/s]
2022-01-16 18:20:17.612924: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:09<00:00,  9.18s/it]
computing reference batch activations...
100% 157/157 [00:35<00:00,  4.42it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:51<00:00,  4.56it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 259.0693359375
FID: 3.984601585595442
sFID: 5.0033157549130465
tcmalloc: large alloc 2000003072 bytes == 0x55a7e8426000 @  0x7f16874d6001 0x7f1683f9e54f 0x7f1683feeb58 0x7f1683ff2b17 0x7f1684091203 0x55a7157ad544 0x55a7157ad240 0x55a715821627 0x55a71581bced 0x55a7157aebda 0x55a71581cc0d 0x55a7157aeafa 0x55a71581cc0d 0x55a7157aeafa 0x55a71581c915 0x55a71581b9ee 0x55a71581b6f3 0x55a7

$\epsilon = 3$, BIGRoC$_{PL}$

In [None]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet256_labeled.npz ./boosted_diff_256_eps_3_steps_7.npz

2022-01-15 06:53:53.197560: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:07<00:00,  7.88s/it]
computing reference batch activations...
100% 157/157 [00:34<00:00,  4.60it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:45<00:00,  4.72it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 239.62173461914062
FID: 4.072115726267782
sFID: 5.021853337463654
tcmalloc: large alloc 2000003072 bytes == 0x556d0e7a6000 @  0x7fcc10ac7001 0x7fcc0d58f54f 0x7fcc0d5dfb58 0x7fcc0d5e3b17 0x7fcc0d682203 0x556c3c298544 0x556c3c298240 0x556c3c30c627 0x556c3c306ced 0x556c3c299bda 0x556c3c307c0d 0x556c3c299afa 0x556c3c307c0d 0x556c3c299afa 0x556c3c307915 0x556c3c3069ee 0x556c3c3066f3 0x556c3c3d04c2 0x556c3c3d083d 0x556c3c3d06e6 0x556c3c3a8163 0x556

Additional Results

$\epsilon = 1.5$, BIGRoC$_{PL}$

In [None]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet256_labeled.npz ./boosted_diff_256_eps_1_half_steps_7.npz

downloading InceptionV3 model...
11679it [00:04, 2779.05it/s]
2022-01-15 06:36:24.564140: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:08<00:00,  8.52s/it]
computing reference batch activations...
100% 157/157 [00:32<00:00,  4.80it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:44<00:00,  4.77it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 214.6321563720703
FID: 4.102903153513239
sFID: 5.112735536421496
tcmalloc: large alloc 2000003072 bytes == 0x55b63b818000 @  0x7f1762c55001 0x7f175f71d54f 0x7f175f76db58 0x7f175f771b17 0x7f175f810203 0x55b5671e3544 0x55b5671e3240 0x55b567257627 0x55b567251ced 0x55b5671e4bda 0x55b567252c0d 0x55b5671e4afa 0x55b567252c0d 0x55b5671e4afa 0x55b567252915 0x55b5672519ee 0x55b5672516f3 0x55

$\epsilon = 5$, BIGRoC$_{PL}$

In [None]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet256_labeled.npz ./boosted_diff_256_eps_5_steps_7.npz

2022-01-14 17:59:40.101140: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:08<00:00,  8.78s/it]
computing reference batch activations...
100% 157/157 [00:35<00:00,  4.44it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:48<00:00,  4.64it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 275.64508056640625
FID: 4.450068382755262
sFID: 4.9191934585393255
tcmalloc: large alloc 2000003072 bytes == 0x55e82ff6a000 @  0x7fde93272001 0x7fde8fd3a54f 0x7fde8fd8ab58 0x7fde8fd8eb17 0x7fde8fe2d203 0x55e75ad72544 0x55e75ad72240 0x55e75ade6627 0x55e75ade0ced 0x55e75ad73bda 0x55e75ade1c0d 0x55e75ad73afa 0x55e75ade1c0d 0x55e75ad73afa 0x55e75ade1915 0x55e75ade09ee 0x55e75ade06f3 0x55e75aeaa4c2 0x55e75aeaa83d 0x55e75aeaa6e6 0x55e75ae82163 0x55