/
mnist.py
68 lines (54 loc) · 2.66 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import tensorflow as tf
import tf_explain
INPUT_SHAPE = (28, 28, 1)
NUM_CLASSES = 10
AVAILABLE_DATASETS = {
'mnist': tf.keras.datasets.mnist,
'fashion_mnist': tf.keras.datasets.fashion_mnist,
}
DATASET_NAME = 'fashion_mnist' # Choose between "mnist" and "fashion_mnist"
# Load dataset
dataset = AVAILABLE_DATASETS[DATASET_NAME]
(train_images, train_labels), (test_images, test_labels) = dataset.load_data()
# Convert from (28, 28) images to (28, 28, 1)
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]
# One hot encore labels 0, 1, .., 9 to [0, 0, .., 1, 0, 0]
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=NUM_CLASSES)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=NUM_CLASSES)
# Create model
img_input = tf.keras.Input(INPUT_SHAPE)
x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(img_input)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', name='target_layer')(x)
x = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Dropout(0.25)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)
model = tf.keras.Model(img_input, x)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Select a subset of the validation data to examine
# Here, we choose 5 elements with label "0" == [1, 0, 0, .., 0]
validation_class_zero = (np.array([
el for el, label in zip(test_images, test_labels)
if np.all(label == np.array([1] + [0] * 9))
][0:5]), None)
# Select a subset of the validation data to examine
# Here, we choose 5 elements with label "4" == [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
validation_class_fours = (np.array([
el for el, label in zip(test_images, test_labels)
if np.all(label == np.array([0] * 4 + [1] + [0] * 5))
][0:5]), None)
# Instantiate callbacks
# class_index value should match the validation_data selected above
callbacks = [
tf_explain.callbacks.GradCAMCallback(validation_class_zero, 'target_layer', class_index=0),
tf_explain.callbacks.GradCAMCallback(validation_class_fours, 'target_layer', class_index=4),
tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),
tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),
tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10),
]
# Start training
model.fit(train_images, train_labels, epochs=5, callbacks=callbacks)