# BIGRoC: Boosting Image Generation via a Robust Classifier

--------------------------------------------------

## Boosting Guided Diffusion on ImageNet $128\times128$

This colab notebook contains the needed code to experiment with our proposed algorithm for boosting guided diffusion on ImageNet 128x128 (in Section 5.2).

**How to use**

1.   Upload the notebook to colab
2.   Make sure that colab uses a GPU (Edit $\rightarrow$ Notebook settings $\rightarrow$ Hardware accelerator)

**Setup:**

This notebook mounts google drive and assumes that [this directory](https://drive.google.com/drive/folders/1yN6WjMmc-pi3zHylF-I1jri7I2nZpsGJ?usp=sharing) is located in your root folder in Google Drive. Therefore, please open the Google Drive's link and choose "add a shortcut to drive" and pick "My Drive" is the chosen location. By doing so, you are ready to go :)

If you rather not to mount your google drive, download the files and edit the relevant paths accordingly.

**Credits**

This notebook uses the following packages:

* [robustness](https://github.com/MadryLab/robustness) - For adversarially robust classifier
* [guided-diffuision](https://github.com/openai/guided-diffusion/tree/main/evaluations) - ImageNet quantitative evaluation and 50K set of generated images + labels

# Setup & Installation

In [None]:
!pip install torch==1.8.2+cu102 torchvision==0.9.2+cu102 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
!wget http://andrewilyas.com/ImageNet.pt

In [None]:
!pip install robustness

In [None]:
from robustness import model_utils, datasets
import torch


def create_dl_model(DATA='CIFAR', BATCH_SIZE=128, NUM_WORKERS=8):
    '''
    :param DATA: Choices: ['CIFAR', 'ImageNet', 'RestrictedImageNet']
    :param bs: batch size
    :param num_workers:
    :return: a dataloader object
    '''

    # Load dataset
    dataset_function = getattr(datasets, DATA)
    dataset = dataset_function('data')
    # Load model
    model_kwargs = {
        'arch': 'resnet50',
        'dataset': dataset,
        'resume_path': f'./{DATA}.pt'
    }
    model, _ = model_utils.make_and_restore_model(**model_kwargs)
    classifier = model.model
    classifier.eval()
    return classifier

adv_model = create_dl_model(DATA='ImageNet')

In [2]:
import torch
from torch import nn


class AttackerStep:
    '''
    Generic class for attacker steps, under perturbation constraints
    specified by an "origin input" and a perturbation magnitude.
    Must implement project, step, and random_perturb
    '''

    def __init__(self, orig_input, eps, step_size, use_grad=True):
        '''
        Initialize the attacker step with a given perturbation magnitude.
        Args:
            eps (float): the perturbation magnitude
            orig_input (ch.tensor): the original input
        '''
        self.orig_input = orig_input
        self.eps = eps
        self.step_size = step_size
        self.use_grad = use_grad

    def project(self, x):
        '''
        Given an input x, project it back into the feasible set
        Args:
            ch.tensor x : the input to project back into the feasible set.
        Returns:
            A `ch.tensor` that is the input projected back into
            the feasible set, that is,
        .. math:: \min_{x' \in S} \|x' - x\|_2
        '''
        raise NotImplementedError

    def step(self, x, g):
        '''
        Given a gradient, make the appropriate step according to the
        perturbation constraint (e.g. dual norm maximization for :math:`\ell_p`
        norms).
        Parameters:
            g (ch.tensor): the raw gradient
        Returns:
            The new input, a ch.tensor for the next step.
        '''
        raise NotImplementedError

    def random_perturb(self, x):
        '''
        Given a starting input, take a random step within the feasible set
        '''
        raise NotImplementedError

    def to_image(self, x):
        '''
        Given an input (which may be in an alternative parameterization),
        convert it to a valid image (this is implemented as the identity
        function by default as most of the time we use the pixel
        parameterization, but for alternative parameterizations this functino
        must be overriden).
        '''
        return x


from torch import nn

# simple Module to normalize an image
class Normalize(nn.Module):
    def __init__(self, mean, std):
        super(Normalize, self).__init__()
        self.mean = torch.Tensor(mean)
        self.std = torch.Tensor(std)
    def forward(self, x):
        return (x - self.mean.type_as(x)[None, :, None, None]) / self.std.type_as(x)[None, :, None, None]

norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])



# L2 threat model
class L2Step(AttackerStep):
    """
    Attack step for :math:`\ell_\infty` threat model. Given :math:`x_0`
    and :math:`\epsilon`, the constraint set is given by:
    .. math:: S = \{x | \|x - x_0\|_2 \leq \epsilon\}
    """

    def project(self, x):
        """
        """
        if self.orig_input is None: self.orig_input = x.detach()
        self.orig_input = self.orig_input.cuda()
        diff = x - self.orig_input
        diff = diff.renorm(p=2, dim=0, maxnorm=self.eps)
        return torch.clamp(self.orig_input + diff, 0, 1)

    def step(self, x, g):
        """
        """
        l = len(x.shape) - 1
        g_norm = torch.norm(g.reshape(g.shape[0], -1), dim=1).view(-1, *([1] * l))
        scaled_g = g / (g_norm + 1e-10)
        return x + scaled_g * self.step_size

def targeted_pgd_l2(model, X, y, num_iter, eps, step_size):
    # input images are in range [0,1]
    steper = L2Step(eps=eps, orig_input=None, step_size=step_size)
    for t in range(num_iter):
        X = X.clone().detach().requires_grad_(True).cuda()
        loss = nn.CrossEntropyLoss(reduction='none')(model(norm(X)), y)
        loss = torch.mean(loss)
        grad, = torch.autograd.grad(-1 * loss, [X])
        X = steper.step(X, grad)
        X = steper.project(X)
    return X.detach()

Download reference batch

In [None]:
!wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz

Download guided diffusion generated images + gt labels

In [None]:
!wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz

In [3]:
import numpy as np

x = np.load("admnet_guided_imagenet128.npz", mmap_mode='r')
gen_imgs = x['arr_0']
gt_labels = x['arr_1']

# Guided Diffusion baseline quantitative results

ADM-G quantitative evaluation

In [None]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet128_labeled.npz ./admnet_guided_imagenet128.npz

downloading InceptionV3 model...
11679it [00:02, 3993.89it/s]
2022-01-27 07:27:42.287057: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:09<00:00,  9.08s/it]
computing reference batch activations...
100% 157/157 [00:29<00:00,  5.29it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:23<00:00,  5.46it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 141.37184143066406
FID: 2.9739931366061114
sFID: 5.09352838087807
tcmalloc: large alloc 2000003072 bytes == 0x55c041f34000 @  0x7fcf7295a001 0x7fcf6f42254f 0x7fcf6f472b58 0x7fcf6f476b17 0x7fcf6f515203 0x55bf6f82d544 0x55bf6f82d240 0x55bf6f8a1627 0x55bf6f89bced 0x55bf6f82ebda 0x55bf6f89cc0d 0x55bf6f82eafa 0x55bf6f89cc0d 0x55bf6f82eafa 0x55bf6f89c915 0x55bf6f89b9ee 0x55bf6f89b6f3 0x5

# BIGRoC Application

In [4]:
# BIGRoC args
epsilon, steps = 1.5, 7
step_size = 1.5 * epsilon / steps
use_pl = True

In [5]:
from tqdm.notebook import tqdm

torch.manual_seed(1234)
boosted_gen_imgs = []
for i in tqdm(range(250)):
  x_batch = gen_imgs[i * 100: (i + 1) * 100]
  x_batch_0_1 = torch.tensor(x_batch / 255.).cuda().permute(0,3,1,2).float()
  with torch.no_grad():
    if use_pl:
      labels = torch.argmax(adv_model(norm(x_batch_0_1)), dim=1)
    else:
      labels = torch.tensor(gt_labels[i * 100: (i + 1) * 100]).cuda().long()
  b_imgs = targeted_pgd_l2(model=adv_model, X=x_batch_0_1.data, y=labels.long(), num_iter=steps, eps=epsilon,
                                    step_size=step_size).detach().cpu()
  boosted_gen_imgs.append(b_imgs)

boosted_gen_imgs = torch.cat(boosted_gen_imgs)
boosted_gen_imgs = (boosted_gen_imgs * 255.).int()
boosted_gen_imgs = boosted_gen_imgs.detach().cpu().permute(0,2,3,1).numpy()

  0%|          | 0/250 [00:00<?, ?it/s]

In [6]:
if use_pl:
  np.savez("boosted_diff_1_eps_1_5", boosted_gen_imgs)
else:
  np.savez("boosted_diff_1_eps_1_5_labeled", boosted_gen_imgs)

In [7]:
from tqdm.notebook import tqdm

torch.manual_seed(1234)
boosted_gen_imgs = []
for i in tqdm(range(250)):
  x_batch = gen_imgs[(25000 + i*100):(25000 + (i + 1) * 100)]
  x_batch_0_1 = torch.tensor(x_batch / 255.).cuda().permute(0,3,1,2).float()
  with torch.no_grad():
    if use_pl:
      labels = torch.argmax(adv_model(norm(x_batch_0_1)), dim=1)
    else:
      labels = torch.tensor(gt_labels[(25000 + i*100):(25000 + (i + 1) * 100)]).cuda().long()
  boosted_gen_imgs.append(targeted_pgd_l2(model=adv_model, X=x_batch_0_1.data, y=labels.long(), num_iter=steps, eps=epsilon,
                                    step_size=step_size).detach().cpu())

boosted_gen_imgs = torch.cat(boosted_gen_imgs)
boosted_gen_imgs = (boosted_gen_imgs * 255.).int()
boosted_gen_imgs = boosted_gen_imgs.detach().cpu().permute(0,2,3,1).numpy()

  0%|          | 0/250 [00:00<?, ?it/s]

In [8]:
if use_pl:
  np.savez("boosted_diff_2_eps_1_5", boosted_gen_imgs)
else:
  np.savez("boosted_diff_2_eps_1_5_labeled", boosted_gen_imgs)

gather boosted images to one npz for evaluation

In [9]:
import numpy as np
from tqdm.notebook import tqdm
""
arr = np.zeros(shape=(50000, 128, 128, 3), dtype=np.uint8)

if use_pl:
  x = np.load("boosted_diff_1_eps_1_5.npz")['arr_0']
  arr[:25000] = x
  x = np.load("boosted_diff_2_eps_1_5.npz")['arr_0']
  arr[25000:] = x
else:
  x = np.load("boosted_diff_1_eps_1_5_labeled.npz")['arr_0']
  arr[:25000] = x
  x = np.load("boosted_diff_2_eps_1_5_labeled.npz")['arr_0']
  arr[25000:] = x

In [None]:
arr.shape, arr.min(), arr.max()

((50000, 128, 128, 3), 0, 255)

In [10]:
if use_pl:
  np.savez("boosted_diff_eps_1_5", arr)
else:
  np.savez("boosted_diff_eps_1_5_labeled", arr)

$\epsilon = 1.5$ BIGRoC$_{GT}$

In [None]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet128_labeled.npz ./boosted_diff_eps_1_5_labeled.npz

2022-01-27 08:00:46.832763: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:02<00:00,  2.44s/it]
computing reference batch activations...
100% 157/157 [00:29<00:00,  5.28it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:29<00:00,  5.24it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 169.73379516601562
FID: 2.533804513797861
sFID: 4.886204958273538
tcmalloc: large alloc 2000003072 bytes == 0x565120a3e000 @  0x7fa0c11dc001 0x7fa0bdca454f 0x7fa0bdcf4b58 0x7fa0bdcf8b17 0x7fa0bdd97203 0x56504e70e544 0x56504e70e240 0x56504e782627 0x56504e77cced 0x56504e70fbda 0x56504e77dc0d 0x56504e70fafa 0x56504e77dc0d 0x56504e70fafa 0x56504e77d915 0x56504e77c9ee 0x56504e77c6f3 0x56504e8464c2 0x56504e84683d 0x56504e8466e6 0x56504e81e163 0x565

$\epsilon = 1.5$ BIGRoC$_{PL}$

In [11]:
!python ./gdrive/MyDrive/models_and_scripts/evaluator.py ./VIRTUAL_imagenet128_labeled.npz ./boosted_diff_eps_1_5.npz

2022-01-27 09:57:25.725015: W tensorflow/core/framework/op_def_util.cc:371] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
warming up TensorFlow...
100% 1/1 [00:09<00:00,  9.03s/it]
computing reference batch activations...
100% 157/157 [00:30<00:00,  5.15it/s]
computing/reading reference batch statistics...
computing sample batch activations...
100% 782/782 [02:26<00:00,  5.32it/s]
computing/reading sample batch statistics...
Computing evaluations...
Inception Score: 150.43173217773438
FID: 2.7753071639148743
sFID: 4.971221621778113
tcmalloc: large alloc 2000003072 bytes == 0x556aa63ce000 @  0x7f7cd0169001 0x7f7cccc3154f 0x7f7cccc81b58 0x7f7cccc85b17 0x7f7cccd24203 0x5569d2fcb544 0x5569d2fcb240 0x5569d303f627 0x5569d3039ced 0x5569d2fccbda 0x5569d303ac0d 0x5569d2fccafa 0x5569d303ac0d 0x5569d2fccafa 0x5569d303a915 0x5569d30399ee 0x5569d30396f3 0x5569d31034c2 0x5569d310383d 0x5569d31036e6 0x5569d30db163 0x55