Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/1457 add support for static covariates to shap #1732

Conversation

anne-devries
Copy link
Contributor

@anne-devries anne-devries commented Apr 26, 2023

Before you couldn't use ShapExplainer when your regression model used static covariates. This PR fixes this issue.

Summary

I made sure static covariates are added to the features (X) in shap_explainer

Other Information

After the latest release, a function is used to create the lagged_component_names, I made sure this function was also used in shap_explainer to retrieve these instead of the old logic.

@anne-devries
Copy link
Contributor Author

@dennisbader I created the pull request for this change. Not sure if I missed something, if so just let me know! Also:

The PR I now created contains this unittest: test_shapley_multiple_series_with_different_static_covs. In tabularization.py, if the model uses static_covariates, only the static covariates of one target series are used. Shouldn't there be a validation check to see if different target series have different static covariates?

anne-devries and others added 5 commits April 26, 2023 13:44
…covariates_to_shap' into fix/1457_add_support_for_static_covariates_to_shap

# Conflicts:
#	darts/tests/explainability/test_shap_explainer.py
Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @anne-devries,

Thank you for contributing to darts! Looks good, a few minors comments about optimization.

Also, it seems like there are some linting issues : can you please follow the instructions to apply it automatically?

Thank you for adding tests!

darts/explainability/shap_explainer.py Outdated Show resolved Hide resolved
darts/explainability/shap_explainer.py Show resolved Hide resolved
darts/explainability/shap_explainer.py Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented May 1, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: -0.10 ⚠️

Comparison is base (f580e97) 94.18% compared to head (7d423dd) 94.09%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1732      +/-   ##
==========================================
- Coverage   94.18%   94.09%   -0.10%     
==========================================
  Files         125      125              
  Lines       11491    11468      -23     
==========================================
- Hits        10823    10791      -32     
- Misses        668      677       +9     
Impacted Files Coverage Δ
darts/explainability/shap_explainer.py 89.86% <100.00%> (+0.92%) ⬆️

... and 10 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@dennisbader
Copy link
Collaborator

Thanks for this @anne-devries, this was long on our TODO list 🚀 let me review it next week.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @anne-devries and thanks for this great PR. 🚀

Just had some minor suggestions after which it'll be ready to get merged!

@@ -5,6 +5,8 @@ We do our best to avoid the introduction of breaking changes,
but cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "🔴".

## [Unreleased](https://github.com/unit8co/darts/tree/master)
- Added static covariates to ShapExplainer - you can now use RegressionModels with static covariates **and** generate shapley values for them
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you adapt the entry with the same style as previous entries?
e.g. add PR and your user reference

Comment on lines +715 to +717
X, _ = add_static_covariates_to_lagged_data(
X, target_series, uses_static_covariates=self.model.uses_static_covariates
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need to check for the last static covariates shape here as the forecasting model was already trained

Suggested change
X, _ = add_static_covariates_to_lagged_data(
X, target_series, uses_static_covariates=self.model.uses_static_covariates
)
X, _ = add_static_covariates_to_lagged_data(
X,
target_series,
uses_static_covariates=self.model.uses_static_covariates,
last_shape=self.model._static_covariates_shape,
)

x_1.reshape(-1, 1),
static_covariates=pd.DataFrame({"type": [0], "state": [1]}),
).with_columns_renamed(["0"], ["price"])
target_ts_with_static_covs_multiple_series = TimeSeries.from_times_and_values(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicking, we use multiple series naming only for a list of series

Suggested change
target_ts_with_static_covs_multiple_series = TimeSeries.from_times_and_values(
target_ts_with_multi_component_static_covs = TimeSeries.from_times_and_values(

@@ -670,3 +689,45 @@ def test_shap_explanation_object_validity(self):
),
shap.Explanation,
)

def test_shapley_with_static_cov(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also test that not using static covariates and static covariates with wrong dimensionality will raise an error?

Comment on lines +713 to +716
assert len(explanation_results.explained_forecasts[1]["price"].columns) == (
-(min(model.lags["target"])) * model.input_dim["target"]
+ model.input_dim["target"] * model.static_covariates.shape[1]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just some additional testing

Suggested change
assert len(explanation_results.explained_forecasts[1]["price"].columns) == (
-(min(model.lags["target"])) * model.input_dim["target"]
+ model.input_dim["target"] * model.static_covariates.shape[1]
)
for comp in self.target_ts_with_static_covs_multiple_series.components:
comps_out = explanation_results.explained_forecasts[1][comp].columns
assert len(comps_out) == (
-(min(model.lags["target"])) * model.input_dim["target"]
+ model.input_dim["target"] * model.static_covariates.shape[1]
)
assert comps_out[-4:].tolist() == [
'type_statcov_target_price',
'type_statcov_target_power',
'state_statcov_target_price',
'state_statcov_target_power'
]

explanation_results = shap_explain.explain()

self.assertTrue(len(explanation_results.feature_values) == 2)
# test black
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# test black
# model trained on multiple series will take column names of first series -> even though
# static covs have different names, the output will show the same names
for explained_forecast in explanation_results.explained_forecasts:
comps_out = explained_forecast[1]["price"].columns.tolist()
assert comps_out[-1] == "type_statcov_target_price"

@dennisbader
Copy link
Collaborator

Hi @anne-devries, thanks for your work! I will close this one as I had to implement the changes in the meantime for PR #1803. I added your contribution to the CHANGELOG of that PR.

@anne-devries
Copy link
Contributor Author

@dennisbader thanks! Sorry that I didn't fix those suggestions you had earlier on. Create that this functionality is now merged though! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants