diff --git a/README.md b/README.md
index e853cae..9a44932 100644
--- a/README.md
+++ b/README.md
@@ -35,6 +35,7 @@ pip install tensorflow-gpu==2.0.0-beta1
2. [Occlusion Sensitivity](#occlusion-sensitivity)
3. [Grad CAM (Class Activation Maps)](#grad-cam)
4. [SmoothGrad](#smoothgrad)
+5. [Integrated Gradients](#integrated-gradients)
### Activations Visualization
@@ -144,13 +145,39 @@ model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
+### Integrated Gradients
+
+> Visualize an average of the gradients along the construction of the input towards the decision
+
+From [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf)
+
+```python
+from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
+
+model = [...]
+
+callbacks = [
+ IntegratedGradientsCallback(
+ validation_data=(x_val, y_val),
+ class_index=0,
+ n_steps=20,
+ output_dir=output_dir,
+ )
+]
+
+model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
+```
+
+
+
+
## Visualizing the results
When you use the callbacks, the output files are created in the `logs` directory.
-You can see them in tensorboard with the following command: `tensorboard --logdir logs`
+You can see them in Tensorboard with the following command: `tensorboard --logdir logs`
## Roadmap
@@ -158,7 +185,7 @@ You can see them in tensorboard with the following command: `tensorboard --logdi
- [ ] Subclassing API Support
- [ ] Additional Methods
- [ ] [GradCAM++](https://arxiv.org/abs/1710.11063)
- - [ ] [Integrated Gradients](https://arxiv.org/abs/1703.01365)
+ - [x] [Integrated Gradients](https://arxiv.org/abs/1703.01365)
- [ ] [Guided SmoothGrad](https://arxiv.org/abs/1706.03825)
- [ ] [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140)
- [ ] Auto-generated API Documentation & Documentation Testing
diff --git a/docs/assets/integrated_gradients.png b/docs/assets/integrated_gradients.png
new file mode 100644
index 0000000..9e78714
Binary files /dev/null and b/docs/assets/integrated_gradients.png differ
diff --git a/docs/source/methods.rst b/docs/source/methods.rst
index 12b3497..b34bb65 100644
--- a/docs/source/methods.rst
+++ b/docs/source/methods.rst
@@ -108,3 +108,31 @@ From `SmoothGrad: removing noise by adding noise `_
+::
+ from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
+
+ model = [...]
+
+ callbacks = [
+ IntegratedGradientsCallback(
+ validation_data=(x_val, y_val),
+ class_index=0,
+ n_steps=20,
+ output_dir=output_dir,
+ )
+ ]
+
+ model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
+
+.. image:: ../assets/integrated_gradients.png
+ :alt: IntegratedGradients
+ :width: 200px
+ :align: center
diff --git a/examples/callbacks/mnist.py b/examples/callbacks/mnist.py
index addcaee..f7537ca 100644
--- a/examples/callbacks/mnist.py
+++ b/examples/callbacks/mnist.py
@@ -61,6 +61,7 @@
tf_explain.callbacks.GradCAMCallback(validation_class_fours, 'target_layer', class_index=4),
tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),
tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),
+ tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10),
]
# Start training
diff --git a/examples/core/integrated_gradients.py b/examples/core/integrated_gradients.py
new file mode 100644
index 0000000..0a61029
--- /dev/null
+++ b/examples/core/integrated_gradients.py
@@ -0,0 +1,20 @@
+import tensorflow as tf
+
+from tf_explain.core.integrated_gradients import IntegratedGradients
+
+IMAGE_PATH = './cat.jpg'
+
+if __name__ == '__main__':
+ model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True)
+
+ img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
+ img = tf.keras.preprocessing.image.img_to_array(img)
+
+ model.summary()
+ data = ([img], None)
+
+ tabby_cat_class_index = 281
+ explainer = IntegratedGradients()
+ # Compute SmoothGrad on VGG16
+ grid = explainer.explain(data, model, tabby_cat_class_index, n_steps=15)
+ explainer.save(grid, '.', 'integrated_gradients.png')
diff --git a/tests/callbacks/test_integrated_gradients.py b/tests/callbacks/test_integrated_gradients.py
new file mode 100644
index 0000000..7c5e57d
--- /dev/null
+++ b/tests/callbacks/test_integrated_gradients.py
@@ -0,0 +1,32 @@
+import numpy as np
+from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
+
+
+def test_should_call_integrated_gradients_callback(
+ random_data, convolutional_model, output_dir, mocker
+):
+ mock_explainer = mocker.MagicMock(explain=mocker.MagicMock(return_value=0))
+ mocker.patch(
+ "tf_explain.callbacks.integrated_gradients.IntegratedGradients",
+ return_value=mock_explainer,
+ )
+ mock_image_summary = mocker.patch(
+ "tf_explain.callbacks.integrated_gradients.tf.summary.image"
+ )
+
+ images, labels = random_data
+
+ callbacks = [
+ IntegratedGradientsCallback(
+ validation_data=random_data, class_index=0, output_dir=output_dir, n_steps=3
+ )
+ ]
+
+ convolutional_model.fit(images, labels, batch_size=2, epochs=1, callbacks=callbacks)
+
+ mock_explainer.explain.assert_called_once_with(
+ random_data, convolutional_model, 0, 3
+ )
+ mock_image_summary.assert_called_once_with(
+ "IntegratedGradients", np.array([0]), step=0
+ )
diff --git a/tests/callbacks/test_smoothgrad.py b/tests/callbacks/test_smoothgrad.py
index a2e532e..4eab5b8 100644
--- a/tests/callbacks/test_smoothgrad.py
+++ b/tests/callbacks/test_smoothgrad.py
@@ -9,7 +9,9 @@ def test_should_call_smoothgrad_callback(
mocker.patch(
"tf_explain.callbacks.smoothgrad.SmoothGrad", return_value=mock_explainer
)
- mock_image_summary = mocker.patch("tf_explain.callbacks.grad_cam.tf.summary.image")
+ mock_image_summary = mocker.patch(
+ "tf_explain.callbacks.smoothgrad.tf.summary.image"
+ )
images, labels = random_data
diff --git a/tests/core/test_integrated_gradients.py b/tests/core/test_integrated_gradients.py
new file mode 100644
index 0000000..073ed8a
--- /dev/null
+++ b/tests/core/test_integrated_gradients.py
@@ -0,0 +1,40 @@
+import numpy as np
+
+from tf_explain.core.integrated_gradients import IntegratedGradients
+
+
+def test_should_explain_output(convolutional_model, random_data, mocker):
+ mocker.patch(
+ "tf_explain.core.integrated_gradients.grid_display", side_effect=lambda x: x
+ )
+ images, labels = random_data
+ explainer = IntegratedGradients()
+ grid = explainer.explain((images, labels), convolutional_model, 0)
+
+ # Outputs is in grayscale format
+ assert grid.shape == images.shape[:-1]
+
+
+def test_generate_linear_path():
+ input_shape = (28, 28, 1)
+ target = np.ones(input_shape)
+ baseline = np.zeros(input_shape)
+ n_steps = 3
+
+ expected_output = [baseline, 1 / 2 * (target - baseline), target]
+
+ output = IntegratedGradients.generate_linear_path(baseline, target, n_steps)
+
+ np.testing.assert_almost_equal(output, expected_output)
+
+
+def test_get_integrated_gradients(random_data, convolutional_model):
+ images, _ = random_data
+ n_steps = 4
+ gradients = IntegratedGradients.get_integrated_gradients(
+ images, convolutional_model, 0, n_steps=n_steps
+ )
+
+ expected_output_shape = (images.shape[0] / n_steps, *images.shape[1:])
+
+ assert gradients.shape == expected_output_shape
diff --git a/tests/core/test_smoothgrad.py b/tests/core/test_smoothgrad.py
index c9da517..0006a8e 100644
--- a/tests/core/test_smoothgrad.py
+++ b/tests/core/test_smoothgrad.py
@@ -1,5 +1,4 @@
import numpy as np
-import tensorflow as tf
from tf_explain.core.smoothgrad import SmoothGrad
@@ -28,15 +27,6 @@ def test_generate_noisy_images(mocker):
np.testing.assert_array_equal(output, 2 * np.ones((30, 28, 28, 1)))
-def test_should_transform_gradients_to_grayscale():
- gradients = tf.random.uniform((4, 28, 28, 3))
-
- grayscale_gradients = SmoothGrad.transform_to_grayscale(gradients)
- expected_output_shape = (4, 28, 28)
-
- assert grayscale_gradients.shape == expected_output_shape
-
-
def test_get_averaged_gradients(random_data, convolutional_model):
images, _ = random_data
num_samples = 2
diff --git a/tests/utils/test_image.py b/tests/utils/test_image.py
index 3f8b25c..4d76f54 100644
--- a/tests/utils/test_image.py
+++ b/tests/utils/test_image.py
@@ -1,6 +1,7 @@
import numpy as np
+import tensorflow as tf
-from tf_explain.utils.image import apply_grey_patch
+from tf_explain.utils.image import apply_grey_patch, transform_to_normalized_grayscale
def test_should_apply_grey_patch_on_image():
@@ -13,3 +14,12 @@ def test_should_apply_grey_patch_on_image():
)
np.testing.assert_almost_equal(output, expected_output)
+
+
+def test_should_transform_gradients_to_grayscale():
+ gradients = tf.random.uniform((4, 28, 28, 3))
+
+ grayscale_gradients = transform_to_normalized_grayscale(gradients)
+ expected_output_shape = (4, 28, 28)
+
+ assert grayscale_gradients.shape == expected_output_shape
diff --git a/tf_explain/callbacks/__init__.py b/tf_explain/callbacks/__init__.py
index 3cca06c..b562404 100644
--- a/tf_explain/callbacks/__init__.py
+++ b/tf_explain/callbacks/__init__.py
@@ -1,5 +1,6 @@
from .activations_visualization import ActivationsVisualizationCallback
from .grad_cam import GradCAMCallback
+from .integrated_gradients import IntegratedGradientsCallback
from .occlusion_sensitivity import OcclusionSensitivityCallback
from .smoothgrad import SmoothGradCallback
@@ -7,6 +8,7 @@
__all__ = [
"ActivationsVisualizationCallback",
"GradCAMCallback",
+ "IntegratedGradientsCallback",
"OcclusionSensitivityCallback",
"SmoothGradCallback",
]
diff --git a/tf_explain/callbacks/integrated_gradients.py b/tf_explain/callbacks/integrated_gradients.py
new file mode 100644
index 0000000..eb8faf6
--- /dev/null
+++ b/tf_explain/callbacks/integrated_gradients.py
@@ -0,0 +1,65 @@
+"""
+Callback Module for Integrated Gradients
+"""
+from datetime import datetime
+from pathlib import Path
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras.callbacks import Callback
+
+from tf_explain.core.integrated_gradients import IntegratedGradients
+
+
+class IntegratedGradientsCallback(Callback):
+
+ """
+ Perform Integrated Gradients algorithm for a given input
+
+ Paper: [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf)
+ """
+
+ def __init__(
+ self,
+ validation_data,
+ class_index,
+ n_steps=5,
+ output_dir=Path("./logs/integrated_gradients"),
+ ):
+ """
+ Constructor.
+
+ Args:
+ validation_data (Tuple[np.ndarray, Optional[np.ndarray]]): Validation data
+ to perform the method on. Tuple containing (x, y).
+ class_index (int): Index of targeted class
+ n_steps (int): Number of steps in the path
+ output_dir (str): Output directory path
+ """
+ super(IntegratedGradientsCallback, self).__init__()
+ self.validation_data = validation_data
+ self.class_index = class_index
+ self.n_steps = n_steps
+ self.output_dir = Path(output_dir) / datetime.now().strftime("%Y%m%d-%H%M%S.%f")
+ Path.mkdir(Path(self.output_dir), parents=True, exist_ok=True)
+
+ self.file_writer = tf.summary.create_file_writer(str(self.output_dir))
+
+ def on_epoch_end(self, epoch, logs=None):
+ """
+ Draw Integrated Gradients outputs at each epoch end to Tensorboard.
+
+ Args:
+ epoch (int): Epoch index
+ logs (dict): Additional information on epoch
+ """
+ explainer = IntegratedGradients()
+ grid = explainer.explain(
+ self.validation_data, self.model, self.class_index, self.n_steps
+ )
+
+ # Using the file writer, log the reshaped image.
+ with self.file_writer.as_default():
+ tf.summary.image(
+ "IntegratedGradients", np.expand_dims([grid], axis=-1), step=epoch
+ )
diff --git a/tf_explain/core/integrated_gradients.py b/tf_explain/core/integrated_gradients.py
new file mode 100644
index 0000000..c040a13
--- /dev/null
+++ b/tf_explain/core/integrated_gradients.py
@@ -0,0 +1,134 @@
+"""
+Core Module for Integrated Gradients Algorithm
+"""
+from pathlib import Path
+
+import cv2
+import numpy as np
+import tensorflow as tf
+
+from tf_explain.utils.display import grid_display
+from tf_explain.utils.image import transform_to_normalized_grayscale
+
+
+class IntegratedGradients:
+
+ """
+ Perform Integrated Gradients algorithm for a given input
+
+ Paper: [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf)
+ """
+
+ def explain(self, validation_data, model, class_index, n_steps=10):
+ """
+ Compute Integrated Gradients for a specific class index
+
+ Args:
+ validation_data (Tuple[np.ndarray, Optional[np.ndarray]]): Validation data
+ to perform the method on. Tuple containing (x, y).
+ model (tf.keras.Model): tf.keras model to inspect
+ class_index (int): Index of targeted class
+ n_steps (int): Number of steps in the path
+
+ Returns:
+ np.ndarray: Grid of all the integrated gradients
+ """
+ images, _ = validation_data
+
+ interpolated_images = IntegratedGradients.generate_interpolations(
+ np.array(images), n_steps
+ )
+
+ integrated_gradients = IntegratedGradients.get_integrated_gradients(
+ interpolated_images, model, class_index, n_steps
+ )
+
+ grayscale_integrated_gradients = transform_to_normalized_grayscale(
+ tf.abs(integrated_gradients)
+ ).numpy()
+
+ grid = grid_display(grayscale_integrated_gradients)
+
+ return grid
+
+ @staticmethod
+ @tf.function
+ def get_integrated_gradients(interpolated_images, model, class_index, n_steps):
+ """
+ Perform backpropagation to compute integrated gradients.
+
+ Args:
+ interpolated_images (numpy.ndarray): 4D-Tensor of shape (N * n_steps, H, W, 3)
+ model (tf.keras.Model): tf.keras model to inspect
+ class_index (int): Index of targeted class
+ n_steps (int): Number of steps in the path
+
+ Returns:
+ tf.Tensor: 4D-Tensor of shape (N, H, W, 3) with integrated gradients
+ """
+ with tf.GradientTape() as tape:
+ inputs = tf.cast(interpolated_images, tf.float32)
+ tape.watch(inputs)
+ predictions = model(inputs)
+ loss = predictions[:, class_index]
+
+ grads = tape.gradient(loss, inputs)
+ grads_per_image = tf.reshape(grads, (-1, n_steps, *grads.shape[1:]))
+
+ integrated_gradients = tf.reduce_mean(grads_per_image, axis=1)
+
+ return integrated_gradients
+
+ @staticmethod
+ def generate_interpolations(images, n_steps):
+ """
+ Generate interpolation paths for batch of images.
+
+ Args:
+ images (numpy.ndarray): 4D-Tensor of images with shape (N, H, W, 3)
+ n_steps (int): Number of steps in the path
+
+ Returns:
+ numpy.ndarray: Interpolation paths for each image with shape (N * n_steps, H, W, 3)
+ """
+ baseline = np.zeros(images.shape[1:])
+
+ return np.concatenate(
+ [
+ IntegratedGradients.generate_linear_path(baseline, image, n_steps)
+ for image in images
+ ]
+ )
+
+ @staticmethod
+ def generate_linear_path(baseline, target, n_steps):
+ """
+ Generate the interpolation path between the baseline image and the target image.
+
+ Args:
+ baseline (numpy.ndarray): Reference image
+ target (numpy.ndarray): Target image
+ n_steps (int): Number of steps in the path
+
+ Returns:
+ List(np.ndarray): List of images for each step
+ """
+ return [
+ baseline + (target - baseline) * index / (n_steps - 1)
+ for index in range(n_steps)
+ ]
+
+ def save(self, grid, output_dir, output_name):
+ """
+ Save the output to a specific dir.
+
+ Args:
+ grid (numpy.ndarray): Grid of all the heatmaps
+ output_dir (str): Output directory path
+ output_name (str): Output name
+ """
+ Path.mkdir(Path(output_dir), parents=True, exist_ok=True)
+
+ cv2.imwrite(
+ str(Path(output_dir) / output_name), cv2.cvtColor(grid, cv2.COLOR_RGB2BGR)
+ )
diff --git a/tf_explain/core/smoothgrad.py b/tf_explain/core/smoothgrad.py
index 7b65607..7b28b47 100644
--- a/tf_explain/core/smoothgrad.py
+++ b/tf_explain/core/smoothgrad.py
@@ -8,6 +8,7 @@
import tensorflow as tf
from tf_explain.utils.display import grid_display
+from tf_explain.utils.image import transform_to_normalized_grayscale
class SmoothGrad:
@@ -41,8 +42,8 @@ def explain(self, validation_data, model, class_index, num_samples=5, noise=1.0)
noisy_images, model, class_index, num_samples
)
- grayscale_gradients = SmoothGrad.transform_to_grayscale(
- smoothed_gradients
+ grayscale_gradients = transform_to_normalized_grayscale(
+ tf.abs(smoothed_gradients)
).numpy()
grid = grid_display(grayscale_gradients)
@@ -67,28 +68,6 @@ def generate_noisy_images(images, num_samples, noise):
return repeated_images + noise
- @staticmethod
- @tf.function
- def transform_to_grayscale(gradients):
- """
- Transform gradients over RGB axis to grayscale.
-
- Args:
- gradients (tf.Tensor): 4D-Tensor with shape (batch_size, H, W, 3)
-
- Returns:
- tf.Tensor: 4D-Tensor of grayscale gradients, with shape (batch_size, H, W, 1)
- """
- grayscale_grads = tf.reduce_sum(tf.abs(gradients), axis=-1)
- normalized_grads = tf.cast(
- 255
- * (grayscale_grads - tf.reduce_min(grayscale_grads))
- / (tf.reduce_max(grayscale_grads) - tf.reduce_min(grayscale_grads)),
- tf.uint8,
- )
-
- return normalized_grads
-
@staticmethod
@tf.function
def get_averaged_gradients(noisy_images, model, class_index, num_samples):
diff --git a/tf_explain/utils/image.py b/tf_explain/utils/image.py
index 9168550..177b5f9 100644
--- a/tf_explain/utils/image.py
+++ b/tf_explain/utils/image.py
@@ -1,5 +1,6 @@
""" Module for image operations """
import numpy as np
+import tensorflow as tf
def apply_grey_patch(image, top_left_x, top_left_y, patch_size):
@@ -21,3 +22,25 @@ def apply_grey_patch(image, top_left_x, top_left_y, patch_size):
] = 127.5
return patched_image
+
+
+@tf.function
+def transform_to_normalized_grayscale(tensor):
+ """
+ Transform tensor over RGB axis to grayscale.
+
+ Args:
+ tensor (tf.Tensor): 4D-Tensor with shape (batch_size, H, W, 3)
+
+ Returns:
+ tf.Tensor: 4D-Tensor of grayscale tensor, with shape (batch_size, H, W, 1)
+ """
+ grayscale_tensor = tf.reduce_sum(tensor, axis=-1)
+ normalized_tensor = tf.cast(
+ 255
+ * (grayscale_tensor - tf.reduce_min(grayscale_tensor))
+ / (tf.reduce_max(grayscale_tensor) - tf.reduce_min(grayscale_tensor)),
+ tf.uint8,
+ )
+
+ return normalized_tensor