Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/lagged data static covs #1803

Merged
merged 6 commits into from
May 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Improvements to `EnsembleModel`:
- Model creation parameter `forecasting_models` now supports a mix of `LocalForecastingModel` and `GlobalForecastingModel` (single `TimeSeries` training/inference only, due to the local models). [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
- Future and past covariates can now be used even if `forecasting_models` have different covariates support. The covariates passed to `fit()`/`predict()` are used only by models that support it. [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
- Improvements to `ShapExplainer`:
- Added static covariates support to `ShapeExplainer`. [#1803](https://github.com/unit8co/darts/pull/#1803) by [Anne de Vries](https://github.com/anne-devries) and [Dennis Bader](https://github.com/dennisbader).

**Fixed**
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when loading the weights of a `TorchForecastingModel` trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when using selected `target_components` with `ShapExplainer. [#1803](https://github.com/unit8co/darts/pull/#1803) by [Dennis Bader](https://github.com/dennisbader).

## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
### For users of the library:
Expand Down
21 changes: 9 additions & 12 deletions darts/explainability/explainability_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import Any, Dict, Optional, Sequence, Union

import shap
from numpy import integer

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not
Expand All @@ -26,8 +25,8 @@ class ExplainabilityResult(ABC):
def __init__(
self,
explained_forecasts: Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
Dict[int, Dict[str, TimeSeries]],
Sequence[Dict[int, Dict[str, TimeSeries]]],
],
):
self.explained_forecasts = explained_forecasts
Expand Down Expand Up @@ -61,9 +60,7 @@ def get_explanation(

def _query_explainability_result(
self,
attr: Union[
Dict[integer, Dict[str, Any]], Sequence[Dict[integer, Dict[str, Any]]]
],
attr: Union[Dict[int, Dict[str, Any]], Sequence[Dict[int, Dict[str, Any]]]],
horizon: int,
component: Optional[str] = None,
) -> Any:
Expand Down Expand Up @@ -141,16 +138,16 @@ class ShapExplainabilityResult(ExplainabilityResult):
def __init__(
self,
explained_forecasts: Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
Dict[int, Dict[str, TimeSeries]],
Sequence[Dict[int, Dict[str, TimeSeries]]],
],
feature_values: Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
Dict[int, Dict[str, TimeSeries]],
Sequence[Dict[int, Dict[str, TimeSeries]]],
],
shap_explanation_object: Union[
Dict[integer, Dict[str, shap.Explanation]],
Sequence[Dict[integer, Dict[str, shap.Explanation]]],
Dict[int, Dict[str, shap.Explanation]],
Sequence[Dict[int, Dict[str, shap.Explanation]]],
],
):
super().__init__(explained_forecasts)
Expand Down
32 changes: 11 additions & 21 deletions darts/explainability/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import matplotlib.pyplot as plt
import pandas as pd
import shap
from numpy import integer
from sklearn.multioutput import MultiOutputRegressor

from darts import TimeSeries
Expand Down Expand Up @@ -563,7 +562,7 @@ def shap_explanations(
foreground_X,
horizons: Optional[Sequence[int]] = None,
target_components: Optional[Sequence[str]] = None,
) -> Dict[integer, Dict[str, shap.Explanation]]:
) -> Dict[int, Dict[str, shap.Explanation]]:

"""
Return a dictionary of dictionaries of shap.Explanation instances:
Expand All @@ -577,7 +576,7 @@ def shap_explanations(
Optionally, a list of integers representing which points/steps in the future we want to explain,
starting from the first prediction step at 1. Currently, only forecasting models are supported which
provide an `output_chunk_length` parameter. `horizons` must not be larger than `output_chunk_length`.
target_names
target_components
Optionally, a list of strings with the target components we want to explain.

"""
Expand All @@ -589,7 +588,9 @@ def shap_explanations(

for h in horizons:
tmp_n = {}
for t_idx, t in enumerate(target_components):
for t_idx, t in enumerate(self.target_components):
if t not in target_components:
continue
explainer = self.explainers[h - 1][t_idx](foreground_X)
explainer.base_values = explainer.base_values.ravel()
explainer.time_index = foreground_X.index
Expand All @@ -601,6 +602,8 @@ def shap_explanations(
for h in horizons:
tmp_n = {}
for t_idx, t in enumerate(target_components):
if t not in target_components:
continue
if not self.single_output:
tmp_t = shap.Explanation(
shap_explanation_tmp.values[
Expand Down Expand Up @@ -702,6 +705,8 @@ def _create_regression_model_shap_X(
lags_future_covariates=lags_future_covariates_list
if future_covariates
else None,
uses_static_covariates=self.model.uses_static_covariates,
last_static_covariates_shape=self.model._static_covariates_shape,
)
# Remove sample axis:
X = X[:, :, 0]
Expand All @@ -720,26 +725,11 @@ def _create_regression_model_shap_X(
if n_samples:
X = shap.utils.sample(X, n_samples)

# We keep the creation order of the different lags/features in create_lagged_data
lags_names_list = []
if lags_list:
for lag in lags_list:
for t_name in self.target_components:
lags_names_list.append(t_name + "_target_lag" + str(lag))
if lags_past_covariates_list:
for lag in lags_past_covariates_list:
for t_name in self.past_covariates_components:
lags_names_list.append(t_name + "_past_cov_lag" + str(lag))
if lags_future_covariates_list:
for lag in lags_future_covariates_list:
for t_name in self.future_covariates_components:
lags_names_list.append(t_name + "_fut_cov_lag" + str(lag))

# rename output columns to the matching lagged features names
X = X.rename(
columns={
name: lags_names_list[idx]
name: self.model.lagged_feature_names[idx]
for idx, name in enumerate(X.columns.to_list())
}
)

return X
17 changes: 8 additions & 9 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,21 @@ def _create_lagged_data(
lags_past_covariates = self.lags.get("past")
lags_future_covariates = self.lags.get("future")

features, labels, _ = create_lagged_training_data(
(
features,
labels,
_,
self._static_covariates_shape,
) = create_lagged_training_data(
target_series=target_series,
output_chunk_length=self.output_chunk_length,
past_covariates=past_covariates,
future_covariates=future_covariates,
lags=lags,
lags_past_covariates=lags_past_covariates,
lags_future_covariates=lags_future_covariates,
uses_static_covariates=self.uses_static_covariates,
last_static_covariates_shape=None,
max_samples_per_ts=max_samples_per_ts,
multi_models=self.multi_models,
check_inputs=False,
Expand All @@ -371,14 +378,6 @@ def _create_lagged_data(
features[i] = X_i[:, :, 0]
labels[i] = y_i[:, :, 0]

features, static_covariates_shape = add_static_covariates_to_lagged_data(
features,
target_series,
uses_static_covariates=self.uses_static_covariates,
last_shape=None,
)
self._static_covariates_shape = static_covariates_shape

training_samples = np.concatenate(features, axis=0)
training_labels = np.concatenate(labels, axis=0)

Expand Down
123 changes: 123 additions & 0 deletions darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
from datetime import date, timedelta

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
import shap
import sklearn
from dateutil.relativedelta import relativedelta
Expand Down Expand Up @@ -90,6 +92,25 @@ class ShapExplainerTestCase(DartsBaseTestClass):
days, np.concatenate([x_1.reshape(-1, 1), x_2.reshape(-1, 1)], axis=1)
).with_columns_renamed(["0", "1"], ["price", "power"])

target_ts_with_static_covs = TimeSeries.from_times_and_values(
days,
x_1.reshape(-1, 1),
static_covariates=pd.DataFrame({"type": [0], "state": [1]}),
).with_columns_renamed(["0"], ["price"])
target_ts_with_multi_component_static_covs = TimeSeries.from_times_and_values(
days,
np.concatenate([x_1.reshape(-1, 1), x_2.reshape(-1, 1)], axis=1),
static_covariates=pd.DataFrame({"type": [0, 1], "state": [2, 3]}),
).with_columns_renamed(["0", "1"], ["price", "power"])
target_ts_multiple_series_with_different_static_covs = [
TimeSeries.from_times_and_values(
days, x_1.reshape(-1, 1), static_covariates=pd.DataFrame({"type": [0]})
).with_columns_renamed(["0"], ["price"]),
TimeSeries.from_times_and_values(
days, x_2.reshape(-1, 1), static_covariates=pd.DataFrame({"state": [1]})
).with_columns_renamed(["0"], ["price"]),
]

past_cov_ts = TimeSeries.from_times_and_values(
days_past_cov,
np.concatenate(
Expand Down Expand Up @@ -670,3 +691,105 @@ def test_shap_explanation_object_validity(self):
),
shap.Explanation,
)

def test_shap_selected_components(self):
model = LightGBMModel(
lags=4,
lags_past_covariates=2,
lags_future_covariates=[1],
output_chunk_length=1,
)
model.fit(
series=self.target_ts,
past_covariates=self.past_cov_ts,
future_covariates=self.fut_cov_ts,
)
shap_explain = ShapExplainer(model)
explanation_results = shap_explain.explain()
# check that explain() with selected components gives identical results
for comp in self.target_ts.components:
explanation_comp = shap_explain.explain(target_components=[comp])
assert explanation_comp.available_components == [comp]
assert explanation_comp.available_horizons == [1]
# explained forecasts
fc_res_tmp = copy.deepcopy(explanation_results.explained_forecasts)
fc_res_tmp[1] = {str(comp): fc_res_tmp[1][comp]}
assert explanation_comp.explained_forecasts == fc_res_tmp

# feature values
fv_res_tmp = copy.deepcopy(explanation_results.feature_values)
fv_res_tmp[1] = {str(comp): fv_res_tmp[1][comp]}
assert explanation_comp.explained_forecasts == fc_res_tmp

# shap objects
assert (
len(explanation_comp.shap_explanation_object[1]) == 1
and comp in explanation_comp.shap_explanation_object[1]
)

def test_shapley_with_static_cov(self):
ts = self.target_ts_with_static_covs
model = LightGBMModel(
lags=4,
output_chunk_length=1,
)
model.fit(
series=ts,
)
shap_explain = ShapExplainer(model)

# different static covariates dimensions should raise an error
with pytest.raises(ValueError):
shap_explain.explain(
ts.with_static_covariates(ts.static_covariates["state"])
)

# without static covariates should raise an error
with pytest.raises(ValueError):
shap_explain.explain(ts.with_static_covariates(None))

explanation_results = shap_explain.explain(ts)
assert len(explanation_results.explained_forecasts[1]["price"].columns) == (
-(min(model.lags["target"])) + model.static_covariates.shape[1]
)

model.fit(
series=self.target_ts_with_multi_component_static_covs,
)
shap_explain = ShapExplainer(model)
explanation_results = shap_explain.explain()
assert len(explanation_results.feature_values[1]) == 2
for comp in self.target_ts_with_multi_component_static_covs.components:
comps_out = explanation_results.explained_forecasts[1][comp].columns
assert len(comps_out) == (
-(min(model.lags["target"])) * model.input_dim["target"]
+ model.input_dim["target"] * model.static_covariates.shape[1]
)
assert comps_out[-4:].tolist() == [
"type_statcov_target_price",
"type_statcov_target_power",
"state_statcov_target_price",
"state_statcov_target_power",
]

def test_shapley_multiple_series_with_different_static_covs(self):
model = LightGBMModel(
lags=4,
output_chunk_length=1,
)
model.fit(
series=self.target_ts_multiple_series_with_different_static_covs,
)
shap_explain = ShapExplainer(
model,
background_series=self.target_ts_multiple_series_with_different_static_covs,
)
explanation_results = shap_explain.explain()

self.assertTrue(len(explanation_results.feature_values) == 2)

# model trained on multiple series will take column names of first series -> even though
# static covs have different names, the output will show the same names
for explained_forecast in explanation_results.explained_forecasts:
comps_out = explained_forecast[1]["price"].columns.tolist()
assert comps_out[-1] == "type_statcov_target_price"