In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from torch.autograd import Variable
import numpy as np


# Universal import block 
# Block to get the relative imports working 
import os
import sys 
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import config
import matplotlib.pyplot as plt 
import prebuilt_loss_functions as plf
import loss_functions as lf 
import utils.pytorch_utils as utils
import utils.image_utils as img_utils
import cifar10.cifar_loader as cifar_loader
import cifar10.cifar_resnets as cifar_resnets
import adversarial_attacks as aa
import adversarial_training as advtrain
import adversarial_evaluation as adveval
import utils.checkpoints as checkpoints
import adversarial_perturbations as ap
import adversarial_attacks_refactor as aar 
import spatial_transformers as st 
reload(ap)

In [None]:
""" Goal here is to make sure adversarial training works under the refactored model. It _should_, but it might need 
    a few tweaks 
"""

# Load up dataLoader, classifier, normer 
use_gpu = torch.cuda.is_available()
classifier_net = cifar_loader.load_pretrained_cifar_resnet(flavor=32,
                                                           use_gpu=use_gpu)
classifier_net.eval()
train_loader = cifar_loader.load_cifar_data('train', normalize=False, 
                                            batch_size=16, use_gpu=use_gpu)
val_loader = cifar_loader.load_cifar_data('val', normalize=False, 
                                          batch_size=4, use_gpu=use_gpu)

cifar_normer = utils.DifferentiableNormalize(mean=config.CIFAR10_MEANS,
                                             std=config.CIFAR10_STDS)

examples, labels = next(iter(val_loader))


In [None]:
# Make a threat model and attack object 
reload(advtrain)
reload(aar)
delta_threat = ap.ThreatModel(ap.DeltaAddition, 
                              ap.PerturbationParameters(lp_style='inf',
                                                        lp_bound=8.0 / 255.0))
loss_fxn = plf.VanillaXentropy(classifier_net, normalizer=cifar_normer) # USE A PLF LOSS FXN 
fgsm_attack = aar.FGSM(classifier_net, cifar_normer, delta_threat, loss_fxn)
attack_params = advtrain.AdversarialAttackParameters(fgsm_attack, 1.0, 
                                                     attack_specific_params={'attack_kwargs': {'step_size': 0.1}})
print attack_params

rot_threat = ap.ThreatModel(ap.ParameterizedXformAdv, 
                            ap.PerturbationParameters(lp_style='inf', 
                                                      lp_bound=10.0 / 360,
                                                      xform_class=st.RotationTransform))
loss_fxn_rot = plf.VanillaXentropy(classifier_net, normalizer=cifar_normer) # USE A PLF LOSS FXN 
rot_attack = aar.PGD(classifier_net, cifar_normer, rot_threat, loss_fxn_rot)
rot_attack_params = advtrain.AdversarialAttackParameters(rot_attack, 1.0, 
                                                     attack_specific_params={'attack_kwargs': {'step_size': 0.01}})



In [None]:
print isinstance(attack_params, advtrain.AdversarialAttackParameters)

In [None]:
reload(advtrain)

classifier_net.train()
train_obj = advtrain.AdversarialTraining(classifier_net, cifar_normer, 'refactor_test', 'resnet32')
train_loss = nn.CrossEntropyLoss() # USE A PREBUILT TRAIN LOSS FXN 
train_obj.train(train_loader, 2, train_loss, attack_parameters=[attack_params, rot_attack_params], verbosity="snoop", 
                adversarial_save_dir='foobar')

