# Asymmetrical Adversarial Training 
(Integrated classifier)

In [None]:
import numpy as np
import tensorflow as tf

import sys
sys.path.append("defense")

from defense import cifar10_input
from defense.model import Model, BayesClassifier
from defense.eval_utils import *
from defense.pgd_attack import PGDAttackCombined, PGDAttack

## load data

In [2]:
cifar = cifar10_input.CIFAR10Data('defense/cifar10_data')
eval_data = cifar.eval_data

num_eval_examples = 1000
x_test = eval_data.xs.astype(np.float32)[:num_eval_examples]
y_test = eval_data.ys.astype(np.int32)[:num_eval_examples]

## load model

In [None]:
np.random.seed(123)
sess = tf.Session()

classifier = Model(mode='eval', var_scope='classifier')
classifier_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope='classifier')
classifier_saver = tf.train.Saver(var_list=classifier_vars)
classifier_checkpoint = 'models/naturally_trained_prefixed_classifier/checkpoint-70000'

factory = BaseDetectorFactory()
classifier_saver.restore(sess, classifier_checkpoint)
factory.restore_base_detectors(sess)

base_detectors = factory.get_base_detectors()
bayes_classifier = BayesClassifier(base_detectors)

# compute detection thresholds on the test set
nat_accs = get_nat_accs(x_test, y_test, logit_threshs, classifier, base_detectors, sess)

## our targeted PGD attack

In [4]:
eps8_attack_config = {
    'epsilon': 8.0,
    'num_steps': 100,
    'step_size': 2.5 * 8.0 / 100,
    'random_start': True,
    'norm': 'Linf'
}

class PGDAttackOpt(PGDAttack):
    def __init__(self, naive_classifier, base_detector, **kwargs):
        super().__init__(**kwargs)

        self.x_input = tf.placeholder(dtype=tf.float32, shape=[None, 32, 32, 3], name='x_input')
        self.y_input = tf.placeholder(tf.int64, shape=[None], name='y_input')
        clf_logits = naive_classifier.forward(self.x_input)
        det_logits = base_detector.forward(self.x_input)

        label_mask = tf.one_hot(base_detector.target_class, 10, dtype=tf.float32)

        clf_target_logit = tf.reduce_sum(label_mask * clf_logits, axis=1)
        clf_other_logit = tf.reduce_max((1 - label_mask) * clf_logits - 1e4 * label_mask, axis=1)

        det_target_logit = tf.reduce_sum(label_mask * det_logits, axis=1)

        # maximize target logit and minimize 2nd best logit until we have a targeted misclassification
        mask = tf.cast(tf.greater(clf_target_logit - 0.01, clf_other_logit), tf.float32)
        clf_loss = (1-mask) * (clf_target_logit - clf_other_logit)

        # just maximize the target logit for the detector once we have a misclassification
        det_loss = mask * det_target_logit

        self.loss = clf_loss + det_loss
        self.grad = tf.gradients(self.loss, self.x_input)[0]

## multi-targeted attack

In [5]:
opt_adv = x_test.copy()
best_logit = np.asarray([-np.inf] * len(opt_adv))

for i in range(10):
    attack = PGDAttackOpt(classifier,
                          base_detectors[i],
                          **eps8_attack_config)
    
    x_test_adv = attack.batched_perturb(x_test, y_test, sess, batch_size=50)
    
    adv_preds = batched_run(classifier.predictions,
                            classifier.x_input, x_test_adv, sess)
    det_logits = get_det_logits(x_test_adv, adv_preds, base_detectors, sess)
    
    better = (adv_preds != y_test) & (det_logits > best_logit)
    best_logit[better] = det_logits[better]
    opt_adv[better] = x_test_adv[better]
    
    print(i, np.mean(best_logit > -np.inf), np.mean(best_logit[best_logit > -np.inf]))

perturbed 0-50
perturbed 50-100
perturbed 100-150
perturbed 150-200
perturbed 200-250
perturbed 250-300
perturbed 300-350
perturbed 350-400
perturbed 400-450
perturbed 450-500
perturbed 500-550
perturbed 550-600
perturbed 600-650
perturbed 650-700
perturbed 700-750
perturbed 750-800
perturbed 800-850
perturbed 850-900
perturbed 900-950
perturbed 950-1000
0 0.754 -17.185676392572944
perturbed 0-50
perturbed 50-100
perturbed 100-150
perturbed 150-200
perturbed 200-250
perturbed 250-300
perturbed 300-350
perturbed 350-400
perturbed 400-450
perturbed 450-500
perturbed 500-550
perturbed 550-600
perturbed 600-650
perturbed 650-700
perturbed 700-750
perturbed 750-800
perturbed 800-850
perturbed 850-900
perturbed 900-950
perturbed 950-1000
1 0.922 -18.29826464208243
perturbed 0-50
perturbed 50-100
perturbed 100-150
perturbed 150-200
perturbed 200-250
perturbed 250-300
perturbed 300-350
perturbed 350-400
perturbed 400-450
perturbed 450-500
perturbed 500-550
perturbed 550-600
perturbed 600-650
p

## accuracy at 5% FPR

In [12]:
opt_adv_errors = get_adv_errors(opt_adv, y_test, logit_threshs, classifier, base_detectors, sess)
tau = np.max(np.where(nat_accs >= np.max(nat_accs) - 0.05)[0])
print("acc: {:.1f}%".format(100 * (1-opt_adv_errors[tau])))

acc: 14.1%
