Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Problem with keras ResNet50 CAM visualisation #53 #122

Merged
merged 6 commits into from
Aug 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 92 additions & 92 deletions examples/mnist/attention.ipynb

Large diffs are not rendered by default.

742 changes: 742 additions & 0 deletions examples/resnet/attention.ipynb

Large diffs are not rendered by default.

87 changes: 53 additions & 34 deletions examples/vggnet/attention.ipynb

Large diffs are not rendered by default.

125 changes: 86 additions & 39 deletions tests/vis/backend/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from vis.backend import modify_model_backprop
from vis.utils.test_utils import skip_backends

from keras.models import Model, Input
import keras
from keras.models import Model, Input, Sequential
from keras.layers import Dense
from keras.initializers import Constant
from keras import backend as K
Expand All @@ -20,44 +21,90 @@ def _compute_grads(model, input_array):

@skip_backends('theano')
def test_guided_grad_modifier():
# Create a simple linear sequence x -> linear(w.x) with weights w1 = -1, w2 = 1.
inp = Input(shape=(2, ))
out = Dense(1, activation='linear', use_bias=False, kernel_initializer=Constant([-1., 1.]))(inp)
model = Model(inp, out)

# Original model gradient should be [w1, w2]
assert np.array_equal(_compute_grads(model, [1., -1.]), [-1., 1.])

# Original gradient is [-1, 1] but new gradient should be [0, 0]
# First one is clipped because of negative gradient while the second is clipped due to negative input.
modified_model = modify_model_backprop(model, 'guided')
assert np.array_equal(_compute_grads(modified_model, [1., -1.]), [0., 0.])

# Ensure that the original model reference remains unchanged.
assert model.layers[1].activation == get('linear')
assert modified_model.layers[1].activation == get('relu')


@skip_backends('theano')
def test_advanced_activations():
""" Tests that various ways of specifying activations in keras models are handled when replaced with Relu
"""
inp = Input(shape=(2, ))
x = Dense(5, activation='elu')(inp)
x = advanced_activations.LeakyReLU()(x)
x = Activation('elu')(x)
model = Model(inp, x)

# Ensure that layer.activation, Activation and advanced activations are replaced with relu
modified_model = modify_model_backprop(model, 'guided')
assert modified_model.layers[1].activation == get('relu')
assert modified_model.layers[2].activation == get('relu')
assert modified_model.layers[3].activation == get('relu')

# Ensure that original model is unchanged.
assert model.layers[1].activation == get('elu')
assert isinstance(model.layers[2], advanced_activations.LeakyReLU)
assert model.layers[3].activation == get('elu')
# Create a simple 2 dense layer model.
simple_model = Sequential([
Dense(2, activation='relu', use_bias=False, kernel_initializer=Constant([[-1., 1.], [-1., 1.]]), input_shape=(2,)),
Dense(1, activation='linear', use_bias=False, kernel_initializer=Constant([-1., 1.]))
])
simple_model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adam())

# Create a simple 2 dense layer model using Activation.
simple_model_with_activation = Sequential([
Dense(2, activation='linear', use_bias=False, kernel_initializer=Constant([[-1., 1.], [-1., 1.]]), input_shape=(2,)),
Activation('relu'),
Dense(1, activation='linear', use_bias=False, kernel_initializer=Constant([-1., 1.]))
])
simple_model_with_activation.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adam())

for i, model in enumerate([simple_model, simple_model_with_activation]):
# Create guided backprop model
modified_model = modify_model_backprop(model, 'guided')

# Gradients are zeros.
input_array = [0., 0.]
assert np.array_equal(_compute_grads(model, input_array), [0., 0.])
assert np.array_equal(_compute_grads(modified_model, input_array), [0., 0.])

# Below 3 cases, GuidedBackprop gradients is the same as Original gradients.
input_array = [1., 0.]
assert np.array_equal(_compute_grads(model, input_array), [1., 1.])
assert np.array_equal(_compute_grads(modified_model, input_array), [1., 1.])

input_array = [0., 1.]
assert np.array_equal(_compute_grads(model, input_array), [1., 1.])
assert np.array_equal(_compute_grads(modified_model, input_array), [1., 1.])

input_array = [1., 1.]
assert np.array_equal(_compute_grads(model, input_array), [1., 1.])
assert np.array_equal(_compute_grads(modified_model, input_array), [1., 1.])

# If inputs contains negative values,
# GuidedBackprop gradients is not the same as Original gradients.
input_array = [-1., 0.]
assert np.array_equal(_compute_grads(model, input_array), [1., 1.])
assert np.array_equal(_compute_grads(modified_model, input_array), [0., 0.])

input_array = [0., -1.]
assert np.array_equal(_compute_grads(model, input_array), [1., 1.])
assert np.array_equal(_compute_grads(modified_model, input_array), [0., 0.])

