Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SHAP not working with LSTM! #3344

Open
MohamedNedal opened this issue Oct 16, 2023 Discussed in #3342 · 7 comments · May be fixed by #3419
Open

SHAP not working with LSTM! #3344

MohamedNedal opened this issue Oct 16, 2023 Discussed in #3342 · 7 comments · May be fixed by #3419
Labels
deep explainer Relating to DeepExplainer, tensorflow or pytorch

Comments

@MohamedNedal
Copy link

Discussed in #3342

Originally posted by MohamedNedal October 16, 2023
Hello, I have a trained LSTM mode for timeseries foreasting and I cannot use SHAP with it.
I have checked the tutorials and the discussions here on similar problems and I tried those suggestions but still didn't work, unfortunately. ChatGPT couldn't help either as it's something specific to the SHAP routines.
I tried to run the code in another environment with a downgraded TF version, as some suggested, but still didn't work and had more compatibility problems.
My current versions of shap and tensorflow are

TF ver. 2.3.0
shap ver. 0.41.0

Here is a minimal example code to reproduce the error:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
tf.compat.v1.enable_eager_execution()
import shap
from tqdm import tqdm
shap.initjs()



# Define the start and end datetime
start_datetime = pd.to_datetime('2020-01-01 00:00:00')
end_datetime = pd.to_datetime('2023-12-31 23:00:00')

# Generate a DatetimeIndex with hourly frequency
date_rng = pd.date_range(start=start_datetime, end=end_datetime, freq='H')

# Create a DataFrame with random data for 7 features
num_samples = len(date_rng)
num_features = 7

# Generate random data for the DataFrame
data = np.random.rand(num_samples, num_features)

# Create the DataFrame with a DatetimeIndex
df = pd.DataFrame(data, index=date_rng, columns=[f'X{i}' for i in range(1, num_features+1)])


def windowed_dataset(series=None, in_horizon=None, out_horizon=None, delay=None, batch_size=None):
    '''
    Convert multivariate data into input and output sequences.
    Convert NumPy arrays to TensorFlow tensors.
    Arguments:
    ===========
    series: a list or array of time-series data.
    total_horizon: an integer representing the size of the input window.
    out_horizon: an integer representing the size of the output window.
    delay: an integer representing the number of steps between each input window.
    batch_size: an integer representing the batch size. 
    '''
    total_horizon = in_horizon + out_horizon
    dataset = tf.data.Dataset.from_tensor_slices(series)
    dataset = dataset.window(total_horizon, shift=delay, drop_remainder=True)
    dataset = dataset.flat_map(lambda window: window.batch(total_horizon))
    dataset = dataset.map(lambda window: (window[:-out_horizon,:], window[-out_horizon:,0]))
    dataset = dataset.batch(batch_size).prefetch(1)
    return dataset



# Define the proportions for the splits (70:20:10)%
train_size = 0.7
valid_size = 0.2
test_size = 0.1

# Calculate the split points
train_split = int(len(df)*train_size)
valid_split = int(len(df)*(train_size + valid_size))

# Split the DataFrame
df_train = df.iloc[:train_split]
df_valid = df.iloc[train_split:valid_split]
df_test = df.iloc[valid_split:]

# number of input features and output targets
n_features = df.shape[1]

# split the data into sliding sequential windows
train_dataset = windowed_dataset(series=df_train.values, 
                                 in_horizon=100, 
                                 out_horizon=3, 
                                 delay=1, 
                                 batch_size=32)

valid_dataset = windowed_dataset(series=df_valid.values, 
                                 in_horizon=100, 
                                 out_horizon=3, 
                                 delay=1, 
                                 batch_size=32)

test_dataset = windowed_dataset(series=df_test.values, 
                                in_horizon=100, 
                                out_horizon=3, 
                                delay=1, 
                                batch_size=32)

