diff --git a/README.md b/README.md index e853cae..9a44932 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ pip install tensorflow-gpu==2.0.0-beta1 2. [Occlusion Sensitivity](#occlusion-sensitivity) 3. [Grad CAM (Class Activation Maps)](#grad-cam) 4. [SmoothGrad](#smoothgrad) +5. [Integrated Gradients](#integrated-gradients) ### Activations Visualization @@ -144,13 +145,39 @@ model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

+### Integrated Gradients + +> Visualize an average of the gradients along the construction of the input towards the decision + +From [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf) + +```python +from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback + +model = [...] + +callbacks = [ + IntegratedGradientsCallback( + validation_data=(x_val, y_val), + class_index=0, + n_steps=20, + output_dir=output_dir, + ) +] + +model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks) +``` + +

+ +

## Visualizing the results When you use the callbacks, the output files are created in the `logs` directory. -You can see them in tensorboard with the following command: `tensorboard --logdir logs` +You can see them in Tensorboard with the following command: `tensorboard --logdir logs` ## Roadmap @@ -158,7 +185,7 @@ You can see them in tensorboard with the following command: `tensorboard --logdi - [ ] Subclassing API Support - [ ] Additional Methods - [ ] [GradCAM++](https://arxiv.org/abs/1710.11063) - - [ ] [Integrated Gradients](https://arxiv.org/abs/1703.01365) + - [x] [Integrated Gradients](https://arxiv.org/abs/1703.01365) - [ ] [Guided SmoothGrad](https://arxiv.org/abs/1706.03825) - [ ] [LRP](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140) - [ ] Auto-generated API Documentation & Documentation Testing diff --git a/docs/assets/integrated_gradients.png b/docs/assets/integrated_gradients.png new file mode 100644 index 0000000..9e78714 Binary files /dev/null and b/docs/assets/integrated_gradients.png differ diff --git a/docs/source/methods.rst b/docs/source/methods.rst index 12b3497..b34bb65 100644 --- a/docs/source/methods.rst +++ b/docs/source/methods.rst @@ -108,3 +108,31 @@ From `SmoothGrad: removing noise by adding noise `_ +:: + from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback + + model = [...] + + callbacks = [ + IntegratedGradientsCallback( + validation_data=(x_val, y_val), + class_index=0, + n_steps=20, + output_dir=output_dir, + ) + ] + + model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks) + +.. image:: ../assets/integrated_gradients.png + :alt: IntegratedGradients + :width: 200px + :align: center diff --git a/examples/callbacks/mnist.py b/examples/callbacks/mnist.py index addcaee..f7537ca 100644 --- a/examples/callbacks/mnist.py +++ b/examples/callbacks/mnist.py @@ -61,6 +61,7 @@ tf_explain.callbacks.GradCAMCallback(validation_class_fours, 'target_layer', class_index=4), tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']), tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.), + tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10), ] # Start training diff --git a/examples/core/integrated_gradients.py b/examples/core/integrated_gradients.py new file mode 100644 index 0000000..0a61029 --- /dev/null +++ b/examples/core/integrated_gradients.py @@ -0,0 +1,20 @@ +import tensorflow as tf + +from tf_explain.core.integrated_gradients import IntegratedGradients + +IMAGE_PATH = './cat.jpg' + +if __name__ == '__main__': + model = tf.keras.applications.vgg16.VGG16(weights='imagenet', include_top=True) + + img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224)) + img = tf.keras.preprocessing.image.img_to_array(img) + + model.summary() + data = ([img], None) + + tabby_cat_class_index = 281 + explainer = IntegratedGradients() + # Compute SmoothGrad on VGG16 + grid = explainer.explain(data, model, tabby_cat_class_index, n_steps=15) + explainer.save(grid, '.', 'integrated_gradients.png') diff --git a/tests/callbacks/test_integrated_gradients.py b/tests/callbacks/test_integrated_gradients.py new file mode 100644 index 0000000..7c5e57d --- /dev/null +++ b/tests/callbacks/test_integrated_gradients.py @@ -0,0 +1,32 @@ +import numpy as np +from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback + + +def test_should_call_integrated_gradients_callback( + random_data, convolutional_model, output_dir, mocker +): + mock_explainer = mocker.MagicMock(explain=mocker.MagicMock(return_value=0)) + mocker.patch( + "tf_explain.callbacks.integrated_gradients.IntegratedGradients", + return_value=mock_explainer, + ) + mock_image_summary = mocker.patch( + "tf_explain.callbacks.integrated_gradients.tf.summary.image" + ) + + images, labels = random_data + + callbacks = [ + IntegratedGradientsCallback( + validation_data=random_data, class_index=0, output_dir=output_dir, n_steps=3 + ) + ] + + convolutional_model.fit(images, labels, batch_size=2, epochs=1, callbacks=callbacks) + + mock_explainer.explain.assert_called_once_with( + random_data, convolutional_model, 0, 3 + ) + mock_image_summary.assert_called_once_with( + "IntegratedGradients", np.array([0]), step=0 + ) diff --git a/tests/callbacks/test_smoothgrad.py b/tests/callbacks/test_smoothgrad.py index a2e532e..4eab5b8 100644 --- a/tests/callbacks/test_smoothgrad.py +++ b/tests/callbacks/test_smoothgrad.py @@ -9,7 +9,9 @@ def test_should_call_smoothgrad_callback( mocker.patch( "tf_explain.callbacks.smoothgrad.SmoothGrad", return_value=mock_explainer ) - mock_image_summary = mocker.patch("tf_explain.callbacks.grad_cam.tf.summary.image") + mock_image_summary = mocker.patch( + "tf_explain.callbacks.smoothgrad.tf.summary.image" + ) images, labels = random_data diff --git a/tests/core/test_integrated_gradients.py b/tests/core/test_integrated_gradients.py new file mode 100644 index 0000000..073ed8a --- /dev/null +++ b/tests/core/test_integrated_gradients.py @@ -0,0 +1,40 @@ +import numpy as np + +from tf_explain.core.integrated_gradients import IntegratedGradients + + +def test_should_explain_output(convolutional_model, random_data, mocker): + mocker.patch( + "tf_explain.core.integrated_gradients.grid_display", side_effect=lambda x: x + ) + images, labels = random_data + explainer = IntegratedGradients() + grid = explainer.explain((images, labels), convolutional_model, 0) + + # Outputs is in grayscale format + assert grid.shape == images.shape[:-1] + + +def test_generate_linear_path(): + input_shape = (28, 28, 1) + target = np.ones(input_shape) + baseline = np.zeros(input_shape) + n_steps = 3 + + expected_output = [baseline, 1 / 2 * (target - baseline), target] + + output = IntegratedGradients.generate_linear_path(baseline, target, n_steps) + + np.testing.assert_almost_equal(output, expected_output) + + +def test_get_integrated_gradients(random_data, convolutional_model): + images, _ = random_data + n_steps = 4 + gradients = IntegratedGradients.get_integrated_gradients( + images, convolutional_model, 0, n_steps=n_steps + ) + + expected_output_shape = (images.shape[0] / n_steps, *images.shape[1:]) + + assert gradients.shape == expected_output_shape diff --git a/tests/core/test_smoothgrad.py b/tests/core/test_smoothgrad.py index c9da517..0006a8e 100644 --- a/tests/core/test_smoothgrad.py +++ b/tests/core/test_smoothgrad.py @@ -1,5 +1,4 @@ import numpy as np -import tensorflow as tf from tf_explain.core.smoothgrad import SmoothGrad @@ -28,15 +27,6 @@ def test_generate_noisy_images(mocker): np.testing.assert_array_equal(output, 2 * np.ones((30, 28, 28, 1))) -def test_should_transform_gradients_to_grayscale(): - gradients = tf.random.uniform((4, 28, 28, 3)) - - grayscale_gradients = SmoothGrad.transform_to_grayscale(gradients) - expected_output_shape = (4, 28, 28) - - assert grayscale_gradients.shape == expected_output_shape - - def test_get_averaged_gradients(random_data, convolutional_model): images, _ = random_data num_samples = 2 diff --git a/tests/utils/test_image.py b/tests/utils/test_image.py index 3f8b25c..4d76f54 100644 --- a/tests/utils/test_image.py +++ b/tests/utils/test_image.py @@ -1,6 +1,7 @@ import numpy as np +import tensorflow as tf -from tf_explain.utils.image import apply_grey_patch +from tf_explain.utils.image import apply_grey_patch, transform_to_normalized_grayscale def test_should_apply_grey_patch_on_image(): @@ -13,3 +14,12 @@ def test_should_apply_grey_patch_on_image(): ) np.testing.assert_almost_equal(output, expected_output) + + +def test_should_transform_gradients_to_grayscale(): + gradients = tf.random.uniform((4, 28, 28, 3)) + + grayscale_gradients = transform_to_normalized_grayscale(gradients) + expected_output_shape = (4, 28, 28) + + assert grayscale_gradients.shape == expected_output_shape diff --git a/tf_explain/callbacks/__init__.py b/tf_explain/callbacks/__init__.py index 3cca06c..b562404 100644 --- a/tf_explain/callbacks/__init__.py +++ b/tf_explain/callbacks/__init__.py @@ -1,5 +1,6 @@ from .activations_visualization import ActivationsVisualizationCallback from .grad_cam import GradCAMCallback +from .integrated_gradients import IntegratedGradientsCallback from .occlusion_sensitivity import OcclusionSensitivityCallback from .smoothgrad import SmoothGradCallback @@ -7,6 +8,7 @@ __all__ = [ "ActivationsVisualizationCallback", "GradCAMCallback", + "IntegratedGradientsCallback", "OcclusionSensitivityCallback", "SmoothGradCallback", ] diff --git a/tf_explain/callbacks/integrated_gradients.py b/tf_explain/callbacks/integrated_gradients.py new file mode 100644 index 0000000..eb8faf6 --- /dev/null +++ b/tf_explain/callbacks/integrated_gradients.py @@ -0,0 +1,65 @@ +""" +Callback Module for Integrated Gradients +""" +from datetime import datetime +from pathlib import Path + +import numpy as np +import tensorflow as tf +from tensorflow.keras.callbacks import Callback + +from tf_explain.core.integrated_gradients import IntegratedGradients + + +class IntegratedGradientsCallback(Callback): + + """ + Perform Integrated Gradients algorithm for a given input + + Paper: [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf) + """ + + def __init__( + self, + validation_data, + class_index, + n_steps=5, + output_dir=Path("./logs/integrated_gradients"), + ): + """ + Constructor. + + Args: + validation_data (Tuple[np.ndarray, Optional[np.ndarray]]): Validation data + to perform the method on. Tuple containing (x, y). + class_index (int): Index of targeted class + n_steps (int): Number of steps in the path + output_dir (str): Output directory path + """ + super(IntegratedGradientsCallback, self).__init__() + self.validation_data = validation_data + self.class_index = class_index + self.n_steps = n_steps + self.output_dir = Path(output_dir) / datetime.now().strftime("%Y%m%d-%H%M%S.%f") + Path.mkdir(Path(self.output_dir), parents=True, exist_ok=True) + + self.file_writer = tf.summary.create_file_writer(str(self.output_dir)) + + def on_epoch_end(self, epoch, logs=None): + """ + Draw Integrated Gradients outputs at each epoch end to Tensorboard. + + Args: + epoch (int): Epoch index + logs (dict): Additional information on epoch + """ + explainer = IntegratedGradients() + grid = explainer.explain( + self.validation_data, self.model, self.class_index, self.n_steps + ) + + # Using the file writer, log the reshaped image. + with self.file_writer.as_default(): + tf.summary.image( + "IntegratedGradients", np.expand_dims([grid], axis=-1), step=epoch + ) diff --git a/tf_explain/core/integrated_gradients.py b/tf_explain/core/integrated_gradients.py new file mode 100644 index 0000000..c040a13 --- /dev/null +++ b/tf_explain/core/integrated_gradients.py @@ -0,0 +1,134 @@ +""" +Core Module for Integrated Gradients Algorithm +""" +from pathlib import Path + +import cv2 +import numpy as np +import tensorflow as tf + +from tf_explain.utils.display import grid_display +from tf_explain.utils.image import transform_to_normalized_grayscale + + +class IntegratedGradients: + + """ + Perform Integrated Gradients algorithm for a given input + + Paper: [Axiomatic Attribution for Deep Networks](https://arxiv.org/pdf/1703.01365.pdf) + """ + + def explain(self, validation_data, model, class_index, n_steps=10): + """ + Compute Integrated Gradients for a specific class index + + Args: + validation_data (Tuple[np.ndarray, Optional[np.ndarray]]): Validation data + to perform the method on. Tuple containing (x, y). + model (tf.keras.Model): tf.keras model to inspect + class_index (int): Index of targeted class + n_steps (int): Number of steps in the path + + Returns: + np.ndarray: Grid of all the integrated gradients + """ + images, _ = validation_data + + interpolated_images = IntegratedGradients.generate_interpolations( + np.array(images), n_steps + ) + + integrated_gradients = IntegratedGradients.get_integrated_gradients( + interpolated_images, model, class_index, n_steps + ) + + grayscale_integrated_gradients = transform_to_normalized_grayscale( + tf.abs(integrated_gradients) + ).numpy() + + grid = grid_display(grayscale_integrated_gradients) + + return grid + + @staticmethod + @tf.function + def get_integrated_gradients(interpolated_images, model, class_index, n_steps): + """ + Perform backpropagation to compute integrated gradients. + + Args: + interpolated_images (numpy.ndarray): 4D-Tensor of shape (N * n_steps, H, W, 3) + model (tf.keras.Model): tf.keras model to inspect + class_index (int): Index of targeted class + n_steps (int): Number of steps in the path + + Returns: + tf.Tensor: 4D-Tensor of shape (N, H, W, 3) with integrated gradients + """ + with tf.GradientTape() as tape: + inputs = tf.cast(interpolated_images, tf.float32) + tape.watch(inputs) + predictions = model(inputs) + loss = predictions[:, class_index] + + grads = tape.gradient(loss, inputs) + grads_per_image = tf.reshape(grads, (-1, n_steps, *grads.shape[1:])) + + integrated_gradients = tf.reduce_mean(grads_per_image, axis=1) + + return integrated_gradients + + @staticmethod + def generate_interpolations(images, n_steps): + """ + Generate interpolation paths for batch of images. + + Args: + images (numpy.ndarray): 4D-Tensor of images with shape (N, H, W, 3) + n_steps (int): Number of steps in the path + + Returns: + numpy.ndarray: Interpolation paths for each image with shape (N * n_steps, H, W, 3) + """ + baseline = np.zeros(images.shape[1:]) + + return np.concatenate( + [ + IntegratedGradients.generate_linear_path(baseline, image, n_steps) + for image in images + ] + ) + + @staticmethod + def generate_linear_path(baseline, target, n_steps): + """ + Generate the interpolation path between the baseline image and the target image. + + Args: + baseline (numpy.ndarray): Reference image + target (numpy.ndarray): Target image + n_steps (int): Number of steps in the path + + Returns: + List(np.ndarray): List of images for each step + """ + return [ + baseline + (target - baseline) * index / (n_steps - 1) + for index in range(n_steps) + ] + + def save(self, grid, output_dir, output_name): + """ + Save the output to a specific dir. + + Args: + grid (numpy.ndarray): Grid of all the heatmaps + output_dir (str): Output directory path + output_name (str): Output name + """ + Path.mkdir(Path(output_dir), parents=True, exist_ok=True) + + cv2.imwrite( + str(Path(output_dir) / output_name), cv2.cvtColor(grid, cv2.COLOR_RGB2BGR) + ) diff --git a/tf_explain/core/smoothgrad.py b/tf_explain/core/smoothgrad.py index 7b65607..7b28b47 100644 --- a/tf_explain/core/smoothgrad.py +++ b/tf_explain/core/smoothgrad.py @@ -8,6 +8,7 @@ import tensorflow as tf from tf_explain.utils.display import grid_display +from tf_explain.utils.image import transform_to_normalized_grayscale class SmoothGrad: @@ -41,8 +42,8 @@ def explain(self, validation_data, model, class_index, num_samples=5, noise=1.0) noisy_images, model, class_index, num_samples ) - grayscale_gradients = SmoothGrad.transform_to_grayscale( - smoothed_gradients + grayscale_gradients = transform_to_normalized_grayscale( + tf.abs(smoothed_gradients) ).numpy() grid = grid_display(grayscale_gradients) @@ -67,28 +68,6 @@ def generate_noisy_images(images, num_samples, noise): return repeated_images + noise - @staticmethod - @tf.function - def transform_to_grayscale(gradients): - """ - Transform gradients over RGB axis to grayscale. - - Args: - gradients (tf.Tensor): 4D-Tensor with shape (batch_size, H, W, 3) - - Returns: - tf.Tensor: 4D-Tensor of grayscale gradients, with shape (batch_size, H, W, 1) - """ - grayscale_grads = tf.reduce_sum(tf.abs(gradients), axis=-1) - normalized_grads = tf.cast( - 255 - * (grayscale_grads - tf.reduce_min(grayscale_grads)) - / (tf.reduce_max(grayscale_grads) - tf.reduce_min(grayscale_grads)), - tf.uint8, - ) - - return normalized_grads - @staticmethod @tf.function def get_averaged_gradients(noisy_images, model, class_index, num_samples): diff --git a/tf_explain/utils/image.py b/tf_explain/utils/image.py index 9168550..177b5f9 100644 --- a/tf_explain/utils/image.py +++ b/tf_explain/utils/image.py @@ -1,5 +1,6 @@ """ Module for image operations """ import numpy as np +import tensorflow as tf def apply_grey_patch(image, top_left_x, top_left_y, patch_size): @@ -21,3 +22,25 @@ def apply_grey_patch(image, top_left_x, top_left_y, patch_size): ] = 127.5 return patched_image + + +@tf.function +def transform_to_normalized_grayscale(tensor): + """ + Transform tensor over RGB axis to grayscale. + + Args: + tensor (tf.Tensor): 4D-Tensor with shape (batch_size, H, W, 3) + + Returns: + tf.Tensor: 4D-Tensor of grayscale tensor, with shape (batch_size, H, W, 1) + """ + grayscale_tensor = tf.reduce_sum(tensor, axis=-1) + normalized_tensor = tf.cast( + 255 + * (grayscale_tensor - tf.reduce_min(grayscale_tensor)) + / (tf.reduce_max(grayscale_tensor) - tf.reduce_min(grayscale_tensor)), + tf.uint8, + ) + + return normalized_tensor