**Imports**

In [1]:
from datetime import datetime
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as trans
import torch.nn.functional as F
from torch.optim import Adam, SGD
from tqdm import tqdm
import random

**Set Device**

In [2]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

**Model Architecture**

In [3]:
class MainConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(MainConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride = stride, padding = padding)
        self.out = None

    def forward(self, x):
        out = self.conv(x)
        self.out = out      # to store hidden layer ouput
        return out

class ResNetBlock(nn.Module):
    # ResNet basic block
    def __init__(self, in_channels, downsample):
        super(ResNetBlock, self).__init__()
        if downsample:
            stride = 2
            out_channels = 2 * in_channels
        else:
            stride = 1
            out_channels = in_channels
        self.conv1 = MainConv(in_channels, out_channels, 3, stride = stride, padding = 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = MainConv(out_channels, out_channels, 3, stride = 1, padding = 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.convr = nn.Conv2d(in_channels, out_channels, 1, stride = stride, padding = 0) if downsample else None
        self.relu = nn.ReLU(inplace = True)

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        if self.convr is not None:
            x = self.convr(x)
        return self.relu(x + y)

class ResNet20(nn.Module):
    def __init__(self, n_classes):
        super(ResNet20, self).__init__()
        blocks = [MainConv(3, 16, 3, stride = 1, padding = 1), nn.ReLU(inplace = True)]
        in_channels = 16
        for i in range(9):
            if i > 0 and i % 3 == 0:
                blocks.append(ResNetBlock(in_channels, True))
                in_channels *= 2
            else:
                blocks.append(ResNetBlock(in_channels, False))
        blocks += [nn.AvgPool2d(8), nn.Flatten(), nn.Linear(in_channels, n_classes)]
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        return self.blocks(x)

class PARLModel(nn.Module):
    def __init__(self, n_classes = 10, n_ensemble = 3):
        # n_ensemble: Number of classifiers in the ensemble
        super(PARLModel, self).__init__()
        self.nets = nn.ModuleList([ResNet20(n_classes) for _ in range(n_ensemble)])
        self.n_ensemble = n_ensemble

    def forward(self, x):
        return [net(x) for net in self.nets]

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Linear):
                stdv = 1. / math.sqrt(m.weight.size(1))
                m.weight.data.uniform_(-stdv, stdv)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

**Attack Codes and Utils**

In [4]:
def FGM(model, X, y = None, eps = 0.03, loss_fn = nn.CrossEntropyLoss(), clip = [0.0, 1.0]):
    model.eval()
    X = X.to(device)
    if y is None:
        y_pred_ens = model(X)
        y_pred_ens = [torch.argmax(y_pred, -1) for y_pred in y_pred_ens]
        y = torch.mode(torch.stack(y_pred_ens, 1)).values
    else:
        y = y.to(device)
    clip = [clip[0].to(device), clip[1].to(device)]

    X.requires_grad_(True)
    y_pred_ens = model(X)
    # loss = sum([loss_fn(nn.functional.softmax(y_pred, dim = -1), y) for y_pred in y_pred_ens]) / len(y_pred_ens)
    loss = sum([loss_fn(y_pred, y) for y_pred in y_pred_ens]) / len(y_pred_ens)
    loss.backward()
    grad_sign = torch.sign(X.grad)
    X.grad.zero_()
    with torch.no_grad():
        X_adv = X + eps * grad_sign
        X_adv = torch.maximum(torch.minimum(X_adv, X + eps), X - eps)
        X_adv.clamp_(clip[0], clip[1])

    return X_adv

def PGD(model, X, y = None, eps = 0.03, n_iter = 50, eps_iter = None, loss_fn = nn.CrossEntropyLoss(), clip = [0.0, 1.0]):
    model.eval()
    X = X.to(device)
    if y is None:
        y_pred_ens = model(X)
        y_pred_ens = [torch.argmax(y_pred, -1) for y_pred in y_pred_ens]
        y = torch.mode(torch.stack(y_pred_ens, 1)).values
    else:
        y = y.to(device)
    if eps_iter is None:
        eps_iter = eps / n_iter
    clip = [clip[0].to(device), clip[1].to(device)]

    pertb = torch.rand_like(X) * 2 * eps - eps
    X_adv = X.clone().detach() + pertb
    X_adv.clamp_(clip[0], clip[1])

    for _ in tqdm(range(n_iter), desc = 'Iteration'):
        X_adv = FGM(model, X_adv, y = y, eps = eps_iter, loss_fn = loss_fn, clip = clip)
        with torch.no_grad():
            X_adv = torch.maximum(torch.minimum(X_adv, X + eps), X - eps)

    return X_adv

def MDI2_FGSM(model, X, y = None, eps = 0.03, n_iter = 50, eps_iter = None, decay = 0.1, prob = 0.1, resize_frac = 0.9, loss_fn = nn.CrossEntropyLoss(), clip = [0.0, 1.0]):
    model.eval()
    X = X.to(device)
    if y is None:
        y_pred_ens = model(X)
        y_pred_ens = [torch.argmax(y_pred, -1) for y_pred in y_pred_ens]
        y = torch.mode(torch.stack(y_pred_ens, 1)).values
    else:
        y = y.to(device)
    if eps_iter is None:
        eps_iter = eps / n_iter
    clip = [clip[0].to(device), clip[1].to(device)]

    X_adv = X.clone().detach()
    gn = 0
    img_dims = X.shape[-1]
    max_pad = max(int(math.floor((1 - resize_frac) * img_dims)), 1)
    possible_pad = list(range(1, max_pad + 1))


    for _ in tqdm(range(n_iter), desc = 'Iteration'):
        X_adv.requires_grad_(True)
        pad_i = random.choice(possible_pad)
        left_pad = random.choice(range(pad_i + 1))
        top_pad = random.choice(range(pad_i + 1))
        right_pad = pad_i - left_pad
        bottom_pad = pad_i - top_pad
        mask = torch.where(torch.rand((len(X), 1, 1, 1)) < prob, 1.0, 0.0).to(device)
        X_r = trans.Resize((img_dims - pad_i, img_dims - pad_i), antialias = True)(X_adv)
        X_T = torch.nn.functional.pad(X_r, (left_pad, right_pad, top_pad, bottom_pad))
        X_T = mask * X_T + (1 - mask) * X_adv
        y_pred_ens = model(X_T)
        loss = sum([loss_fn(y_pred, y) for y_pred in y_pred_ens]) / len(y_pred_ens)
        loss.backward()
        grad = X_adv.grad.clone()
        X_adv.grad.zero_()

        with torch.no_grad():
            gn = decay * gn + grad / (torch.norm(grad, p = 1, dim = (1, 2, 3), keepdim = True) + 1e-12)
            X_adv = X_adv + eps_iter * torch.sign(gn)
            X_adv = torch.maximum(torch.minimum(X_adv, X + eps), X - eps)
            X_adv.clamp_(clip[0], clip[1])

    return X_adv

def SGM(model, X, y = None, eps = 0.03, gamma = 0.5, loss_fn = nn.CrossEntropyLoss(), clip = [0.0, 1.0]):
    # This code is specific to our ResNet20 implementation
    model.eval()
    X = X.to(device)
    if y is None:
        y_pred_ens = model(X)
        y_pred_ens = [torch.argmax(y_pred, -1) for y_pred in y_pred_ens]
        y = torch.mode(torch.stack(y_pred_ens, 1)).values
    else:
        y = y.to(device)
    clip = [clip[0].to(device), clip[1].to(device)]
    X.requires_grad_(True)
    cargo = [model.nets[i].blocks[1](model.nets[i].blocks[0](X)) for i in range(model.n_ensemble)]
    ensemble = [model.nets[i].modules() for i in range(model.n_ensemble)]

    k = 0
    z, skip_out, conv_out = [None for _ in range(4)], [None for _ in range(3)], [None for _ in range(3)]
    for M in zip(*ensemble):
        if isinstance(M[0], ResNetBlock):
            k += 1
            if k < 7:
                cargo = [M[i](cargo[i]) for i in range(model.n_ensemble)]
            else:
                z[k - 7] = cargo
                if k > 7:
                    cargo = [nn.functional.relu(c) for c in cargo]
                skip_out[k - 7] = [M[i].convr(cargo[i]) if M[i].convr is not None else cargo[i] for i in range(model.n_ensemble)]
                conv_out[k - 7] = [M[i].bn2(M[i].conv2(M[i].relu(M[i].bn1(M[i].conv1(cargo[i]))))) for i in range(model.n_ensemble)]
                cargo = [torch.add(skip_out[k - 7][i], conv_out[k - 7][i]) for i in range(model.n_ensemble)]

    z[3] = cargo
    cargo = [nn.functional.relu(c) for c in cargo]
    y_pred_ens = [model.nets[i].blocks[-1](model.nets[i].blocks[-2](model.nets[i].blocks[-3](cargo[i]))) for i in range(model.n_ensemble)]
    loss = [loss_fn(y_pred, y) for y_pred in y_pred_ens]

    dz_dx, d_skip_out_dx, d_conv_out_dx, d_loss_dx = [[None for __ in range(model.n_ensemble)] for _ in range(4)], [[None for __ in range(model.n_ensemble)] for _ in range(3)], [[None for __ in range(model.n_ensemble)] for _ in range(3)], [[None for __ in range(model.n_ensemble)] for _ in range(3)]
    grad1, grad2, grad3, grad4 = [None for _ in range(3)], [None for _ in range(3)], [None for _ in range(3)], [None for _ in range(3)]
    for i in range(model.n_ensemble):
        for j in range(4):
            z[j][i].backward(torch.ones_like(z[j][i]), retain_graph = True)
            dz_dx[j][i] = X.grad.clone()
            X.grad.zero_()

        for j in range(3):
            skip_out[j][i].backward(torch.ones_like(skip_out[j][i]), retain_graph = True)
            d_skip_out_dx[j][i] = X.grad.clone()
            X.grad.zero_()

        for j in range(3):
            conv_out[j][i].backward(torch.ones_like(conv_out[j][i]), retain_graph = True)
            d_conv_out_dx[j][i] = X.grad.clone()
            X.grad.zero_()

        loss[i].backward(retain_graph = True)
        d_loss_dx[i] = X.grad.clone()
        X.grad.zero_()

        grad1[i] = (d_skip_out_dx[0][i] + gamma * d_conv_out_dx[0][i]) / dz_dx[0][i]
        grad2[i] = (d_skip_out_dx[1][i] + gamma * d_conv_out_dx[1][i]) / dz_dx[1][i]
        grad3[i] = (d_skip_out_dx[2][i] + gamma * d_conv_out_dx[2][i]) / dz_dx[2][i]
        grad4[i] = d_loss_dx[i] / dz_dx[3][i]

    overall_grad = sum([dz_dx[0][i] * grad1[i] * grad2[i] * grad3[i] * grad4[i] for i in range(model.n_ensemble)]) / model.n_ensemble
    grad_sign = torch.sign(overall_grad)
    with torch.no_grad():
        X_adv = X + eps * grad_sign
        X_adv = torch.maximum(torch.minimum(X_adv, X + eps), X - eps)
        X_adv.clamp_(clip[0], clip[1])

    return X_adv

def multi_attack_eval(model, x, y, show_indiv = False):
    y_true = y.to(device)
    X = [xi.to(device) for xi in x]
    success = [torch.ones_like(y_true) for _ in range(model.n_ensemble + 1)]
    for Xi in X:
        y_pred_ens = model(Xi)
        y_pred_ens = [torch.argmax(y_pred, -1) for y_pred in y_pred_ens]
        y_pred = torch.mode(torch.stack(y_pred_ens, 1)).values
        success[0] = success[0] * torch.where(y_pred - y_true == 0, 1, 0)
        for k, y_pred in enumerate(y_pred_ens):
            success[k + 1] = success[k + 1] * torch.where(y_pred - y_true == 0, 1, 0)
    acc = [torch.sum(s).item() for s in success]
    print(f'Combined Ensemble Accuracy: \033[92m{np.around(100 * acc[0] / len(y_true), 2)}%\033[0m')
    if show_indiv:
        for i in range(1, model.n_ensemble + 1):
            print(f'Model-{i} Accuracy: {np.around(100 * acc[i] / len(y_true), 2)}%')

**Load Clean Data**

In [5]:
dataset = 'CIFAR-10' # CIFAR-10, CIFAR-100

if dataset == 'CIFAR-10':
    mean = torch.tensor([0.4942, 0.4851, 0.4504]).unsqueeze_(0).unsqueeze(-1).unsqueeze_(-1).to(device)
    std = torch.tensor([0.2467, 0.2429, 0.2616]).unsqueeze_(0).unsqueeze(-1).unsqueeze_(-1).to(device)
    n_classes = 10
    avg_std = 0.2504
elif dataset == 'CIFAR-100':
    mean = torch.tensor([0.5088, 0.4874, 0.4419]).unsqueeze_(0).unsqueeze(-1).unsqueeze_(-1).to(device)
    std = torch.tensor([0.2683, 0.2574, 0.2771]).unsqueeze_(0).unsqueeze(-1).unsqueeze_(-1).to(device)
    n_classes = 100
    avg_std = 0.2676

clip_max = (1.0 - mean) / std
clip_min = (0.0 - mean) / std

x_clean = torch.from_numpy(np.load(f'{dataset}_x_clean_1000.npy'))
y_clean = torch.from_numpy(np.load(f'{dataset}_y_clean_1000.npy'))

**Load Pre-trained Models**

In [11]:
n_ensemble = 3 # Number of classifiers in the enemsemble

# Ensemble trained with PARL
target_model = PARLModel(n_classes = n_classes, n_ensemble = n_ensemble).to(device)
target_model.load_state_dict(torch.load(f'PARL5_{dataset}_ResNet20_3_50epochs_gamma_0.25.pth', map_location = torch.device(device)))

# Surrogate ensemble trained with cross entropy loss only, on which the adversarial examples are crafted
surrogate_model_1 = PARLModel(n_classes = n_classes, n_ensemble = n_ensemble).to(device)
surrogate_model_1.load_state_dict(torch.load(f'Surrogate1_{dataset}_ResNet20.pth', map_location = torch.device(device)))

# Surrogate ensemble trained with cross entropy loss only, which gives the black box robust accuracy of surrogate model
surrogate_model_2 = PARLModel(n_classes = n_classes, n_ensemble = n_ensemble).to(device)
surrogate_model_2.load_state_dict(torch.load(f'Surrogate2_{dataset}_ResNet20.pth', map_location = torch.device(device)))

<All keys matched successfully>

**Perform PGD Attack**

In [None]:
for eps in [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07]:
    print(f'Epsilon {eps}')
    for r in range(1, 4):
        print(f'Restart-{r}: ', end = '', flush = True)
        # since the input is in normalized form, we should adjust eps and eps_iter accordingly
        x_adv = PGD(surrogate_model_1, x_clean, eps = eps / avg_std, n_iter = 100, eps_iter = eps / (5 * avg_std), clip = [clip_min, clip_max])
        np.save(f'{dataset}_ResNet20_X_PGD_R{r}_00{int(100 * eps)}.npy', x_adv.cpu().numpy())
        x_adv = None
        torch.cuda.empty_cache()

**Perform MDI2-FGSM Attack**

In [None]:
for eps in [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07]:
    print(f'Epsilon {eps}')
    # since the input is in normalized form, we should adjust eps and eps_iter accordingly
    x_adv = MDI2_FGSM(surrogate_model_1, x_clean, eps = eps / avg_std, n_iter = 100, eps_iter = eps / (5 * avg_std), clip = [clip_min, clip_max])
    np.save(f'{dataset}_ResNet20_X_MDI2-FGSM_00{int(100 * eps)}.npy', x_adv.cpu().numpy())
    x_adv = None
    torch.cuda.empty_cache()

**Perform SGM Attack**

In [None]:
for eps in [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07]:
    print(f'Epsilon {eps}\n')
    # since the input is in normalized form, we should adjust eps and eps_iter accordingly
    x_adv = SGM(surrogate_model_1, x_clean, eps = eps / avg_std, clip = [clip_min, clip_max])
    np.save(f'{dataset}_ResNet20_X_SGM_00{int(100 * eps)}.npy', x_adv.cpu().numpy())
    x_adv = None
    torch.cuda.empty_cache()

**Perform All Attacks**

In [None]:
show_indiv = True

for eps in [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07]:
    print(f'\nEpsilon {eps}\n')
    sample_list = [
        f'{dataset}_ResNet20_X_PGD_R1_00{int(100 * eps)}.npy',
        f'{dataset}_ResNet20_X_PGD_R2_00{int(100 * eps)}.npy',
        f'{dataset}_ResNet20_X_PGD_R3_00{int(100 * eps)}.npy',
        f'{dataset}_ResNet20_X_MDI2-FGSM_00{int(100 * eps)}.npy',
        f'{dataset}_ResNet20_X_SGM_00{int(100 * eps)}.npy',
    ]
    x_adv = [torch.from_numpy(np.load(sample)) for sample in sample_list]
    print()
    print('SURROGATE-1: ', end = '')
    multi_attack_eval(surrogate_model_1, x_adv, y_clean, show_indiv = show_indiv)
    print()
    print('SURROGATE-2: ', end = '')
    multi_attack_eval(surrogate_model_2, x_adv, y_clean, show_indiv = show_indiv)
    print()
    print('TARGET     : ', end = '')
    multi_attack_eval(target_model, x_adv, y_clean, show_indiv = show_indiv)
    print()
    x_adv = None
    torch.cuda.empty_cache()