input_layer = tf.keras.layers.Input(shape=(100, n_features))
lstm_layer1 = tf.keras.layers.LSTM(5, return_sequences=True)(input_layer)
lstm_layer2 = tf.keras.layers.LSTM(5, return_sequences=True)(lstm_layer1)
lstm_layer3 = tf.keras.layers.LSTM(5)(lstm_layer2)
output_layer = tf.keras.layers.Dense(3)(lstm_layer3)
model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss='mse', metrics=['mae'])
history = model.fit(train_dataset, epochs=5, validation_data=valid_dataset, verbose=1)

def tensor_to_arrays(input_obj=None):
    '''
    Convert a "tensorflow.python.data.ops.dataset_ops.PrefetchDataset" object into a numpy arrays.
    This function can be used to slice the tensor objects out of the `windowing` function.
    '''
    x = list(map(lambda x: x[0], input_obj))
    y = list(map(lambda x: x[1], input_obj))
    
    x_ = [xtmp.numpy() for xtmp in x]
    y_ = [ytmp.numpy() for ytmp in y]
    
    # Stack the arrays vertically
    x = np.vstack(x_)
    y = np.vstack(y_)
    
    return x, y


xarr, yarr = tensor_to_arrays(input_obj=train_dataset)

# Create an explainer object
explainer = shap.DeepExplainer(model, xarr)

# Calculate SHAP values for the data
#shap_values = explainer(xarr)
shap_values = explainer.shap_values(xarr)

The error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-19-ab3ee3b4955f> in <module>
      1 # Calculate SHAP values for the data portion
      2 #shap_values = explainer(xarr)
----> 3 shap_values = explainer.shap_values(xarr)

