# Preparation

## Download & Extract Adaptive Exposures

In [None]:
import gdown
url = "https://drive.google.com/drive/folders/1zXEwDiB1EUhG4-n4x3LYhbffOKYCbiEc?usp=sharing"
gdown.download_folder(url, quiet=True, use_cookies=False)

## Download OOD Datasets

In [None]:
%%capture
!mkdir data
# Places
!wget https://dl.dropboxusercontent.com/s/3pwqsyv33f6if3z/val_256.tar
!tar -xf val_256.tar -C ./data
%cd data
!wget https://dl.dropboxusercontent.com/s/gaf1ygpdnkhzyjo/places365_val.txt
!wget https://dl.dropboxusercontent.com/s/enr71zpolzi1xzm/categories_places365.txt
%cd ..
# COIL
!mkdir data/coil
!wget http://www.cs.columbia.edu/CAVE/databases/SLAM_coil-20_coil-100/coil-100/coil-100.zip
!unzip coil-100.zip -d ./data
!mkdir data/coil
!cp -r data/coil-100 data/coil
# LSUN
!wget https://www.dropbox.com/s/moqh2wh8696c3yl/LSUN_resize.tar.gz
!tar -xf LSUN_resize.tar.gz -C ./data
%cd data
# iSUN
!wget https://www.dropbox.com/s/ssz7qxfqae0cca5/iSUN.tar.gz
!tar -xf iSUN.tar.gz
# Birds
!wget https://www.dropbox.com/s/yc6kz6ld56q836c/images.tgz
!tar -xf images.tgz
# Flowers
!wget https://dl.dropboxusercontent.com/s/hbt8e7wjiplryoo/102flowers.tgz
!tar -xf 102flowers.tgz
!mv jpg flowers
!mkdir flowers/fld
import os
import shutil

# Source and destination folder paths
src_folder = './flowers'
dst_folder = './flowers/fld'

# Copy all files from the source folder to the destination folder
for filename in os.listdir(src_folder):
    # Construct the full file paths
    src_file = os.path.join(src_folder, filename)
    dst_file = os.path.join(dst_folder, filename)

    # Copy the file to the destination folder if it's a file (not a folder)
    if os.path.isfile(src_file):
        shutil.copy(src_file, dst_file)
# Tiny Image Net
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip tiny-imagenet-200.zip
%cd ..

# Use CUDA

In [None]:
!nvidia-smi

In [None]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Configurations

In [None]:
from torchvision import transforms

in_dataset = ['cifar10', 'cifar100'][0]

epochs = 10
optim = "adam"

lr = 0.001

avialable_datasets = ['cifar10', 'cifar100', 'mnist', 'places', 'coil', 'LSUN', 'iSUN', 'flowers', 'birds', 'tiny_imagenet']
out_dataset = avialable_datasets[1]

attack_eps = 8/255
attack_steps = 10
attack_alpha = 2.5 * attack_eps / attack_steps
num_classes = {
    'cifar10': 10,
    'cifar100': 20
}[in_dataset]
all_num_classes = num_classes

batch_size = 128

# Model

## Wide Resnet

In [None]:
import math
import torch
import torch.nn as nn


class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = torch.nn.functional.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = torch.nn.functional.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

# Attack

## Base Attack

In [None]:
import time
import logging
from collections import OrderedDict
from collections.abc import Iterable

import torch
from torch.utils.data import DataLoader, TensorDataset


def wrapper_method(func):
    def wrapper_func(self, *args, **kwargs):
        result = func(self, *args, **kwargs)
        for atk in self.__dict__.get('_attacks').values():
            eval("atk."+func.__name__+"(*args, **kwargs)")
        return result
    return wrapper_func


