# tf-explain

[![Pypi Version](https://img.shields.io/pypi/v/tf-explain.svg)](https://pypi.org/project/tf-explain/)
[![Build Status](https://api.travis-ci.org/sicara/tf-explain.svg?branch=master)](https://travis-ci.org/sicara/tf-explain)
[![Documentation Status](https://readthedocs.org/projects/tf-explain/badge/?version=latest)](https://tf-explain.readthedocs.io/en/latest/?badge=latest)
![Python Versions](https://img.shields.io/badge/python-3.6%20|%203.7-%23EBBD68.svg)
![Tensorflow Versions](https://img.shields.io/badge/tensorflow-2.0.0-blue.svg)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black)

__tf-explain__ implements interpretability methods as Tensorflow 2.0 callbacks to __ease neural network's understanding__.  
See [Introducing tf-explain, Interpretability for Tensorflow 2.0](https://blog.sicara.com/tf-explain-interpretability-tensorflow-2-9438b5846e35)

__Documentation__: https://tf-explain.readthedocs.io

## tf-explain example over mnist/fashion-mnist

In [1]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

In [5]:
pip install -e .

[31mERROR: File "setup.py" not found. Directory cannot be installed in editable mode: /home/raph/Sicara/tf-explain/examples/callbacks[0m
You should consider upgrading via the 'pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [6]:
import numpy as np
import tf_explain

In [7]:
AVAILABLE_DATASETS = {
    'mnist': tf.keras.datasets.mnist,
    'fashion_mnist': tf.keras.datasets.fashion_mnist,
}

In [8]:
#@CHOOSE A DATASET

DATASET_NAME = 'mnist' #@param ["fashion_mnist", "mnist"]
print(DATASET_NAME)

mnist


In [9]:
INPUT_SHAPE = (28, 28, 1)
NUM_CLASSES = 10

In [10]:
# Load dataset
dataset = AVAILABLE_DATASETS[DATASET_NAME]
(train_images, train_labels), (test_images, test_labels) = dataset.load_data()

# Convert from (28, 28) images to (28, 28, 1)
train_images = train_images[..., tf.newaxis].astype('float32')
test_images = test_images[..., tf.newaxis].astype('float32')

# One hot encore labels 0, 1, .., 9 to [0, 0, .., 1, 0, 0]
train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=NUM_CLASSES)
test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=NUM_CLASSES)

In [11]:
# Create model
img_input = tf.keras.Input(INPUT_SHAPE)

x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(img_input)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', name='target_layer')(x)
x = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(x)

x = tf.keras.layers.Dropout(0.25)(x)
x = tf.keras.layers.Flatten()(x)

x = tf.keras.layers.Dense(128, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)

x = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)

model = tf.keras.Model(img_input, x)

In [12]:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [13]:
# Select a subset of the validation data to examine
# Here, we choose 5 elements with one hot encoded label "0" == [1, 0, 0, .., 0]
validation_class_zero = (np.array([
    el for el, label in zip(test_images, test_labels)
    if np.all(np.argmax(label) == 0)
][0:5]), None)

In [14]:
# Select a subset of the validation data to examine
# Here, we choose 5 elements with one hot encoded label "4" == [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
validation_class_fours = (np.array([
    el for el, label in zip(test_images, test_labels)
    if np.all(np.argmax(label) == 4)
][0:5]), None)

In [15]:
# Instantiate callbacks
# class_index value should match the validation_data selected above
callbacks = [
    tf_explain.callbacks.GradCAMCallback(validation_class_zero, layer_name='target_layer', class_index=0),
    tf_explain.callbacks.GradCAMCallback(validation_class_fours, layer_name='target_layer', class_index=4),
    tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),
    tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),
    tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10),
    tf_explain.callbacks.VanillaGradientsCallback(validation_class_zero, class_index=0),
]

In [16]:
%load_ext tensorboard
%tensorboard --logdir logs

In [None]:
# Start training
model.fit(train_images, train_labels, epochs=5, callbacks=callbacks)

Train on 60000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5