input_array = [-1., -1.]
assert np.array_equal(_compute_grads(model, input_array), [1., 1.])
assert np.array_equal(_compute_grads(modified_model, input_array), [0., 0.])

# Activation is not changed.
if i == 0: # modified first model
modified_model.layers[0].activation == keras.activations.relu
modified_model.layers[1].activation == keras.activations.linear
if i == 1: # modified second model
modified_model.layers[0].activation == keras.activations.linear
modified_model.layers[1].activation == keras.activations.relu
modified_model.layers[2].activation == keras.activations.linear


# Currently, the modify_model_backprop function doesn't support advanced activation.
# Therefore, this test case will temporarily comment out.
#
# @skip_backends('theano')
# def test_advanced_activations():
# """ Tests that various ways of specifying activations in keras models are handled when replaced with Relu
# """
# inp = Input(shape=(2, ))
# x = Dense(5, activation='elu')(inp)
# x = advanced_activations.LeakyReLU()(x)
# x = Activation('elu')(x)
# model = Model(inp, x)
#
# # Ensure that layer.activation, Activation and advanced activations are replaced with relu
# modified_model = modify_model_backprop(model, 'guided')
# assert modified_model.layers[1].activation == get('relu')
# assert modified_model.layers[2].activation == get('relu')
# assert modified_model.layers[3].activation == get('relu')
#
# # Ensure that original model is unchanged.
# assert model.layers[1].activation == get('elu')
# assert isinstance(model.layers[2], advanced_activations.LeakyReLU)
# assert model.layers[3].activation == get('elu')


# @skip_backends('theano')
Expand Down
34 changes: 10 additions & 24 deletions vis/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ..utils import utils
from tensorflow.python.framework import ops
import keras
from keras.models import load_model
from keras.layers import advanced_activations, Activation

Expand All @@ -26,7 +27,7 @@ def _register_guided_gradient(name):
def _guided_backprop(op, grad):
dtype = op.outputs[0].dtype
gate_g = tf.cast(grad > 0., dtype)
gate_y = tf.cast(op.outputs[0] > 0, dtype)
gate_y = tf.cast(op.outputs[0] > 0., dtype)
return gate_y * gate_g * grad


Expand Down Expand Up @@ -60,18 +61,19 @@ def modify_model_backprop(model, backprop_modifier):
A copy of model with modified activations for backwards pass.
"""
# The general strategy is as follows:
# - Clone original model via save/load so that upstream callers don't see unexpected results with their models.
# - Modify all activations in the model as ReLU.
# - Save modified model so that it can be loaded with custom context modifying backprop behavior.
# - Save original model so that upstream callers don't see unexpected results with their models.
# - Call backend specific function that registers the custom op and loads the model under modified context manager.
# - Maintain cache to save this expensive process on subsequent calls.
# - Load model with custom context modifying backprop behavior.
#
# The reason for this round about way is because the graph needs to be rebuild when any of its layer builder
# functions are changed. This is very complicated to do in Keras and makes the implementation very tightly bound
# with keras internals. By saving and loading models, we dont have to worry about future compatibility.
#
# The only exception to this is the way advanced activations are handled which makes use of some keras internal
# knowledge and might break in the future.
# ADD on 22 Jul 2018:
# In fact, it has broken. Currently, advanced activations are not supported.

# 0. Retrieve from cache if previously computed.
modified_model = _MODIFIED_MODEL_CACHE.get((model, backprop_modifier))
Expand All @@ -80,32 +82,16 @@ def modify_model_backprop(model, backprop_modifier):

model_path = os.path.join(tempfile.gettempdir(), next(tempfile._get_candidate_names()) + '.h5')
try:
# 1. Clone original model via save and load.
# 1. Save original model
model.save(model_path)
modified_model = load_model(model_path)

# 2. Replace all possible activations with ReLU.
for i, layer in utils.reverse_enumerate(modified_model.layers):
if hasattr(layer, 'activation'):
layer.activation = tf.nn.relu
if isinstance(layer, _ADVANCED_ACTIVATIONS):
# NOTE: This code is brittle as it makes use of Keras internal serialization knowledge and might
# break in the future.
modified_layer = Activation('relu')
modified_layer.inbound_nodes = layer.inbound_nodes
modified_layer.name = layer.name
modified_model.layers[i] = modified_layer

# 3. Save model with modifications.
modified_model.save(model_path)

# 4. Register modifier and load modified model under custom context.

# 2. Register modifier and load modified model under custom context.
modifier_fn = _BACKPROP_MODIFIERS.get(backprop_modifier)
if modifier_fn is None:
raise ValueError("'{}' modifier is not supported".format(backprop_modifier))
modifier_fn(backprop_modifier)

# 5. Create graph under custom context manager.
# 3. Create graph under custom context manager.
with tf.get_default_graph().gradient_override_map({'Relu': backprop_modifier}):
# This should rebuild graph with modifications.
modified_model = load_model(model_path)
Expand Down