class Attack(object):
    r"""
    Base class for all attacks.

    .. note::
        It automatically set device to the device where given model is.
        It basically changes training mode to eval during attack process.
        To change this, please see `set_model_training_mode`.
    """

    def __init__(self, name, model):
        r"""
        Initializes internal attack state.

        Arguments:
            name (str): name of attack.
            model (torch.nn.Module): model to attack.
        """

        self.attack = name
        self._attacks = OrderedDict()

        self.set_model(model)
        self.device = next(model.parameters()).device

        # Controls attack mode.
        self.attack_mode = 'default'
        self.supported_mode = ['default']
        self.targeted = False
        self._target_map_function = None

        # Controls when normalization is used.
        self.normalization_used = {}
        self._normalization_applied = False
        self._set_auto_normalization_used(model)

        # Controls model mode during attack.
        self._model_training = False
        self._batchnorm_training = False
        self._dropout_training = False

    def forward(self, inputs, labels=None, *args, **kwargs):
        r"""
        It defines the computation performed at every call.
        Should be overridden by all subclasses.
        """
        raise NotImplementedError

    def _check_inputs(self, images):
        tol = 1e-4
        if self._normalization_applied:
            images = self.inverse_normalize(images)
        if torch.max(images) > 1+tol or torch.min(images) < 0-tol:
            raise ValueError('Input must have a range [0, 1] (max: {}, min: {})'.format(
                torch.max(images), torch.min(images)))
        return images

    def _check_outputs(self, images):
        if self._normalization_applied:
            images = self.normalize(images)
        return images

    @wrapper_method
    def set_model(self, model):
        self.model = model
        self.model_name = model.__class__.__name__

    def get_logits(self, inputs, labels=None, *args, **kwargs):
        if self._normalization_applied:
            inputs = self.normalize(inputs)
        logits = self.model(inputs)
        return logits

    @wrapper_method
    def _set_normalization_applied(self, flag):
        self._normalization_applied = flag

    @wrapper_method
    def set_device(self, device):
        self.device = device

    @wrapper_method
    def _set_auto_normalization_used(self, model):
        if model.__class__.__name__ == 'RobModel':
            mean = getattr(model, 'mean', None)
            std = getattr(model, 'std', None)
            if (mean is not None) and (std is not None):
                if isinstance(mean, torch.Tensor):
                    mean = mean.cpu().numpy()
                if isinstance(std, torch.Tensor):
                    std = std.cpu().numpy()
                if (mean != 0).all() or (std != 1).all():
                    self.set_normalization_used(mean, std)
    #                 logging.info("Normalization automatically loaded from `model.mean` and `model.std`.")

    @wrapper_method
    def set_normalization_used(self, mean, std):
        n_channels = len(mean)
        mean = torch.tensor(mean).reshape(1, n_channels, 1, 1)
        std = torch.tensor(std).reshape(1, n_channels, 1, 1)
        self.normalization_used['mean'] = mean
        self.normalization_used['std'] = std
        self._normalization_applied = True

    def normalize(self, inputs):
        mean = self.normalization_used['mean'].to(inputs.device)
        std = self.normalization_used['std'].to(inputs.device)
        return (inputs - mean) / std

    def inverse_normalize(self, inputs):
        mean = self.normalization_used['mean'].to(inputs.device)
        std = self.normalization_used['std'].to(inputs.device)
        return inputs*std + mean

    def get_mode(self):
        r"""
        Get attack mode.

        """
        return self.attack_mode

    @wrapper_method
    def set_mode_default(self):
        r"""
        Set attack mode as default mode.

        """
        self.attack_mode = 'default'
        self.targeted = False
        print("Attack mode is changed to 'default.'")

    @wrapper_method
    def _set_mode_targeted(self, mode, quiet):
        if "targeted" not in self.supported_mode:
            raise ValueError("Targeted mode is not supported.")
        self.targeted = True
        self.attack_mode = mode
        if not quiet:
            print("Attack mode is changed to '%s'." % mode)

    @wrapper_method
    def set_mode_targeted_by_function(self, target_map_function, quiet=False):
        r"""
        Set attack mode as targeted.

        Arguments:
            target_map_function (function): Label mapping function.
                e.g. lambda inputs, labels:(labels+1)%10.
                None for using input labels as targeted labels. (Default)

        """
        self._set_mode_targeted('targeted(custom)', quiet)
        self._target_map_function = target_map_function

    @wrapper_method
    def set_mode_targeted_random(self, quiet=False):
        r"""
        Set attack mode as targeted with random labels.

        Arguments:
            num_classses (str): number of classes.

        """
        self._set_mode_targeted('targeted(random)', quiet)
        self._target_map_function = self.get_random_target_label

    @wrapper_method
    def set_mode_targeted_least_likely(self, kth_min=1, quiet=False):
        r"""
        Set attack mode as targeted with least likely labels.

        Arguments:
            kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1)

        """
        self._set_mode_targeted('targeted(least-likely)', quiet)
        assert (kth_min > 0)
        self._kth_min = kth_min
        self._target_map_function = self.get_least_likely_label

    @wrapper_method
    def set_mode_targeted_by_label(self, quiet=False):
        r"""
        Set attack mode as targeted.

        .. note::
            Use user-supplied labels as target labels.
        """
        self._set_mode_targeted('targeted(label)', quiet)
        self._target_map_function = 'function is a string'

    @wrapper_method
    def set_model_training_mode(self, model_training=False, batchnorm_training=False, dropout_training=False):
        r"""
        Set training mode during attack process.

        Arguments:
            model_training (bool): True for using training mode for the entire model during attack process.
            batchnorm_training (bool): True for using training mode for batchnorms during attack process.
            dropout_training (bool): True for using training mode for dropouts during attack process.

        .. note::
            For RNN-based models, we cannot calculate gradients with eval mode.
            Thus, it should be changed to the training mode during the attack.
        """
        self._model_training = model_training
        self._batchnorm_training = batchnorm_training
        self._dropout_training = dropout_training

    @wrapper_method
    def _change_model_mode(self, given_training):
        if self._model_training:
            self.model.train()
            for _, m in self.model.named_modules():
                if not self._batchnorm_training:
                    if 'BatchNorm' in m.__class__.__name__:
                        m = m.eval()
                if not self._dropout_training:
                    if 'Dropout' in m.__class__.__name__:
                        m = m.eval()
        else:
            self.model.eval()

    @wrapper_method
    def _recover_model_mode(self, given_training):
        if given_training:
            self.model.train()

    def save(self, data_loader, save_path=None, verbose=True, return_verbose=False,
             save_predictions=False, save_clean_inputs=False, save_type='float'):
        r"""
        Save adversarial inputs as torch.tensor from given torch.utils.data.DataLoader.

        Arguments:
            save_path (str): save_path.
            data_loader (torch.utils.data.DataLoader): data loader.
            verbose (bool): True for displaying detailed information. (Default: True)
            return_verbose (bool): True for returning detailed information. (Default: False)
            save_predictions (bool): True for saving predicted labels (Default: False)
            save_clean_inputs (bool): True for saving clean inputs (Default: False)

        """
        if save_path is not None:
            adv_input_list = []
            label_list = []
            if save_predictions:
                pred_list = []
            if save_clean_inputs:
                input_list = []

        correct = 0
        total = 0
        l2_distance = []

        total_batch = len(data_loader)
        given_training = self.model.training

        for step, (inputs, labels) in enumerate(data_loader):
            start = time.time()
            adv_inputs = self.__call__(inputs, labels)
            batch_size = len(inputs)

            if verbose or return_verbose:
                with torch.no_grad():
                    outputs = self.get_output_with_eval_nograd(adv_inputs)

                    # Calculate robust accuracy
                    _, pred = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    right_idx = (pred == labels.to(self.device))
                    correct += right_idx.sum()
                    rob_acc = 100 * float(correct) / total

                    # Calculate l2 distance
                    delta = (adv_inputs - inputs.to(self.device)).view(batch_size, -1)  # nopep8
                    l2_distance.append(torch.norm(delta[~right_idx], p=2, dim=1))  # nopep8
                    l2 = torch.cat(l2_distance).mean().item()

                    # Calculate time computation
                    progress = (step+1)/total_batch*100
                    end = time.time()
                    elapsed_time = end-start

                    if verbose:
                        self._save_print(progress, rob_acc, l2, elapsed_time, end='\r')  # nopep8

            if save_path is not None:
                adv_input_list.append(adv_inputs.detach().cpu())
                label_list.append(labels.detach().cpu())

                adv_input_list_cat = torch.cat(adv_input_list, 0)
                label_list_cat = torch.cat(label_list, 0)

                save_dict = {'adv_inputs': adv_input_list_cat, 'labels': label_list_cat}  # nopep8

                if save_predictions:
                    pred_list.append(pred.detach().cpu())
                    pred_list_cat = torch.cat(pred_list, 0)
                    save_dict['preds'] = pred_list_cat

                if save_clean_inputs:
                    input_list.append(inputs.detach().cpu())
                    input_list_cat = torch.cat(input_list, 0)
                    save_dict['clean_inputs'] = input_list_cat

                if self.normalization_used is not None:
                    save_dict['adv_inputs'] = self.inverse_normalize(save_dict['adv_inputs'])  # nopep8
                    if save_clean_inputs:
                        save_dict['clean_inputs'] = self.inverse_normalize(save_dict['clean_inputs'])  # nopep8

                if save_type == 'int':
                    save_dict['adv_inputs'] = self.to_type(save_dict['adv_inputs'], 'int')  # nopep8
                    if save_clean_inputs:
                        save_dict['clean_inputs'] = self.to_type(save_dict['clean_inputs'], 'int')  # nopep8

                save_dict['save_type'] = save_type
                torch.save(save_dict, save_path)

        # To avoid erasing the printed information.
        if verbose:
            self._save_print(progress, rob_acc, l2, elapsed_time, end='\n')

        if given_training:
            self.model.train()

        if return_verbose:
            return rob_acc, l2, elapsed_time

    @staticmethod
    def to_type(inputs, type):
        r"""
        Return inputs as int if float is given.
        """
        if type == 'int':
            if isinstance(inputs, torch.FloatTensor) or isinstance(inputs, torch.cuda.FloatTensor):
                return (inputs*255).type(torch.uint8)
        elif type == 'float':
            if isinstance(inputs, torch.ByteTensor) or isinstance(inputs, torch.cuda.ByteTensor):
                return inputs.float()/255
        else:
            raise ValueError(
                type + " is not a valid type. [Options: float, int]")
        return inputs

    @staticmethod
    def _save_print(progress, rob_acc, l2, elapsed_time, end):
        print('- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \t'
              % (progress, rob_acc, l2, elapsed_time), end=end)

    @staticmethod
    def load(load_path, batch_size=128, shuffle=False, normalize=None,
             load_predictions=False, load_clean_inputs=False):
        save_dict = torch.load(load_path)
        keys = ['adv_inputs', 'labels']

        if load_predictions:
            keys.append('preds')
        if load_clean_inputs:
            keys.append('clean_inputs')

        if save_dict['save_type'] == 'int':
            save_dict['adv_inputs'] = save_dict['adv_inputs'].float()/255
            if load_clean_inputs:
                save_dict['clean_inputs'] = save_dict['clean_inputs'].float() / 255  # nopep8

        if normalize is not None:
            n_channels = len(normalize['mean'])
            mean = torch.tensor(normalize['mean']).reshape(1, n_channels, 1, 1)
            std = torch.tensor(normalize['std']).reshape(1, n_channels, 1, 1)
            save_dict['adv_inputs'] = (save_dict['adv_inputs'] - mean) / std
            if load_clean_inputs:
                save_dict['clean_inputs'] = (save_dict['clean_inputs'] - mean) / std  # nopep8

        adv_data = TensorDataset(*[save_dict[key] for key in keys])
        adv_loader = DataLoader(
            adv_data, batch_size=batch_size, shuffle=shuffle)
        print("Data is loaded in the following order: [%s]" % (", ".join(keys)))  # nopep8
        return adv_loader

    @torch.no_grad()
    def get_output_with_eval_nograd(self, inputs):
        given_training = self.model.training
        if given_training:
            self.model.eval()
        outputs = self.get_logits(inputs)
        if given_training:
            self.model.train()
        return outputs

    def get_target_label(self, inputs, labels=None):
        r"""
        Function for changing the attack mode.
        Return input labels.
        """
        if self._target_map_function is None:
            raise ValueError(
                'target_map_function is not initialized by set_mode_targeted.')
        if self.attack_mode == 'targeted(label)':
            target_labels = labels
        else:
            target_labels = self._target_map_function(inputs, labels)
        return target_labels

    @torch.no_grad()
    def get_least_likely_label(self, inputs, labels=None):
        outputs = self.get_output_with_eval_nograd(inputs)
        if labels is None:
            _, labels = torch.max(outputs, dim=1)
        n_classses = outputs.shape[-1]

        target_labels = torch.zeros_like(labels)
        for counter in range(labels.shape[0]):
            l = list(range(n_classses))
            l.remove(labels[counter])
            _, t = torch.kthvalue(outputs[counter][l], self._kth_min)
            target_labels[counter] = l[t]

        return target_labels.long().to(self.device)

    @torch.no_grad()
    def get_random_target_label(self, inputs, labels=None):
        outputs = self.get_output_with_eval_nograd(inputs)
        if labels is None:
            _, labels = torch.max(outputs, dim=1)
        n_classses = outputs.shape[-1]

        target_labels = torch.zeros_like(labels)
        for counter in range(labels.shape[0]):
            l = list(range(n_classses))
            l.remove(labels[counter])
            t = (len(l)*torch.rand([1])).long().to(self.device)
            target_labels[counter] = l[t]

        return target_labels.long().to(self.device)

    def __call__(self, images, labels=None, *args, **kwargs):
        given_training = self.model.training
        self._change_model_mode(given_training)
        images = self._check_inputs(images)
        adv_images = self.forward(images, labels, *args, **kwargs)
        adv_images = self._check_outputs(adv_images)
        self._recover_model_mode(given_training)
        return adv_images

    def __repr__(self):
        info = self.__dict__.copy()

        del_keys = ['model', 'attack', 'supported_mode']

        for key in info.keys():
            if key[0] == "_":
                del_keys.append(key)

        for key in del_keys:
            del info[key]

        info['attack_mode'] = self.attack_mode
        info['normalization_used'] = True if len(self.normalization_used) > 0 else False  # nopep8

        return self.attack + "(" + ', '.join('{}={}'.format(key, val) for key, val in info.items()) + ")"

    def __setattr__(self, name, value):
        object.__setattr__(self, name, value)

        attacks = self.__dict__.get('_attacks')

        # Get all items in iterable items.
        def get_all_values(items, stack=[]):
            if (items not in stack):
                stack.append(items)
                if isinstance(items, list) or isinstance(items, dict):
                    if isinstance(items, dict):
                        items = (list(items.keys())+list(items.values()))
                    for item in items:
                        yield from get_all_values(item, stack)
                else:
                    if isinstance(items, Attack):
                        yield items
            else:
                if isinstance(items, Attack):
                    yield items

        for num, value in enumerate(get_all_values(value)):
            attacks[name+"."+str(num)] = value
            for subname, subvalue in value.__dict__.get('_attacks').items():
                attacks[name+"."+subname] = subvalue

