Skip to content

Commit

Permalink
Rename target layer (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Meudec committed Aug 1, 2019
1 parent 67b0f93 commit 5d5b939
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/callbacks/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
img_input = tf.keras.Input(INPUT_SHAPE)

x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(img_input)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', name='grad_cam_target')(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', name='target_layer')(x)
x = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(x)

x = tf.keras.layers.Dropout(0.25)(x)
Expand Down Expand Up @@ -57,9 +57,9 @@
# Instantiate callbacks
# class_index value should match the validation_data selected above
callbacks = [
tf_explain.callbacks.GradCAMCallback(validation_class_zero, 'grad_cam_target', class_index=0),
tf_explain.callbacks.GradCAMCallback(validation_class_fours, 'grad_cam_target', class_index=4),
tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, 'grad_cam_target'),
tf_explain.callbacks.GradCAMCallback(validation_class_zero, 'target_layer', class_index=0),
tf_explain.callbacks.GradCAMCallback(validation_class_fours, 'target_layer', class_index=4),
tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),
tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),
]

Expand Down

0 comments on commit 5d5b939

Please sign in to comment.