In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd
import altair as alt
from path_explain.path_explainer_tf import PathExplainerTF

In [3]:
baseline = np.random.randn(1000, 10)
inputs = np.random.randn(50, 10)

In [4]:
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Input(10, dtype=tf.float32))
model.add(tf.keras.layers.Dense(5, activation=tf.keras.activations.relu, use_bias=True))
model.add(tf.keras.layers.Dense(1, activation=None, use_bias=False))

In [5]:
explainer = PathExplainerTF(model)

In [6]:
for number_to_draw in [1, 100, 1000]:
    sampled_baseline = explainer._sample_baseline(baseline, number_to_draw, True)
    assert sampled_baseline.shape == (number_to_draw, 10), \
                                     "Expected: {}, Received: {}".format((number_to_draw, 10),
                                                                          sampled_baseline.shape)
    sampled_baseline = explainer._sample_baseline(baseline[0:1], number_to_draw, False)
    assert sampled_baseline.shape == (number_to_draw, 10), \
                                     "Expected: {}, Received: {}".format((number_to_draw, 10),
                                                                          sampled_baseline.shape)

In [7]:
current_alphas = explainer._sample_alphas(num_samples=100, use_expectation=True)
explainer._single_attribution(inputs[0], baseline,
                            current_alphas, num_samples=100, batch_size=50,
                            use_expectation=True, output_index=None)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



array([-0.07984692, -0.11272831,  0.09528565, -0.00645371, -0.11223353,
       -0.28691726,  0.40059981,  0.0028297 , -0.18274488,  0.20822725])

In [8]:
current_alphas

array([0.12279019, 0.67769281, 0.84169407, 0.70511165, 0.00265455,
       0.44090306, 0.09369319, 0.05260025, 0.43436406, 0.17005723,
       0.73707246, 0.98173315, 0.34009972, 0.22936328, 0.57953494,
       0.43395269, 0.25288484, 0.33651185, 0.39387707, 0.82774975,
       0.50674553, 0.06165431, 0.0738939 , 0.91143068, 0.72668137,
       0.58994742, 0.49104549, 0.27581179, 0.38785292, 0.65675242,
       0.01218733, 0.41294979, 0.17313672, 0.00985133, 0.49858228,
       0.36609394, 0.7225146 , 0.82125531, 0.77717948, 0.96296571,
       0.61650889, 0.66459346, 0.7513812 , 0.00788782, 0.38531933,
       0.74815795, 0.39474457, 0.83674689, 0.54913175, 0.36205168,
       0.55060591, 0.37934378, 0.94637996, 0.87050619, 0.79113655,
       0.58980651, 0.5040989 , 0.24805337, 0.74425321, 0.96735304,
       0.53158117, 0.82961234, 0.2008738 , 0.26603568, 0.90905753,
       0.79384321, 0.96605803, 0.66512964, 0.14308583, 0.75248542,
       0.37419097, 0.62967183, 0.55979809, 0.77136192, 0.14769

In [9]:
explainer.attributions(inputs, baseline,
                     batch_size=50, num_samples=100,
                     use_expectation=True, output_indices=None,
                     verbose=False)

array([[[-1.07834415e-01, -1.68414574e-01,  1.28650983e-01,
          2.77651834e-02, -1.42457853e-01, -3.54354974e-01,
          4.39486893e-01, -1.88926609e-04, -1.81691181e-01,
          2.20601897e-01],
        [-3.22946949e-02, -4.68478913e-02, -2.61244609e-02,
         -2.32070177e-02,  9.46304906e-02,  4.91508526e-03,
          8.26230794e-03,  3.08964067e-02, -4.64665141e-02,
          1.73579641e-02],
        [ 1.84936065e-01,  5.44597127e-01, -4.20986472e-01,
         -3.71642429e-01,  1.20786647e-02, -1.33254624e-01,
          2.87803825e-01,  3.48104300e-02,  9.83399491e-01,
         -1.65271455e-01],
        [ 1.67303229e-02,  2.13835263e-04, -1.52581209e-02,
          5.06169615e-02, -1.61871390e-02, -1.72167027e-02,
         -2.95359505e-02,  1.28508805e-02, -1.06878544e-01,
          8.26645384e-02],
        [-2.41410004e-01, -1.16921086e-01, -4.26272775e-01,
         -1.88412565e-02,  2.32495372e-01, -2.71090748e-02,
         -1.56312207e-01, -5.12081947e-02, -6.422287