In [None]:
import os
import sys
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0";
#from pudb import set_trace
import numpy as np
import torch
from torchvision import models
from PIL import Image
from skimage.io import imsave
import torch.backends.cudnn as cudnn
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from utils_rb import *
from ramboattack import RamBoAtt
from HSJA_rb import HSJA
from SignOPT_rb import OPT_attack_sign_SGD
def imshow(img):
    
    npimg = img[0].cpu().numpy()
    npimg = np.transpose(npimg,(1, 2, 0))
    npimg = np.clip(npimg, 0, 1)
    plt.imshow(npimg)
    plt.show()
def perturbation_heat_map(xo,xa):
    
    fig_dims = (5, 5)
    fig, ax = plt.subplots(figsize=fig_dims)
    x=torch.abs(xo-xa).sum(dim=1).cpu()[0]
    sns.heatmap(x,ax=ax,xticklabels=False, yticklabels=False,cbar = False)
    plt.show()
# a. Load dataset

batch_size = 1
dataset = 'cifar10' 
datapath = '../datasets/cifar10'
testloader, testset = load_data(dataset,data_path=datapath,batch_size=batch_size)
# b. Load pre-trained model

# 'resnet50' if pre-trained model from Pytorch. 'cifar10' if using pre-trained cifar10 model
arch = 'cifar10' 

# None means using pre-traineded model from Pytorch or default path. Otherwise, please change model_path = '...'
model_path = None 

# True means pre-trained model does "not" normalized data while training, 
# so no need to unnorm during intergerence (used for CIFAR10 model)

if dataset == 'cifar10':
    num_classes = 10
    unnorm = True # True means pre-trained model does "not" normalized data while training.
    
net = load_model(arch,model_path)
model_rb = PretrainedModel(net,dataset,unnorm)

bounds = [0,1]
model_ex = PytorchModel_ex(net, bounds, num_classes,dataset,unnorm)
# c. Load evaluation set
targeted = True # True means targeted attack. False means untargeted attack.
# 'balance', 'easyset'->imagenet or cifar10; 
# 'hardset'-> imagenet; 
# 'hardset_A','hardset_B','hardset_D' -> cifar10
eval_set =  'hardset_B'

ID_set = get_evalset(dataset,targeted,eval_set)


# successful_attacks = 0 # count the number of successful attacks
# query_list=[]
# for _ in range(100): # run the attack 100 times
#     i = 2 # 0,1,2,10,20,50,123: the sample i-th in the evaluation set
#     # Define the list of options
#     options = [0, 1, 2, 10, 20, 50, 123]

#     # Choose a random value from the list
#     i = random.choice(options)
#     query_limit = 50000
#     D = np.zeros(query_limit+2000)
#     nquery = 0
#     o = ID_set[i,1] # oID
#     # 0. select original image
#     oimg, olabel = testset[o]
#     oimg = torch.unsqueeze(oimg, 0).cuda()

#     # 1. select starting image
#     if targeted:
#         x = random.randint(1, 4)
#         y =random.randint(1,4)
#         t = ID_set[i,3] # tID, 3 is index across dataset - 4 is sample index in a class (not across dataset)
#         tlabel = ID_set[i,2]
#         timg, _ = testset[t]
#         timg = torch.unsqueeze(timg, 0).cuda()
#     else:
#         tlabel = None
#     attack_mode = 'RBS' # 'RBH' means RamBoAttack(HSJA) while 'RBS' means RamBoAttack(Sign-OPT) -> see our paper
#     seed = 0
#     query_limit = 50000
#     module = RamBoAtt(model_rb, model_ex, testset, seed, targeted, dataset)
#     adv, nqry, Dt = module.hybrid_attack(oimg, olabel, timg, tlabel, query_limit, attack_mode)
#     print(nqry)
#     query_list.append(nqry)
#     if nqry > 1400: # check if the attack succeeded with more than 1000 queries
#         successful_attacks += 1

# print("Number of successful attacks with more than 1000 queries:", successful_attacks)

i = 2#0,1,2,10,20,50,123 # the sample i-th in the evaluation set
query_limit = 10000
D = np.zeros(query_limit+2000)
nquery = 0
o = ID_set[i,2] #oID

# 0. select original image
oimg, olabel = testset[o]
oimg = torch.unsqueeze(oimg, 0).cuda()

# 1. select starting image
if targeted:
    t = ID_set[i,3] #tID, 3 is index acrross dataset - 4 is sample index in a class (not accross dataset)
    tlabel = ID_set[i,2]
    timg, _ = testset[t]
    timg = torch.unsqueeze(timg, 0).cuda()
else:
    tlabel = None
    
    
attack_mode = 'RBS' # 'RBH' means RamBoAttack(HSJA) while 'RBS' means RamBoAttack(Sign-OPT) -> see our paper
seed = 0
query_limit = 10000
module = RamBoAtt(model_rb,model_ex,testset,seed,targeted,dataset)
adv, nqry, Dt = module.hybrid_attack(oimg,olabel,timg,tlabel,query_limit,attack_mode)


def save_img(img, filename):
    npimg = img[0].cpu().numpy()
    npimg = np.transpose(npimg,(1, 2, 0))
    npimg = np.clip(npimg, 0, 1)
    pil_img = Image.fromarray(np.uint8(npimg*255))
    pil_img.save(filename)
    
print('Source image:')
imshow(oimg)
#oimg.save('Source_image.jpg')
print('Starting image:')
timg, _ = testset[t]
timg = torch.unsqueeze(timg, 0).cuda()
imshow(timg)

print('Adversarial Example:')
save_img(adv,'adv.jpg')
#oimg.save('Adversarial Example.jpg')
print('Perturbation Heat Map:')
perturbation_heat_map(oimg,adv)