Skip to content

Commit

Permalink
Integrated gradients (#70)
Browse files Browse the repository at this point in the history
* Extract normalized grayscale from SmoothGrad

* Add Integrated Gradients method

* Add Integrated Gradients example

* Add Integrated Gradients callback

* Add Integrated Gradients to README

* Add Integrated Gradients in docs
  • Loading branch information
Raphael Meudec committed Aug 23, 2019
1 parent 0ee8db4 commit 58554e2
Show file tree
Hide file tree
Showing 15 changed files with 391 additions and 38 deletions.
31 changes: 29 additions & 2 deletions README.md
Expand Up @@ -35,6 +35,7 @@ pip install tensorflow-gpu==2.0.0-beta1
2. [Occlusion Sensitivity](#occlusion-sensitivity)
3. [Grad CAM (Class Activation Maps)](#grad-cam)
4. [SmoothGrad](#smoothgrad)
5. [Integrated Gradients](#integrated-gradients)

### Activations Visualization

Expand Down Expand Up @@ -144,21 +145,47 @@ model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
<img src="./docs/assets/smoothgrad.png" width="200" />
</p>

### Integrated Gradients

> Visualize an average of the gradients along the construction of the input towards the decision
From [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf)

```python
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback

model = [...]

callbacks = [
IntegratedGradientsCallback(
validation_data=(x_val, y_val),
class_index=0,
n_steps=20,
output_dir=output_dir,
)
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
```

<p align="center">
<img src="./docs/assets/integrated_gradients.png" width="200" />
</p>


## Visualizing the results

When you use the callbacks, the output files are created in the `logs` directory.

You can see them in tensorboard with the following command: `tensorboard --logdir logs`
You can see them in Tensorboard with the following command: `tensorboard --logdir logs`


## Roadmap

- [ ] Subclassing API Support
- [ ] Additional Methods
- [ ] [GradCAM++](https://arxiv.org/abs/1710.11063)
- [ ] [Integrated Gradients](https://arxiv.org/abs/1703.01365)
- [x] [Integrated Gradients](https://arxiv.org/abs/1703.01365)
- [ ] [Guided SmoothGrad](https://arxiv.org/abs/1706.03825)
- [ ] [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140)
- [ ] Auto-generated API Documentation & Documentation Testing
Binary file added docs/assets/integrated_gradients.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions docs/source/methods.rst
Expand Up @@ -108,3 +108,31 @@ From `SmoothGrad: removing noise by adding noise <https://arxiv.org/abs/1706.038
:alt: SmoothGrad
:width: 200px
:align: center


Integrated Gradients
********************

Visualize an average of the gradients along the construction of the input towards the decision.

From `Axiomatic Attribution for Deep Networks <https://arxiv.org/pdf/1703.01365.pdf>`_
::
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback

model = [...]

callbacks = [
IntegratedGradientsCallback(
validation_data=(x_val, y_val),
class_index=0,
n_steps=20,
output_dir=output_dir,
)
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

.. image:: ../assets/integrated_gradients.png
:alt: IntegratedGradients
:width: 200px
:align: center
1 change: 1 addition & 0 deletions examples/callbacks/mnist.py
Expand Up @@ -61,6 +61,7 @@
tf_explain.callbacks.GradCAMCallback(validation_class_fours, 'target_layer', class_index=4),
tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),
tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),
tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10),
]

# Start training
Expand Down
20 changes: 20 additions & 0 deletions examples/core/integrated_gradients.py
@@ -0,0 +1,20 @@
import tensorflow as tf

from tf_explain.core.integrated_gradients import IntegratedGradients

IMAGE_PATH = './cat.jpg'

if __name__ == '__main__':
model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True)

img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)

model.summary()
data = ([img], None)

tabby_cat_class_index = 281
explainer = IntegratedGradients()
# Compute SmoothGrad on VGG16
grid = explainer.explain(data, model, tabby_cat_class_index, n_steps=15)
explainer.save(grid, '.', 'integrated_gradients.png')
32 changes: 32 additions & 0 deletions tests/callbacks/test_integrated_gradients.py
@@ -0,0 +1,32 @@
import numpy as np
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback


def test_should_call_integrated_gradients_callback(
random_data, convolutional_model, output_dir, mocker
):
mock_explainer = mocker.MagicMock(explain=mocker.MagicMock(return_value=0))
mocker.patch(
"tf_explain.callbacks.integrated_gradients.IntegratedGradients",
return_value=mock_explainer,
)
mock_image_summary = mocker.patch(
"tf_explain.callbacks.integrated_gradients.tf.summary.image"
)

images, labels = random_data

callbacks = [
IntegratedGradientsCallback(
validation_data=random_data, class_index=0, output_dir=output_dir, n_steps=3
)
]

convolutional_model.fit(images, labels, batch_size=2, epochs=1, callbacks=callbacks)

mock_explainer.explain.assert_called_once_with(
random_data, convolutional_model, 0, 3
)
mock_image_summary.assert_called_once_with(
"IntegratedGradients", np.array([0]), step=0
)
4 changes: 3 additions & 1 deletion tests/callbacks/test_smoothgrad.py
Expand Up @@ -9,7 +9,9 @@ def test_should_call_smoothgrad_callback(
mocker.patch(
"tf_explain.callbacks.smoothgrad.SmoothGrad", return_value=mock_explainer
)
mock_image_summary = mocker.patch("tf_explain.callbacks.grad_cam.tf.summary.image")
mock_image_summary = mocker.patch(
"tf_explain.callbacks.smoothgrad.tf.summary.image"
)

images, labels = random_data

Expand Down
40 changes: 40 additions & 0 deletions tests/core/test_integrated_gradients.py
@@ -0,0 +1,40 @@
import numpy as np

from tf_explain.core.integrated_gradients import IntegratedGradients


def test_should_explain_output(convolutional_model, random_data, mocker):
mocker.patch(
"tf_explain.core.integrated_gradients.grid_display", side_effect=lambda x: x
)
images, labels = random_data
explainer = IntegratedGradients()
grid = explainer.explain((images, labels), convolutional_model, 0)

# Outputs is in grayscale format
assert grid.shape == images.shape[:-1]


def test_generate_linear_path():
input_shape = (28, 28, 1)
target = np.ones(input_shape)
baseline = np.zeros(input_shape)
n_steps = 3

expected_output = [baseline, 1 / 2 * (target - baseline), target]

output = IntegratedGradients.generate_linear_path(baseline, target, n_steps)

np.testing.assert_almost_equal(output, expected_output)


def test_get_integrated_gradients(random_data, convolutional_model):
images, _ = random_data
n_steps = 4
gradients = IntegratedGradients.get_integrated_gradients(
images, convolutional_model, 0, n_steps=n_steps
)

expected_output_shape = (images.shape[0] / n_steps, *images.shape[1:])

assert gradients.shape == expected_output_shape
10 changes: 0 additions & 10 deletions tests/core/test_smoothgrad.py
@@ -1,5 +1,4 @@
import numpy as np
import tensorflow as tf

from tf_explain.core.smoothgrad import SmoothGrad

Expand Down Expand Up @@ -28,15 +27,6 @@ def test_generate_noisy_images(mocker):
np.testing.assert_array_equal(output, 2 * np.ones((30, 28, 28, 1)))


def test_should_transform_gradients_to_grayscale():
gradients = tf.random.uniform((4, 28, 28, 3))

grayscale_gradients = SmoothGrad.transform_to_grayscale(gradients)
expected_output_shape = (4, 28, 28)

assert grayscale_gradients.shape == expected_output_shape


def test_get_averaged_gradients(random_data, convolutional_model):
images, _ = random_data
num_samples = 2
Expand Down
12 changes: 11 additions & 1 deletion tests/utils/test_image.py
@@ -1,6 +1,7 @@
import numpy as np
import tensorflow as tf

from tf_explain.utils.image import apply_grey_patch
from tf_explain.utils.image import apply_grey_patch, transform_to_normalized_grayscale


def test_should_apply_grey_patch_on_image():
Expand All @@ -13,3 +14,12 @@ def test_should_apply_grey_patch_on_image():
)

np.testing.assert_almost_equal(output, expected_output)


def test_should_transform_gradients_to_grayscale():
gradients = tf.random.uniform((4, 28, 28, 3))

grayscale_gradients = transform_to_normalized_grayscale(gradients)
expected_output_shape = (4, 28, 28)

assert grayscale_gradients.shape == expected_output_shape
2 changes: 2 additions & 0 deletions tf_explain/callbacks/__init__.py
@@ -1,12 +1,14 @@
from .activations_visualization import ActivationsVisualizationCallback
from .grad_cam import GradCAMCallback
from .integrated_gradients import IntegratedGradientsCallback
from .occlusion_sensitivity import OcclusionSensitivityCallback
from .smoothgrad import SmoothGradCallback


__all__ = [
"ActivationsVisualizationCallback",
"GradCAMCallback",
"IntegratedGradientsCallback",
"OcclusionSensitivityCallback",
"SmoothGradCallback",
]
65 changes: 65 additions & 0 deletions tf_explain/callbacks/integrated_gradients.py
@@ -0,0 +1,65 @@
"""
Callback Module for Integrated Gradients
"""
from datetime import datetime
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import Callback

from tf_explain.core.integrated_gradients import IntegratedGradients


class IntegratedGradientsCallback(Callback):

"""
Perform Integrated Gradients algorithm for a given input
Paper: [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf)
"""

def __init__(
self,
validation_data,
class_index,
n_steps=5,
output_dir=Path("./logs/integrated_gradients"),
):
"""
Constructor.
Args:
validation_data (Tuple[np.ndarray, Optional[np.ndarray]]): Validation data
to perform the method on. Tuple containing (x, y).
class_index (int): Index of targeted class
n_steps (int): Number of steps in the path
output_dir (str): Output directory path
"""
super(IntegratedGradientsCallback, self).__init__()
self.validation_data = validation_data
self.class_index = class_index
self.n_steps = n_steps
self.output_dir = Path(output_dir) / datetime.now().strftime("%Y%m%d-%H%M%S.%f")
Path.mkdir(Path(self.output_dir), parents=True, exist_ok=True)

self.file_writer = tf.summary.create_file_writer(str(self.output_dir))

def on_epoch_end(self, epoch, logs=None):
"""
Draw Integrated Gradients outputs at each epoch end to Tensorboard.
Args:
epoch (int): Epoch index
logs (dict): Additional information on epoch
"""
explainer = IntegratedGradients()
grid = explainer.explain(
self.validation_data, self.model, self.class_index, self.n_steps
)

# Using the file writer, log the reshaped image.
with self.file_writer.as_default():
tf.summary.image(
"IntegratedGradients", np.expand_dims([grid], axis=-1), step=epoch
)

0 comments on commit 58554e2

Please sign in to comment.