## PGD

In [None]:

class PGD_CLS(Attack):
    r"""
    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'
    [https://arxiv.org/abs/1706.06083]
    Distance Measure : Linf
    Arguments:
        model (nn.Module): model to attack.
        eps (float): maximum perturbation. (Default: 8/255)
        alpha (float): step size. (Default: 2/255)
        steps (int): number of steps. (Default: 10)
        random_start (bool): using random initialization of delta. (Default: True)
    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.
    Examples::
        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True)
        >>> adv_images = attack(images, labels)
    """

    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True):
        super().__init__("PGD", model)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self.supported_mode = ['default', 'targeted']

    def forward(self, images, labels):
        r"""
        Overridden.
        """

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self.targeted:
            target_labels = self.get_target_label(images, labels)

        loss = nn.CrossEntropyLoss()
        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + \
                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        for _ in range(self.steps):
            adv_images.requires_grad = True
            outputs = self.get_logits(adv_images)

            # Calculate loss
            if self.targeted:
                cost = -loss(outputs, target_labels)
            else:
                cost = loss(outputs, labels)

            # Update adversarial images
            grad = torch.autograd.grad(cost, adv_images,
                                       retain_graph=False, create_graph=False)[0]

            adv_images = adv_images.detach() + self.alpha*grad.sign()
            delta = torch.clamp(adv_images - images,
                                min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        return  adv_images

In [None]:

class PGD_TEST(Attack):
    r"""
    PGD in the paper 'Towards Deep Learning Models Resistant to Adversarial Attacks'
    [https://arxiv.org/abs/1706.06083]

    Distance Measure : Linf

    Arguments:
        model (nn.Module): model to attack.
        eps (float): maximum perturbation. (Default: 8/255)
        alpha (float): step size. (Default: 2/255)
        steps (int): number of steps. (Default: 10)
        random_start (bool): using random initialization of delta. (Default: True)

    Shape:
        - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`,        `H = height` and `W = width`. It must have a range [0, 1].
        - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
        - output: :math:`(N, C, H, W)`.

    Examples::
        >>> attack = torchattacks.PGD(model, eps=8/255, alpha=1/255, steps=10, random_start=True)
        >>> adv_images = attack(images, labels)

    """

    def __init__(self, model, eps=8/255, alpha=2/255, steps=10, random_start=True, num_classes=10):
        super().__init__("PGD", model)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self.supported_mode = ['default', 'targeted']
        self.num_classes = num_classes

    def forward(self, images, labels):
        r"""
        Overridden.
        """

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        ones = torch.ones_like(labels)
        multipliers = -1 * (ones - 2 * ones * (labels == self.num_classes))

        adv_images = images.clone().detach()

        if self.random_start:
            # Starting at a uniformly random point
            adv_images = adv_images + \
                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        for _ in range(self.steps):
            adv_images.requires_grad = True
            outputs = self.get_logits(adv_images)

            target_labels = torch.full_like(labels, self.num_classes)
            cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
            losses = cross_entropy_loss(outputs, target_labels)

            cost = torch.mean(losses * multipliers)

            # Update adversarial images
            grad = torch.autograd.grad(cost, adv_images,
                                       retain_graph=False, create_graph=False)[0]

            adv_images = adv_images.detach() + self.alpha * grad.sign()
            delta = torch.clamp(adv_images - images,
                                min=-self.eps, max=self.eps)
            adv_images = torch.clamp(images + delta, min=0, max=1).detach()

        return adv_images


# Data

## Datasets

In [None]:
import torch
from torchvision import transforms
from glob import glob
import numpy as np
from PIL import Image
from torchvision.transforms import functional as F
import random

class AdaptiveOutliers(torch.utils.data.Dataset):    
    def __init__(self, filepath='./Generated Outliers/cifar10-glide-ood/cifar10_glide_ood.npy', size=32, transform=transforms.Compose([transforms.Resize(32), transforms.ToTensor()])):
        
        self.data = [x for x in torch.from_numpy(np.load(filepath))]
        self.data = [transform(F.to_pil_image(x)) for x in self.data]
        self.targets = np.zeros(len(self.data))

    def __getitem__(self, index):
        image = self.data[index]
        target = self.targets[index]
        
        return image, int(target)

    def __len__(self):
        return len(self.data)
    
class MergedDataset(torch.utils.data.Dataset):    
    def __init__(self, in_dataset, out_dataset, num_classes, root='.', size=32):
        self.data = [x for x in in_dataset.data] + [x for x in out_dataset.data]
        self.targets = [y for y in in_dataset.targets] + [num_classes] * len(out_dataset)
        self.transform = transforms.Compose([transforms.Resize(size), transforms.ToTensor()])

    def __getitem__(self, index):
        image = self.data[index]
        target = self.targets[index]

        if self.transform:
            image = self.transform(F.to_pil_image(image))
        
        return image, int(target)

    def __len__(self):
        return len(self.data)

## OOD Datasets Loader

In [None]:
def get_out_testing_datasets(out_name):

    if name == 'mnist':
      mnist = torchvision.datasets.MNIST(root='./data', train = False, download = True, transform=transforms.Compose([transforms.ToTensor(),
                                                                                            transforms.Resize(32),
                                                                                            transforms.Lambda(lambda x : x.repeat(3, 1, 1)),
                                                                                            ]))
      return mnist
    
    elif name == 'tiny_imagenet':
      tiny_imagenet = torchvision.datasets.ImageFolder(root = 'data/tiny-imagenet-200/test', transform=transforms.Compose([transforms.ToTensor(),
                                                                                                    transforms.Resize(32)]))
      
      return tiny_imagenet
    
    elif name == 'places':
      places365 = torchvision.datasets.Places365(root = 'data/', split = 'val', small = True, download = False, transform=transforms.Compose([transforms.ToTensor(),
                                                                                                    transforms.Resize(32)]))

      return places
    
    elif name == 'LSUN':
      LSUN = torchvision.datasets.ImageFolder(root = 'data/LSUN_resize/', transform = transforms.ToTensor())

      return LSUN

    elif name == 'iSUN':
      iSUN = torchvision.datasets.ImageFolder(root = 'data/iSUN/', transform = transforms.ToTensor())

      return iSUN
      
    elif name == 'birds': 
      birds = torchvision.datasets.ImageFolder(root = 'data/images/', loader=bird_loader, transform = transforms.ToTensor())

      return birds
    
    elif name == 'flowers':
      flowers = torchvision.datasets.ImageFolder(root = 'data/flowers/', loader=flower_loader, transform = transforms.ToTensor())

      return flowers
    
    elif name == 'coil':
      coil_100 = torchvision.datasets.ImageFolder(root = 'data/coil/', transform=transforms.Compose([transforms.ToTensor(),
                                                                                          transforms.Resize(32)]))
      
      return coil100
    
    else:
      raise ValueError("Invalid OOD Dataset")
    

## CIFAR100 superclasses

In [None]:
def sparse2coarse(targets):
    """Convert Pytorch CIFAR100 sparse targets to coarse targets.
    Usage:
        trainset = torchvision.datasets.CIFAR100(path)
        trainset.targets = sparse2coarse(trainset.targets)
    """
    coarse_labels = np.array([4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
                              3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
                              6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
                              0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
                              5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
                              16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
                              10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
                              2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
                              16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
                              18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
    return coarse_labels[targets]

## DataLoaders

In [None]:
import torchvision
from torchvision.datasets import CIFAR10, CIFAR100, MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import random

dataset_class = {
    'cifar10': CIFAR10,
    'cifar100': CIFAR100,
    'mnist': MNIST
}

def get_loaders():
    
    in_train_dataset = dataset_class[in_dataset](root='.', download=True, train=True)
    
    if in_dataset == 'cifar100':
        in_train_dataset.targets = sparse2coarse(in_train_dataset.targets)

    exposure_transformations = transforms.Compose([transforms.Resize([32, 32]),
                           transforms.RandomHorizontalFlip(),
                           transforms.RandomGrayscale(),
                           transforms.RandomChoice(
                           [transforms.RandomApply([transforms.RandomAffine(90, translate=(0.15, 0.15), scale=(0.85, 1), shear=None)], p=0.6),
                           transforms.RandomApply([transforms.RandomAffine(0, translate=None, scale=(0.5, 0.75), shear=30)], p=0.6),
                           transforms.RandomApply([transforms.AutoAugment()], p=0.9),]),
                           transforms.ToTensor()])
      
    out_train_dataset = AdaptiveOutliers(transform=exposure_transformations)
    
    train_dataset = MergedDataset(in_train_dataset, out_train_dataset, num_classes=num_classes)

    in_test_dataset = dataset_class[in_dataset](root='.', download=True, train=False)
    
    in_test_dataset.data = [x for x in in_test_dataset.data]
    
    if in_dataset == 'cifar100':
        in_test_dataset.targets = sparse2coarse(in_test_dataset.targets)
    
    if out_dataset in ['cifar10', 'cifar100', 'mnist']:
        out_test_dataset = dataset_class[out_dataset](root='.', download=True, train=False, transform=
                                                      transforms.Compose([transforms.Grayscale(3), transforms.Resize(32), transforms.ToTensor()]) if out_dataset == 'mnist' else transforms.ToTensor())
    else:
        out_test_dataset = get_out_testing_datasets([out_dataset])[1][0]
      
    out_test_dataset.data = [x for x, _ in out_test_dataset]
    
    test_dataset = MergedDataset(in_test_dataset, out_test_dataset, num_classes=num_classes)
    
    trainloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    testloader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size)


    print("Length of In train dataset:", len(in_train_dataset))
    print("Length of Out train dataset:", len(out_train_dataset))
    
    
    print("Length of In test dataset:", len(in_test_dataset))
    print("Length of Out test dataset:", len(out_test_dataset))
    
    print(f"Length of train dataset: {len(train_dataset)}")
    print(f"Length of test dataset: {len(test_dataset)}")
    
    return trainloader, testloader

# Utils

In [None]:
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score

def auc_softmax_adversarial(model, test_loader, test_attack, epoch:int, device, num_classes):

    is_train = model.training
    model.eval()

    soft = torch.nn.Softmax(dim=1)
    anomaly_scores = []
    preds = []
    test_labels = []

    with tqdm(test_loader, unit="batch") as tepoch:
        torch.cuda.empty_cache()
        for i, (data, target) in enumerate(tepoch):
            data, target = data.to(device), target.to(device)

            adv_data = test_attack(data, target)
            output = model(adv_data)

            predictions = output.argmax(dim=1, keepdim=True).squeeze()
            preds += predictions.detach().cpu().numpy().tolist()

            probs = soft(output).squeeze()
            anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist()

            target = target == num_classes
            
            test_labels += target.detach().cpu().numpy().tolist()

    auc = roc_auc_score(test_labels, anomaly_scores)
    accuracy = accuracy_score(test_labels, preds, normalize=True)

    if is_train:
        model.train()
    else:
        model.eval()

    return auc, accuracy

def auc_softmax(model, test_loader, epoch:int, device, num_classes):

    is_train = model.training
    model.eval()

    soft = torch.nn.Softmax(dim=1)
    anomaly_scores = []
    preds = []
    test_labels = []
    
    with torch.no_grad():
        with tqdm(test_loader, unit="batch") as tepoch:
            torch.cuda.empty_cache()
            for i, (data, target) in enumerate(tepoch):
                data, target = data.to(device), target.to(device)
                output = model(data)

                predictions = output.argmax(dim=1, keepdim=True).squeeze()
                preds += predictions.detach().cpu().numpy().tolist()

                probs = soft(output).squeeze()
                anomaly_scores += probs[:, num_classes].detach().cpu().numpy().tolist()

                target = (target == num_classes).long()
                
                test_labels += target.detach().cpu().numpy().tolist()

    auc = roc_auc_score(test_labels, anomaly_scores)
    accuracy = accuracy_score(test_labels, preds, normalize=True)

    if is_train:
        model.train()
    else:
        model.eval()

    return auc, accuracy

lr_schedule = lambda learning_rate, t, max_epochs: np.interp([t], [0, max_epochs // 3, max_epochs * 2 // 3, max_epochs], [learning_rate, learning_rate/10, learning_rate / 100, 0])[0]
    
    
def run(model, train_attack, test_attack, trainloader, testloader, test_step:int, max_epochs:int, device, loss_threshold=1e-3, num_classes=10, lr=0.01, optim=None):

    if optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    elif optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    
    criterion = nn.CrossEntropyLoss()
    init_epoch = 0

    clean_aucs = []
    adv_aucs = []
    
    print(f'Starting Run from epoch {init_epoch}')
    
    train_loss = 0
    
    for epoch in range(init_epoch, max_epochs):

        
        torch.cuda.empty_cache()

        
        print(f'====== Starting Training on epoch {epoch}')
        train_accuracy, train_loss = train_one_epoch(epoch=epoch,\
                                                                max_epochs=max_epochs, \
                                                                model=model,\
                                                                optimizer=optimizer,
                                                                criterion=criterion,\
                                                                trainloader=trainloader,\
                                                                train_attack=train_attack,\
                                                                lr=lr,\
                                                                device=device)

        print("train accuracy is ", train_accuracy)
        print("train loss is ", train_loss)
        
        
        if (epoch + 1)%1 == 0:
            save_model_checkpoint(model, train_loss, f'./{epoch}_model_{in_dataset}.pt', optimizer)
        
        if epoch % test_step == 0 :

            test_auc = {}
            test_accuracy = {}

            print(f'AUC & Accuracy Vanila - Started...')
            clean_auc, clean_accuracy  = auc_softmax(model=model, epoch=epoch, test_loader=testloader, device=device, num_classes=num_classes)
            test_auc['Clean'], test_accuracy['Clean'] = clean_auc, clean_accuracy
            print(f'AUC Vanila - score on epoch {epoch} is: {clean_auc * 100}')
            print(f'Accuracy Vanila -  score on epoch {epoch} is: {clean_accuracy * 100}')

            attack_name = 'PGD-10'
            attack = test_attack
            print(f'AUC & Accuracy Adversarial - {attack_name} - Started...')
            adv_auc, adv_accuracy = auc_softmax_adversarial(model=model, epoch=epoch, test_loader=testloader, test_attack=attack, device=device, num_classes=num_classes)
            print(f'AUC Adversairal {attack_name} - score on epoch {epoch} is: {adv_auc * 100}')
            print(f'Accuracy Adversairal {attack_name} -  score on epoch {epoch} is: {adv_accuracy * 100}')

        
        if train_loss < loss_threshold:
            break


        torch.cuda.empty_cache()

        clean_aucs.append(clean_auc)
        adv_aucs.append(adv_auc)


    save_model_checkpoint(model, train_loss, f'./last_model_{in_dataset}.pt', optimizer)

    return clean_aucs, adv_aucs



def train_one_epoch(epoch, max_epochs, model, optimizer, criterion, trainloader, train_attack, lr, device):

    soft = torch.nn.Softmax(dim=1)

    preds = []
    true_labels = []
    running_loss = 0
    accuracy = 0

    model.train()
    with tqdm(trainloader, unit="batch") as tepoch:
        torch.cuda.empty_cache()
        for i, (data, target) in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch + 1}/{max_epochs}")
            updated_lr = lr_schedule(learning_rate=lr, t=epoch + (i + 1) / len(tepoch), max_epochs=max_epochs)
            optimizer.param_groups[0].update(lr=updated_lr)

            data, target = data.to(device), target.to(device)
            target = target.type(torch.LongTensor).cuda()

            # Adversarial attack on every batch
            data = train_attack(data, target)

            # Zero gradients for every batch
            optimizer.zero_grad()

        
            output = model(data)

            # Compute the loss and its gradients
            loss = criterion(output, target)
            loss.backward()

            # Adjust learning weights
            optimizer.step()

            true_labels += target.detach().cpu().numpy().tolist()

            predictions = output.argmax(dim=1, keepdim=True).squeeze()
            preds += predictions.detach().cpu().numpy().tolist()
            correct = (torch.tensor(preds) == torch.tensor(true_labels)).sum().item()
            accuracy = correct / len(preds)

            probs = soft(output).squeeze()

            running_loss += loss.item() * data.size(0)

            tepoch.set_postfix(loss=running_loss / len(preds), accuracy=100. * accuracy)

    return  accuracy_score(true_labels, preds, normalize=True), \
            running_loss / len(preds)


def save_model_checkpoint(model, loss, path, optimizer):
    try:
        torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
        }, path)
    except:
        raise ValueError('Saving model checkpoint failed!')

def load_model_checkpoint(model, optimizer, path):
    try:
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        return model, optimizer, epoch, loss
    except:
        return None


# Visualization

In [None]:
from matplotlib import pyplot as plt

def visualize_samples(dataloader, n, title="Sample"):
    normal_samples = []
    abnormal_samples = []

    def to_3_channels(image):
        if image.shape[0] == 1:
            return image.repeat(3, 1, 1)
        return image

    # Collect n x n samples
    for images, labels in dataloader:
        for i, l in enumerate(labels):
            image = to_3_channels(images[i])
            if len(normal_samples) < n * n and l == 0:
                normal_samples.append(image)
            elif len(abnormal_samples) < n * n and l != 0:
                abnormal_samples.append(image)
            if len(normal_samples) == n * n and len(abnormal_samples) == n * n:
                break
        if len(normal_samples) == n * n and len(abnormal_samples) == n * n:
            break

    normal_grid = torchvision.utils.make_grid(normal_samples, nrow=n)
    abnormal_grid = torchvision.utils.make_grid(abnormal_samples, nrow=n)


    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(18, 8))
    fig.patch.set_alpha(0)
    fig.suptitle(title, fontsize=16)

    axs[0].imshow(normal_grid.permute(1, 2, 0))
    axs[0].set_title('Normal', fontsize=14)
    axs[0].axis('off')

    axs[1].imshow(abnormal_grid.permute(1, 2, 0))
    axs[1].set_title('Abnormal', fontsize=14)
    axs[1].axis('off')

    plt.show()

# Training

In [None]:
import torch

device = 'cuda:0'

trainloader, testloader = get_loaders()
model = WideResNet(40, num_classes+1, 4,  dropRate=0.0).to(device)

train_attack1 = PGD_CLS(model, eps=attack_eps, steps=10, alpha=attack_alpha)
test_attack = PGD_TEST(model, eps=attack_eps, steps=10, alpha=attack_alpha, num_classes=num_classes)

device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")

clean_aucs, adv_aucs = run(model, train_attack1, test_attack, trainloader, testloader, 1, epochs, device, loss_threshold=1e-3, num_classes=num_classes,lr=lr, optim=optim)

# Testing

In [None]:
attack = PGD_TEST(model, eps=attack_eps, steps=100, alpha=attack_alpha, num_classes=num_classes)

adv_sm = 0
clean_sm = 0

for out_dataset in avialable_datasets:

    if out_dataset == in_dataset:
        continue

    trainloader, testloader = get_loaders()
    
    visualize_samples(testloader, 8, out_dataset)
    
    clean_auc, clean_accuracy  = auc_softmax(model=model, epoch=epochs, test_loader=testloader, device=device, num_classes=num_classes)
    print(f"Clean AUC for (In={in_dataset}) and (Out={out_dataset}) is {int(clean_auc * 10000)/100}")
    adv_auc, adv_accuracy = auc_softmax_adversarial(model=model, epoch=epochs, test_loader=testloader, test_attack=attack, device=device, num_classes=num_classes)
    print(f"PGD-10 Adversarial AUC for (In={in_dataset}) and (Out={out_dataset}) is {int(adv_auc * 10000) / 100}")
    adv_sm += adv_auc
    clean_sm += clean_auc

print("Average mean of clean AUC:", clean_sm / (len(avialable_datasets) - 1))
print("Average mean of Adversarial AUC:", adv_sm / (len(avialable_datasets) - 1))