/
dknn.py
595 lines (505 loc) · 20 KB
/
dknn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
"""
This code reproduces the MNIST results from the paper
Deep k-Nearest Neighbors: Towards Confident, Interpretable and Robust Deep Learning
https://arxiv.org/abs/1803.04765
The LSH backend used in the paper is FALCONN. This script also demonstrates
how to use an alternative backend called FAISS.
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import copy
import os
from bisect import bisect_left
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from six.moves import xrange
import enum
import tensorflow as tf
from cleverhans.attacks import FastGradientMethod
from cleverhans.loss import CrossEntropy
from cleverhans.dataset import MNIST
from cleverhans.model import Model
from cleverhans.picklable_model import MLP, Conv2D, ReLU, Flatten, Linear, Softmax
from cleverhans.train import train
from cleverhans.utils_tf import batch_eval, model_eval
if 'DISPLAY' not in os.environ:
matplotlib.use('Agg')
FLAGS = tf.flags.FLAGS
def make_basic_picklable_cnn(nb_filters=64, nb_classes=10,
input_shape=(None, 28, 28, 1)):
"""The model for the picklable models tutorial.
"""
layers = [Conv2D(nb_filters, (8, 8), (2, 2), "SAME"),
ReLU(),
Conv2D(nb_filters * 2, (6, 6), (2, 2), "VALID"),
ReLU(),
Conv2D(nb_filters * 2, (5, 5), (1, 1), "VALID"),
ReLU(),
Flatten(),
Linear(nb_classes),
Softmax()]
model = MLP(layers, input_shape)
return model
class NearestNeighbor:
class BACKEND(enum.Enum):
FALCONN = 1
FAISS = 2
def __init__(
self,
backend,
dimension,
neighbors,
number_bits,
nb_tables=None,
):
assert backend in NearestNeighbor.BACKEND
self._NEIGHBORS = neighbors
self._BACKEND = backend
if self._BACKEND is NearestNeighbor.BACKEND.FALCONN:
self._init_falconn(
dimension,
number_bits,
nb_tables
)
elif self._BACKEND is NearestNeighbor.BACKEND.FAISS:
self._init_faiss(
dimension,
)
else:
raise NotImplementedError
def _init_falconn(
self,
dimension,
number_bits,
nb_tables,
):
import falconn
assert nb_tables >= self._NEIGHBORS
# LSH parameters
params_cp = falconn.LSHConstructionParameters()
params_cp.dimension = dimension
params_cp.lsh_family = falconn.LSHFamily.CrossPolytope
params_cp.distance_function = falconn.DistanceFunction.EuclideanSquared
params_cp.l = nb_tables
params_cp.num_rotations = 2 # for dense set it to 1; for sparse data set it to 2
params_cp.seed = 5721840
# we want to use all the available threads to set up
params_cp.num_setup_threads = 0
params_cp.storage_hash_table = falconn.StorageHashTable.BitPackedFlatHashTable
# we build number_bits-bit hashes so that each table has
# 2^number_bits bins; a rule of thumb is to have the number
# of bins be the same order of magnitude as the number of data points
falconn.compute_number_of_hash_functions(number_bits, params_cp)
self._falconn_table = falconn.LSHIndex(params_cp)
self._falconn_query_object = None
self._FALCONN_NB_TABLES = nb_tables
def _init_faiss(
self,
dimension,
):
import faiss
res = faiss.StandardGpuResources()
self._faiss_index = faiss.GpuIndexFlatL2(
res,
dimension,
)
def _find_knns_falconn(self, x, output):
# Late falconn query_object construction
# Since I suppose there might be an error
# if table.setup() will be called after
if self._falconn_query_object is None:
self._falconn_query_object = self._falconn_table.construct_query_object()
self._falconn_query_object.set_num_probes(
self._FALCONN_NB_TABLES
)
missing_indices = np.zeros(output.shape, dtype=np.bool)
for i in range(x.shape[0]):
query_res = self._falconn_query_object.find_k_nearest_neighbors(
x[i],
self._NEIGHBORS
)
try:
output[i, :] = query_res
except: # pylint: disable-msg=W0702
# mark missing indices
missing_indices[i, len(query_res):] = True
output[i, :len(query_res)] = query_res
return missing_indices
def _find_knns_faiss(self, x, output):
neighbor_distance, neighbor_index = self._faiss_index.search(
x,
self._NEIGHBORS
)
missing_indices = neighbor_distance == -1
d1 = neighbor_index.reshape(-1)
output.reshape(-1)[
np.logical_not(missing_indices.flatten())
] = d1[
np.logical_not(missing_indices.flatten())
]
return missing_indices
def add(self, x):
if self._BACKEND is NearestNeighbor.BACKEND.FALCONN:
self._falconn_table.setup(x)
elif self._BACKEND is NearestNeighbor.BACKEND.FAISS:
self._faiss_index.add(x)
else:
raise NotImplementedError
def find_knns(self, x, output):
if self._BACKEND is NearestNeighbor.BACKEND.FALCONN:
return self._find_knns_falconn(x, output)
elif self._BACKEND is NearestNeighbor.BACKEND.FAISS:
return self._find_knns_faiss(x, output)
else:
raise NotImplementedError
class DkNNModel(Model):
def __init__(self, neighbors, layers, get_activations, train_data, train_labels,
nb_classes, scope=None, nb_tables=200, number_bits=17):
"""
Implements the DkNN algorithm. See https://arxiv.org/abs/1803.04765 for more details.
:param neighbors: number of neighbors to find per layer.
:param layers: a list of layer names to include in the DkNN.
:param get_activations: a callable that takes a np array and a layer name and returns its activations on the data.
:param train_data: a np array of training data.
:param train_labels: a np vector of training labels.
:param nb_classes: the number of classes in the task.
:param scope: a TF scope that was used to create the underlying model.
:param nb_tables: number of tables used by FALCONN to perform locality-sensitive hashing.
:param number_bits: number of hash bits used by LSH.
"""
super(DkNNModel, self).__init__(nb_classes=nb_classes, scope=scope)
self.neighbors = neighbors
self.nb_tables = nb_tables
self.layers = layers
self.get_activations = get_activations
self.nb_cali = -1
self.calibrated = False
self.number_bits = number_bits
# Compute training data activations
self.nb_train = train_labels.shape[0]
assert self.nb_train == train_data.shape[0]
self.train_activations = get_activations(train_data)
self.train_labels = train_labels
# Build locality-sensitive hashing tables for training representations
self.train_activations_lsh = copy.copy(self.train_activations)
self.init_lsh()
def init_lsh(self):
"""
Initializes locality-sensitive hashing with FALCONN to find nearest neighbors in training data.
"""
self.query_objects = {
} # contains the object that can be queried to find nearest neighbors at each layer.
# mean of training data representation per layer (that needs to be substracted before
# NearestNeighbor).
self.centers = {}
for layer in self.layers:
# Normalize all the lenghts, since we care about the cosine similarity.
self.train_activations_lsh[layer] /= np.linalg.norm(
self.train_activations_lsh[layer], axis=1).reshape(-1, 1)
# Center the dataset and the queries: this improves the performance of LSH quite a bit.
center = np.mean(self.train_activations_lsh[layer], axis=0)
self.train_activations_lsh[layer] -= center
self.centers[layer] = center
print('Constructing the NearestNeighbor table')
self.query_objects[layer] = NearestNeighbor(
backend=FLAGS.nearest_neighbor_backend,
dimension=self.train_activations_lsh[layer].shape[1],
number_bits=self.number_bits,
neighbors=self.neighbors,
nb_tables=self.nb_tables
)
self.query_objects[layer].add(self.train_activations_lsh[layer])
def find_train_knns(self, data_activations):
"""
Given a data_activation dictionary that contains a np array with activations for each layer,
find the knns in the training data.
"""
knns_ind = {}
knns_labels = {}
for layer in self.layers:
# Pre-process representations of data to normalize and remove training data mean.
data_activations_layer = copy.copy(data_activations[layer])
nb_data = data_activations_layer.shape[0]
data_activations_layer /= np.linalg.norm(
data_activations_layer, axis=1).reshape(-1, 1)
data_activations_layer -= self.centers[layer]
# Use FALCONN to find indices of nearest neighbors in training data.
knns_ind[layer] = np.zeros(
(data_activations_layer.shape[0], self.neighbors), dtype=np.int32)
knn_errors = 0
knn_missing_indices = self.query_objects[layer].find_knns(
data_activations_layer,
knns_ind[layer],
)
knn_errors += knn_missing_indices.flatten().sum()
# Find labels of neighbors found in the training data.
knns_labels[layer] = np.zeros((nb_data, self.neighbors), dtype=np.int32)
knns_labels[layer].reshape(-1)[
np.logical_not(
knn_missing_indices.flatten()
)
] = self.train_labels[
knns_ind[layer].reshape(-1)[
np.logical_not(
knn_missing_indices.flatten()
)
]
]
return knns_ind, knns_labels
def nonconformity(self, knns_labels):
"""
Given an dictionary of nb_data x nb_classes dimension, compute the nonconformity of
each candidate label for each data point: i.e. the number of knns whose label is
different from the candidate label.
"""
nb_data = knns_labels[self.layers[0]].shape[0]
knns_not_in_class = np.zeros((nb_data, self.nb_classes), dtype=np.int32)
for i in range(nb_data):
# Compute number of nearest neighbors per class
knns_in_class = np.zeros(
(len(self.layers), self.nb_classes), dtype=np.int32)
for layer_id, layer in enumerate(self.layers):
knns_in_class[layer_id, :] = np.bincount(
knns_labels[layer][i], minlength=self.nb_classes)
# Compute number of knns in other class than class_id
for class_id in range(self.nb_classes):
knns_not_in_class[i, class_id] = np.sum(
knns_in_class) - np.sum(knns_in_class[:, class_id])
return knns_not_in_class
def preds_conf_cred(self, knns_not_in_class):
"""
Given an array of nb_data x nb_classes dimensions, use conformal prediction to compute
the DkNN's prediction, confidence and credibility.
"""
nb_data = knns_not_in_class.shape[0]
preds_knn = np.zeros(nb_data, dtype=np.int32)
confs = np.zeros((nb_data, self.nb_classes), dtype=np.float32)
creds = np.zeros((nb_data, self.nb_classes), dtype=np.float32)
for i in range(nb_data):
# p-value of test input for each class
p_value = np.zeros(self.nb_classes, dtype=np.float32)
for class_id in range(self.nb_classes):
# p-value of (test point, candidate label)
p_value[class_id] = (float(self.nb_cali) - bisect_left(
self.cali_nonconformity, knns_not_in_class[i, class_id])) / float(self.nb_cali)
preds_knn[i] = np.argmax(p_value)
confs[i, preds_knn[i]] = 1. - np.sort(p_value)[-2]
creds[i, preds_knn[i]] = p_value[preds_knn[i]]
return preds_knn, confs, creds
def fprop_np(self, data_np):
"""
Performs a forward pass through the DkNN on an numpy array of data.
"""
if not self.calibrated:
raise ValueError(
"DkNN needs to be calibrated by calling DkNNModel.calibrate method once before inferring.")
data_activations = self.get_activations(data_np)
_, knns_labels = self.find_train_knns(data_activations)
knns_not_in_class = self.nonconformity(knns_labels)
_, _, creds = self.preds_conf_cred(knns_not_in_class)
return creds
def fprop(self, x):
"""
Performs a forward pass through the DkNN on a TF tensor by wrapping
the fprop_np method.
"""
logits = tf.py_func(self.fprop_np, [x], tf.float32)
return {self.O_LOGITS: logits}
def calibrate(self, cali_data, cali_labels):
"""
Runs the DkNN on holdout data to calibrate the credibility metric.
:param cali_data: np array of calibration data.
:param cali_labels: np vector of calibration labels.
"""
self.nb_cali = cali_labels.shape[0]
self.cali_activations = self.get_activations(cali_data)
self.cali_labels = cali_labels
print("Starting calibration of DkNN.")
cali_knns_ind, cali_knns_labels = self.find_train_knns(
self.cali_activations)
assert all([v.shape == (self.nb_cali, self.neighbors)
for v in cali_knns_ind.values()])
assert all([v.shape == (self.nb_cali, self.neighbors)
for v in cali_knns_labels.values()])
cali_knns_not_in_class = self.nonconformity(cali_knns_labels)
cali_knns_not_in_l = np.zeros(self.nb_cali, dtype=np.int32)
for i in range(self.nb_cali):
cali_knns_not_in_l[i] = cali_knns_not_in_class[i, cali_labels[i]]
cali_knns_not_in_l_sorted = np.sort(cali_knns_not_in_l)
self.cali_nonconformity = np.trim_zeros(cali_knns_not_in_l_sorted, trim='f')
self.nb_cali = self.cali_nonconformity.shape[0]
self.calibrated = True
print("DkNN calibration complete.")
def plot_reliability_diagram(confidence, labels, filepath):
"""
Takes in confidence values for predictions and correct
labels for the data, plots a reliability diagram.
:param confidence: nb_samples x nb_classes (e.g., output of softmax)
:param labels: vector of nb_samples
:param filepath: where to save the diagram
:return:
"""
assert len(confidence.shape) == 2
assert len(labels.shape) == 1
assert confidence.shape[0] == labels.shape[0]
print('Saving reliability diagram at: ' + str(filepath))
if confidence.max() <= 1.:
# confidence array is output of softmax
bins_start = [b / 10. for b in xrange(0, 10)]
bins_end = [b / 10. for b in xrange(1, 11)]
bins_center = [(b + .5) / 10. for b in xrange(0, 10)]
preds_conf = np.max(confidence, axis=1)
preds_l = np.argmax(confidence, axis=1)
else:
raise ValueError('Confidence values go above 1.')
print(preds_conf.shape, preds_l.shape)
# Create var for reliability diagram
# Will contain mean accuracies for each bin
reliability_diag = []
num_points = [] # keeps the number of points in each bar
# Find average accuracy per confidence bin
for bin_start, bin_end in zip(bins_start, bins_end):
above = preds_conf >= bin_start
if bin_end == 1.:
below = preds_conf <= bin_end
else:
below = preds_conf < bin_end
mask = np.multiply(above, below)
num_points.append(np.sum(mask))
bin_mean_acc = max(0, np.mean(preds_l[mask] == labels[mask]))
reliability_diag.append(bin_mean_acc)
# Plot diagram
assert len(reliability_diag) == len(bins_center)
print(reliability_diag)
print(bins_center)
print(num_points)
fig, ax1 = plt.subplots()
_ = ax1.bar(bins_center, reliability_diag, width=.1, alpha=0.8)
plt.xlim([0, 1.])
ax1.set_ylim([0, 1.])
ax2 = ax1.twinx()
print(sum(num_points))
ax2.plot(bins_center, num_points, color='r', linestyle='-', linewidth=7.0)
ax2.set_ylabel('Number of points in the data', fontsize=16, color='r')
if len(np.argwhere(confidence[0] != 0.)) == 1:
# This is a DkNN diagram
ax1.set_xlabel('Prediction Credibility', fontsize=16)
else:
# This is a softmax diagram
ax1.set_xlabel('Prediction Confidence', fontsize=16)
ax1.set_ylabel('Prediction Accuracy', fontsize=16)
ax1.tick_params(axis='both', labelsize=14)
ax2.tick_params(axis='both', labelsize=14, colors='r')
fig.tight_layout()
plt.savefig(filepath, bbox_inches='tight')
def get_tensorflow_session():
gpu_options = tf.GPUOptions()
gpu_options.per_process_gpu_memory_fraction=FLAGS.tensorflow_gpu_memory_fraction
sess = tf.Session(
config=tf.ConfigProto(
gpu_options=gpu_options
)
)
return sess
def dknn_tutorial():
# Get MNIST data.
mnist = MNIST()
x_train, y_train = mnist.get_set('train')
x_test, y_test = mnist.get_set('test')
# Use Image Parameters.
img_rows, img_cols, nchannels = x_train.shape[1:4]
nb_classes = y_train.shape[1]
with get_tensorflow_session() as sess:
with tf.variable_scope('dknn'):
# Define input TF placeholder.
x = tf.placeholder(tf.float32, shape=(
None, img_rows, img_cols, nchannels))
y = tf.placeholder(tf.float32, shape=(None, nb_classes))
# Define a model.
model = make_basic_picklable_cnn()
preds = model.get_logits(x)
loss = CrossEntropy(model, smoothing=0.)
# Define the test set accuracy evaluation.
def evaluate():
acc = model_eval(sess, x, y, preds, x_test, y_test,
args={'batch_size': FLAGS.batch_size})
print('Test accuracy on test examples: %0.4f' % acc)
# Train the model
train_params = {'nb_epochs': FLAGS.nb_epochs,
'batch_size': FLAGS.batch_size, 'learning_rate': FLAGS.lr}
train(sess, loss, x_train, y_train, evaluate=evaluate,
args=train_params, var_list=model.get_params())
# Define callable that returns a dictionary of all activations for a dataset
def get_activations(data):
data_activations = {}
for layer in layers:
layer_sym = tf.layers.flatten(model.get_layer(x, layer))
data_activations[layer] = batch_eval(sess, [x], [layer_sym], [data],
args={'batch_size': FLAGS.batch_size})[0]
return data_activations
# Use a holdout of the test set to simulate calibration data for the DkNN.
train_data = x_train
train_labels = np.argmax(y_train, axis=1)
cali_data = x_test[:FLAGS.nb_cali]
y_cali = y_test[:FLAGS.nb_cali]
cali_labels = np.argmax(y_cali, axis=1)
test_data = x_test[FLAGS.nb_cali:]
y_test = y_test[FLAGS.nb_cali:]
# Extract representations for the training and calibration data at each layer of interest to the DkNN.
layers = ['ReLU1', 'ReLU3', 'ReLU5', 'logits']
# Wrap the model into a DkNNModel
dknn = DkNNModel(
FLAGS.neighbors,
layers,
get_activations,
train_data,
train_labels,
nb_classes,
scope='dknn',
number_bits=FLAGS.number_bits
)
dknn.calibrate(cali_data, cali_labels)
# Generate adversarial examples
fgsm = FastGradientMethod(model, sess=sess)
attack_params = {'eps': .25, 'clip_min': 0., 'clip_max': 1.}
adv = sess.run(fgsm.generate(x, **attack_params),
feed_dict={x: test_data})
# Test the DkNN on clean test data and FGSM test data
for data_in, fname in zip([test_data, adv], ['test', 'adv']):
dknn_preds = dknn.fprop_np(data_in)
print(dknn_preds.shape)
print(np.mean(np.argmax(dknn_preds, axis=1) == np.argmax(y_test, axis=1)))
plot_reliability_diagram(dknn_preds, np.argmax(
y_test, axis=1), '/tmp/dknn_' + fname + '.pdf')
return True
def main(argv=None):
assert dknn_tutorial()
if __name__ == '__main__':
tf.flags.DEFINE_integer(
'number_bits',
17,
'number of hash bits used by LSH Index'
)
tf.flags.DEFINE_float(
'tensorflow_gpu_memory_fraction',
0.25,
'amount of the GPU memory to allocate for a tensorflow Session'
)
tf.flags.DEFINE_enum_class(
'nearest_neighbor_backend',
NearestNeighbor.BACKEND.FALCONN,
NearestNeighbor.BACKEND,
'NearestNeighbor backend'
)
tf.flags.DEFINE_integer('nb_epochs', 6, 'Number of epochs to train model')
tf.flags.DEFINE_integer('batch_size', 500, 'Size of training batches')
tf.flags.DEFINE_float('lr', 0.001, 'Learning rate for training')
tf.flags.DEFINE_integer(
'nb_cali', 750, 'Number of calibration points for the DkNN')
tf.flags.DEFINE_integer(
'neighbors', 75, 'Number of neighbors per layer for the DkNN')
tf.app.run()