In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from path_explain import PathExplainerTF, summary_plot, scatter_plot, set_up_environment, softplus_activation

In [2]:
n = 5000
d = 5
noise = 0.5
X = np.random.randn(n, d).astype(np.float32)
y = np.prod(X[:, 0:2], axis=-1)

In [3]:
threshold = int(n * 0.8)
x_train = X[:threshold]
y_train = y[:threshold]
x_test  = X[threshold:]
y_test  = y[threshold:]

In [4]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(shape=(d,)))
model.add(tf.keras.layers.Dense(units=10,
                                use_bias=True,
                                activation=tf.keras.activations.relu))
model.add(tf.keras.layers.Dense(units=5,
                                use_bias=True,
                                activation=tf.keras.activations.relu))
model.add(tf.keras.layers.Dense(units=1,
                                use_bias=False,
                                activation=None))

In [5]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 10)                60        
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 55        
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 5         
Total params: 120
Trainable params: 120
Non-trainable params: 0
_________________________________________________________________


In [6]:
learning_rate = 0.1
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate),
              loss=tf.keras.losses.MeanSquaredError())

In [7]:
model.fit(x_train, y_train, batch_size=50, epochs=20, verbose=2, validation_data=(x_test, y_test))

Train on 4000 samples, validate on 1000 samples
Epoch 1/20
4000/4000 - 0s - loss: 0.3718 - val_loss: 0.1381
Epoch 2/20
4000/4000 - 0s - loss: 0.1303 - val_loss: 0.0999
Epoch 3/20
4000/4000 - 0s - loss: 0.1020 - val_loss: 0.0775
Epoch 4/20
4000/4000 - 0s - loss: 0.0777 - val_loss: 0.0745
Epoch 5/20
4000/4000 - 0s - loss: 0.0583 - val_loss: 0.0474
Epoch 6/20
4000/4000 - 0s - loss: 0.0456 - val_loss: 0.0426
Epoch 7/20
4000/4000 - 0s - loss: 0.0395 - val_loss: 0.0303
Epoch 8/20
4000/4000 - 0s - loss: 0.0382 - val_loss: 0.0279
Epoch 9/20
4000/4000 - 0s - loss: 0.0359 - val_loss: 0.0283
Epoch 10/20
4000/4000 - 0s - loss: 0.0353 - val_loss: 0.0284
Epoch 11/20
4000/4000 - 0s - loss: 0.0347 - val_loss: 0.0374
Epoch 12/20
4000/4000 - 0s - loss: 0.0342 - val_loss: 0.0277
Epoch 13/20
4000/4000 - 0s - loss: 0.0337 - val_loss: 0.0275
Epoch 14/20
4000/4000 - 0s - loss: 0.0328 - val_loss: 0.0283
Epoch 15/20
4000/4000 - 0s - loss: 0.0334 - val_loss: 0.0255
Epoch 16/20
4000/4000 - 0s - loss: 0.0328 - va

<tensorflow.python.keras.callbacks.History at 0x7fd54009bc10>

In [8]:
interpret_model = tf.keras.models.clone_model(model)
interpret_model.layers[0].activation = softplus_activation(beta=10.0)
interpret_model.layers[1].activation = softplus_activation(beta=10.0)

In [None]:
path_explainer  = PathExplainerTF(interpret_model)

ih_interactions = path_explainer.interactions(inputs=x_test[:200],
                                              baseline=np.zeros((1, x_test.shape[1])).astype(np.float32),
                                              batch_size=50,
                                              num_samples=50,
                                              use_expectation=False,
                                              output_indices=0,
                                              verbose=True,
                                              interaction_index=None)

 84%|████████▎ | 167/200 [01:59<00:23,  1.40it/s]

In [None]:
attributions = path_explainer.attributions(inputs=x_test[:200],
                                              baseline=np.zeros((1, x_test.shape[1])).astype(np.float32),
                                              batch_size=50,
                                              num_samples=50,
                                              use_expectation=False,
                                              output_indices=0,
                                              verbose=True)

In [None]:
summary_plot(attributions,
             feature_values=x_test[:200],
             plot_top_k=5)

In [None]:
scatter_plot(attributions,
             feature_values=x_test[:200],
             feature_index=0,
             interactions=ih_interactions,
             color_by=1,
             scale_x_ind=True,
             scale_y_ind=True,
             plot_main=False)