In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
import tensorflow as tf
import numpy as np
import pandas as pd
import altair as alt
import shap

from marginal import MarginalExplainer
import plot
import matplotlib.pyplot as plt
import utils

KeyboardInterrupt: 

In [None]:
n = 50000
lamb = 0.5
d = (8, 8, 3)
epochs = 50
hidden_layers = 64
batch_size = 50
learning_rate = 0.001

In [None]:
X = np.random.randn(n, *d).astype(np.float32)
y = (np.sum(X[:, 1, :, 1], axis=-1) > 0).astype(np.float32) + \
    (np.sum(X[:, 1, :, 1], axis=-1) > 1).astype(np.float32) + \
    np.sum(X[:, 0, :, 0], axis=-1) * lamb

y = y.astype(np.float32)

X_train = X[:int(n * 0.8)]
y_train = y[:int(n * 0.8)]
X_val   = X[int(n * 0.8):]
y_val   = y[int(n * 0.8):]

In [None]:
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=d, batch_size=batch_size))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(hidden_layers, activation=tf.keras.activations.relu, use_bias=True))
model.add(tf.keras.layers.Dense(hidden_layers, activation=tf.keras.activations.relu, use_bias=True))
model.add(tf.keras.layers.Dense(hidden_layers, activation=tf.keras.activations.relu, use_bias=True))
model.add(tf.keras.layers.Dense(1, activation=None, use_bias=False))

In [None]:
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.MSE,
              metrics=[tf.keras.metrics.MeanAbsoluteError()])

In [None]:
model.fit(X_train, y_train, epochs=epochs, verbose=2, batch_size=batch_size, validation_data=(X_val, y_val))

In [None]:
num_shap_samples=10

In [None]:
primal_explainer = MarginalExplainer(model, np.zeros(1, 8, 8, 3).astype(np.float32), nsamples=200, representation='mobius')
primal_effects = primal_explainer.explain(X[:num_shap_samples], verbose=True)

In [None]:
model_func = lambda x: model(x.reshape(x.shape[0], 8, 8, 3).astype(np.float32)).numpy()
sample_explainer = shap.SamplingExplainer(model_func, np.zeros(1, 8, 8, 3).astype(np.float32))
shap_values = sample_explainer.shap_values(X[:num_shap_samples].reshape(num_shap_samples, -1))
shap_values = np.reshape(shap_values, (10, 8, 8, 3))

In [None]:
shap_values_mean_abs = np.mean(np.abs(shap_values), axis=0)
plt.imshow(utils.normalize(shap_values_mean_abs, _range=[0.0, 1.0], _domain=[0.0, shap_values_mean_abs.max()]))

In [None]:
primal_mean_abs = np.mean(np.abs(primal_effects), axis=0)
plt.imshow(utils.normalize(primal_mean_abs, _range=[0.0, 1.0], _domain=[0.0, shap_values_mean_abs.max()]))

In [None]:
interactions_mean_abs = np.mean(np.abs(shap_values - primal_effects), axis=0)
plt.imshow(utils.normalize(interactions_mean_abs, _range=[0.0, 1.0], _domain=[0.0, shap_values_mean_abs.max()]))