-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils_cifar.py
137 lines (110 loc) · 4.31 KB
/
utils_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import yaml
import os
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
def _load_image(path):
"""
Reads image image from the given path and returns an numpy array.
"""
image = np.load(path)
assert image.dtype == np.uint8
assert image.shape == (32, 32, 3)
return image
def _read_image(file_name):
"""
Returns a tuple of image as numpy array and label as int,
given the csv row.
"""
input_folder = "test_images_imagenet/"
img_path = os.path.join(input_folder, file_name)
image = _load_image(img_path)
assert image.dtype == np.uint8
image = image.astype(np.float32)
assert image.dtype == np.float32
return image
def preprocess_cifar(x): #对cifar数据进行预处理
pro_x = x.numpy().astype(np.float32)[0]
pro_x = np.transpose(pro_x, (1, 2, 0))
pro_x = pro_x*255.0
return pro_x
def read_images():
"""
Returns a list containing tuples of images as numpy arrays
and the correspoding label.
In case of an untargeted attack the label is the ground truth label.
In case of a targeted attack the label is the target label.
"""
# filepath = "test_images_imagenet/labels.yml"
# with open(filepath, 'r') as ymlfile:
# data = yaml.load(ymlfile)
# data_key = list(data.keys())
# data_key.sort()
# return [(key, _read_image(key), data[key]) for key in data_key]
#FIXME
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0, 0, 0), (1, 1, 1))])
testset = torchvision.datasets.CIFAR10(root='./cifar_pytorch/', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)
return [(0, preprocess_cifar(data[0]), int(data[1])) for data in testloader]
def check_image(image):
# image should a 32 x 32 x 3 RGB image
assert(isinstance(image, np.ndarray))
assert(image.shape == (32, 32, 3))
if image.dtype == np.float32:
# we accept float32, but only if the values
# are between 0 and 255 and we convert them
# to integers
if image.min() < 0:
logger.warning('clipped value smaller than 0 to 0')
if image.max() > 255:
logger.warning('clipped value greater than 255 to 255')
image = np.clip(image, 0, 255)
image = image.astype(np.uint8)
assert image.dtype == np.uint8
return image
def store_adversarial(file_name, adversarial):
"""
Given the filename, stores the adversarial as .npy file.
"""
if adversarial is not None:
adversarial = check_image(adversarial)
path = os.path.join("avc_results", file_name)
path_without_extension = os.path.splitext(path)[0]
np.save(path_without_extension, adversarial)
def compute_MAD():
def load_image(path):
x = np.load(path)
assert x.shape == (32, 32, 3)
assert x.dtype == np.uint8
return x
def distance(X, Y):
assert X.dtype == np.uint8
assert Y.dtype == np.uint8
X = X.astype(np.float64) / 255
Y = Y.astype(np.float64) / 255
return np.linalg.norm(X - Y)
# distance if no adversarial was found (worst case)
def worst_case_distance(X):
assert X.dtype == np.uint8
worst_case = np.zeros_like(X)
worst_case[X < 128] = 255
return distance(X, worst_case)
distances = []
real_distances = []
for file in os.listdir('avc_results/'):
original = load_image('test_images/{}'.format(file))
try:
adversarial = load_image('avc_results/{}'.format(file))
except AssertionError:
#print('adversarial for {} is invalid'.format(file))
adversarial = None
if adversarial is None:
_distance = float(worst_case_distance(original))
else:
_distance = float(distance(original, adversarial))
real_distances.append(_distance)
real_distances = np.array(real_distances)
distances = real_distances * 255
print("\tMedian Distance: %.6f" %np.median(real_distances))
print("\tMean Distance: %.6f" %np.mean(real_distances))