~/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/__init__.py in shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
    122             were chosen as "top".
    123         """
--> 124         return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)

~/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py in shap_values(self, X, ranked_outputs, output_rank_order, check_additivity)
    310                 # run attribution computation graph
    311                 feature_ind = model_output_ranks[j,i]
--> 312                 sample_phis = self.run(self.phi_symbolic(feature_ind), self.model_inputs, joint_input)
    313 
    314                 # assign the attributions to the right part of the output arrays

~/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py in run(self, out, model_inputs, X)
    370 
    371                 return final_out
--> 372             return self.execute_with_overridden_gradients(anon)
    373 
    374     def custom_grad(self, op, *grads):

~/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py in execute_with_overridden_gradients(self, f)
    406         # define the computation graph for the attribution values using a custom gradient-like computation
    407         try:
--> 408             out = f()
    409         finally:
    410             # reinstate the backpropagatable check

~/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py in anon()
    363                     v = tf.constant(data, dtype=self.model_inputs[i].dtype)
    364                     inputs.append(v)
--> 365                 final_out = out(inputs)
    366                 try:
    367                     tf_execute.record_gradient = tf_backprop._record_gradient

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    821       # This is the first call of __call__, so we have to initialize.
    822       initializers = []
--> 823       self._initialize(args, kwds, add_initializers_to=initializers)
    824     finally:
    825       # At this point we know that the initialization is complete (or less

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    695     self._concrete_stateful_fn = (
    696         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 697             *args, **kwds))
    698 
    699     def invalid_creator_scope(*unused_args, **unused_kwds):

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2853       args, kwargs = None, None
   2854     with self._lock:
-> 2855       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2856     return graph_function
   2857 

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3211 
   3212       self._function_cache.missed.add(call_context_key)
-> 3213       graph_function = self._create_graph_function(args, kwargs)
   3214       self._function_cache.primary[cache_key] = graph_function
   3215       return graph_function, args, kwargs

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3073             arg_names=arg_names,
   3074             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075             capture_by_value=self._capture_by_value),
   3076         self._function_attributes,
   3077         function_spec=self.function_spec,

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    984         _, original_func = tf_decorator.unwrap(python_func)
    985 
--> 986       func_outputs = python_func(*func_args, **func_kwargs)
    987 
    988       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    598         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    599         # the function a weak reference to itself to avoid a reference cycle.
--> 600         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    601     weak_wrapped_fn = weakref.ref(wrapped_fn)
    602 

~/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    971           except Exception as e:  # pylint:disable=broad-except
    972             if hasattr(e, "ag_error_metadata"):
--> 973               raise e.ag_error_metadata.to_exception(e)
    974             else:
    975               raise

AttributeError: in user code:

    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py:247 grad_graph  *
        out = self.model(shap_rAnD)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:985 __call__  **
        outputs = call_fn(inputs, *args, **kwargs)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:386 call
        inputs, training=training, mask=mask)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:508 _run_internal_graph
        outputs = node.layer(*args, **kwargs)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent.py:663 __call__
        return super(RNN, self).__call__(inputs, **kwargs)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:985 __call__
        outputs = call_fn(inputs, *args, **kwargs)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent_v2.py:1183 call
        runtime) = lstm_with_backend_selection(**normal_lstm_kwargs)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/keras/layers/recurrent_v2.py:1559 lstm_with_backend_selection
        function.register(defun_gpu_lstm, **params)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py:3241 register
        concrete_func.add_gradient_functions_to_graph()
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py:2063 add_gradient_functions_to_graph
        self._delayed_rewrite_functions.forward_backward())
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py:621 forward_backward
        forward, backward = self._construct_forward_backward(num_doutputs)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py:669 _construct_forward_backward
        func_graph=backwards_graph)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py:986 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/eager/function.py:659 _backprop_function
        src_graph=self._func_graph)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/ops/gradients_util.py:669 _GradientsHelper
        lambda: grad_fn(op, *out_grads))
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/ops/gradients_util.py:336 _MaybeCompile
        return grad_fn()  # Exit early
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/tensorflow/python/ops/gradients_util.py:669 <lambda>
        lambda: grad_fn(op, *out_grads))
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py:378 custom_grad
        out = op_handlers[type_name](self, op, *grads) # we cut off the shap_ prefex before the lookup
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py:667 handler
        return linearity_with_excluded_handler(input_inds, explainer, op, *grads)
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py:674 linearity_with_excluded_handler
        assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
    /home/mnedal/.conda/envs/lstm/lib/python3.6/site-packages/shap/explainers/_deep/deep_tf.py:224 _variable_inputs
        out[i] = t.name in self.between_tensors

    AttributeError: 'TFDeep' object has no attribute 'between_tensors'

I would really appreciate your inputs. Thanks in advance!

@CloseChoice
Copy link
Collaborator

Thanks for the report and especially for the reproducing example. I know this is a longstanding issue and I'll look into this in the coming weeks.

@lordegeology
Copy link

Hi @MohamedNedal. I have encountered the exact same issue in my project. Have you attempted to use any other feature importance algorithm that is compatible with an LSTM?

@CloseChoice
Copy link
Collaborator

So, I have looked into this and am working on a solution. The problem is that with tf version 2 it is not as easy to extract the graph of the model (since tensorflow is just building the graph with lazy execution). I found a solution to catch the graph if we do a forward pass through the model with some example data. But getting this data reliably (conversion of keras tensors to such a form that they can be used to run lazy tf models) is a bit tricky. Will keep you updated if I make further progress.

@MohamedNedal
Copy link
Author

Hi @MohamedNedal. I have encountered the exact same issue in my project. Have you attempted to use any other feature importance algorithm that is compatible with an LSTM?

Hi @lordegeology , I'm actually not familiar with such tools. I used to find the feature importance via the correlation matrix, but it's not so reliable and limited as it doesn't show the feature importance over time.

@castillohair
Copy link

Been having the same issue for a while, and a solution would be greatly appreciated. Thanks @CloseChoice !

@GaoRuiXiang22
Copy link

I found that the LSTM of a lower version of pytorch can get results through shap, but there will still be a warning about an unrecognized model.pyTorch==1.11.0

@CloseChoice CloseChoice linked a pull request Dec 2, 2023 that will close this issue
7 tasks
@CloseChoice
Copy link
Collaborator

I added a draft PR where the code at least runs through. I suspect that there is still something wrong in how we handle new operations (so basically we ignore them in the custom gradient calculation, which might be wrong), so use the results with caution. Any test cases for this are highly appreciated so feel free to start a review and post them there! Together we can do this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deep explainer Relating to DeepExplainer, tensorflow or pytorch
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants