In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
import tensorflow as tf
import numpy as np
import pandas as pd
import altair as alt
import shap

from marginal import MarginalExplainer
import plot
import matplotlib.pyplot as plt
import utils

In [4]:
tf.enable_eager_execution()

In [5]:
n = 5000
d = (8, 8, 3)
hidden_layers = 64
batch_size = 50
learning_rate = 0.005

In [6]:
X = np.random.randn(n, *d)
y = np.sum(X[:, 0, :, 0], axis=-1) + np.prod(X[:, 1, :, 1], axis=-1) * 4

In [7]:
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=d, batch_size=batch_size))
model.add(tf.keras.layers.Conv2D(filters=8, kernel_size=2, strides=1, padding='same', activation=tf.keras.activations.relu, use_bias=True))
model.add(tf.keras.layers.Conv2D(filters=16, kernel_size=2, strides=1, padding='same', activation=tf.keras.activations.relu, use_bias=True))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(hidden_layers, activation=tf.keras.activations.relu, use_bias=True))
model.add(tf.keras.layers.Dense(1, activation=None, use_bias=False))

In [8]:
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.MSE,
              metrics=[tf.keras.metrics.MeanAbsoluteError(), tf.keras.metrics.MeanSquaredError()])

In [9]:
model.fit(X, y, epochs=50, verbose=2, batch_size=batch_size)

Train on 5000 samples
Epoch 1/50


UnknownError:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[node sequential/conv2d/Conv2D (defined at /homes/gws/psturm/anaconda3/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1748) ]] [Op:__inference_distributed_function_1047]

Function call stack:
distributed_function


In [None]:
num_shap_samples=10

In [None]:
primal_explainer = MarginalExplainer(model, X[num_shap_samples:], nsamples=200, representation='mobius')
primal_effects = primal_explainer.explain(X[:num_shap_samples], verbose=True)

In [None]:
model_func = lambda x: model(x.reshape(x.shape[0], 8, 8, 3)).numpy()
sample_explainer = shap.SamplingExplainer(model_func, X[100:300].reshape(200, -1))
shap_values = sample_explainer.shap_values(X[:num_shap_samples].reshape(num_shap_samples, -1))
shap_values = np.reshape(shap_values, (10, 8, 8, 3))

In [None]:
primal_mean_abs = np.mean(np.abs(primal_effects), axis=0)
plt.imshow(utils.normalize(primal_mean_abs, _range=[0.0, 1.0], _domain=[0.0, primal_mean_abs.max()]))

In [None]:
np.mean(np.abs(shap_values), axis=(0, 2, 3))

In [None]:
shap_values_mean_abs = np.mean(np.abs(shap_values), axis=0)
plt.imshow(utils.normalize(shap_values_mean_abs, _range=[0.0, 1.0], _domain=[0.0, shap_values_mean_abs.max()]))

In [None]:
interactions_mean_abs = np.mean(np.abs(shap_values - primal_effects), axis=0)
plt.imshow(utils.normalize(interactions_mean_abs, _range=[0.0, 1.0], _domain=[0.0, interactions_mean_abs.max()]))