Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xai init #909

Merged
merged 98 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
784cf13
xai init
dumjax Apr 13, 2022
5797010
small corrections
dumjax Apr 13, 2022
9fb8bb1
corrections and adds
dumjax Apr 19, 2022
efbb2f6
corrections
dumjax Apr 20, 2022
ce9de7a
linting corrections
dumjax Apr 20, 2022
16a1f57
lint_isort correct
dumjax Apr 20, 2022
5c8fd17
Merge branch 'master' into feat/xai
hrzn Jun 21, 2022
0c45389
Merge branch 'master' into feat/xai
hrzn Jun 30, 2022
a6cc229
small improvements
hrzn Jun 30, 2022
2f0fcf0
Explainability class corrections
dumjax Jul 20, 2022
25a6845
small changes + docstrings
dumjax Jul 28, 2022
ef960db
Shap method choice first adds
dumjax Aug 3, 2022
a4ca533
small adds
dumjax Aug 3, 2022
21787c7
explain new version
dumjax Aug 5, 2022
f1f7680
Merge branch 'master' into feat/xai
hrzn Aug 7, 2022
021e09e
plots and checks
dumjax Aug 9, 2022
62f6518
bugs on cov checks + Xy facto step 1
dumjax Aug 31, 2022
f2d471d
tabularization.py add
dumjax Aug 31, 2022
0f569a3
factor Xy part 2
dumjax Sep 1, 2022
7b1b040
Merge branch 'master' into feat/xai
dumjax Sep 6, 2022
f5139c6
lint
dumjax Sep 6, 2022
d6964f4
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0db6258
clean
dumjax Sep 6, 2022
4ff8108
clean
dumjax Sep 6, 2022
e227ce5
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
09f5ac8
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
cf55086
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0549a5e
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
e811731
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
41d4cd3
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
17aac2a
clean
dumjax Sep 6, 2022
86181cf
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
2661d01
Update darts/explainability/shap_explainer.py
dumjax Sep 6, 2022
00b62e0
clean
dumjax Sep 6, 2022
6947e35
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
7dc9488
corrections
dumjax Sep 7, 2022
5d6672a
encoders - facto corrections - others
dumjax Sep 12, 2022
be8e7a0
Merge branch 'master' into feat/xai
dumjax Sep 12, 2022
712eb57
timeseries_generation bugs + unit tests start
dumjax Sep 12, 2022
4827c75
corrections + Unit tests
dumjax Sep 13, 2022
0d4c978
corrections
dumjax Sep 13, 2022
1c60312
corrections tests multivariates
dumjax Sep 14, 2022
96ae80a
Merge branch 'master' into feat/xai
dumjax Sep 14, 2022
a04f779
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
f4078df
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
3ab2dd0
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
be0aabe
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4d85c5c
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
44767bd
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
488476a
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4201117
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
237a7cf
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
373fb12
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
39cb2c9
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a46ddcf
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1fc9ef8
docstrings corrections and others
dumjax Sep 22, 2022
49e4b53
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
e8694a2
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
5b528d8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a3fe9da
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1857a8e
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
19303ef
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
d640c65
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
213e670
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
40e97e8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
17bc7d4
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
ede4969
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1b19e14
corrections
dumjax Sep 22, 2022
db40e80
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
77e9c2a
corrections
dumjax Sep 27, 2022
86c0274
timeseries generation small correction
dumjax Sep 27, 2022
72bd897
corrections
dumjax Sep 27, 2022
e4eafcc
Merge branch 'master' into feat/xai
dumjax Sep 27, 2022
98823c0
create_lagged_data adds | unit testing | encoders
dumjax Sep 28, 2022
d8cf73a
correction bugs no covariate in explain()
dumjax Sep 29, 2022
f1e3c43
Merge branch 'master' into feat/xai
hrzn Sep 29, 2022
4534eed
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
23cf14e
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
babd744
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
e40712a
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4aba547
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4000e81
Update darts/utils/data/encoders.py
dumjax Sep 30, 2022
e9e7052
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
d32c02a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
2c34089
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
520846a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
dd7f200
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
fee5750
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
9a74e28
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
b60dcf4
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
31a8a3d
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
424d794
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
cca3bd8
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
ff4b13a
corrections
dumjax Oct 4, 2022
feb1e70
test comment
dumjax Oct 4, 2022
0dee794
Some doc fixes
hrzn Oct 4, 2022
b29ae68
More small fixes
hrzn Oct 4, 2022
4915d32
Last small fixes
hrzn Oct 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions darts/explainability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Explainability
--------------
"""

from darts.explainability.explainability_result import ExplainabilityResult
from darts.explainability.shap_explainer import ShapExplainer
dumjax marked this conversation as resolved.
Show resolved Hide resolved
299 changes: 299 additions & 0 deletions darts/explainability/explainability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
"""
Forecasting Model Explainer Base Class

A forecasting model explainer takes a fitted forecasting model as input and applies an Explainability model
to it. Its purpose is to explain each past input contribution to a given model forecast. This 'explanation'
depends on the characteristics of the XAI model chosen (shap, lime etc...).

"""
from abc import ABC, abstractmethod
from typing import Collection, Optional, Sequence, Union

from darts import TimeSeries
from darts.explainability.explainability_result import ExplainabilityResult
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.utils.statistics import stationarity_tests
from darts.utils.utils import series2seq

logger = get_logger(__name__)

MIN_BACKGROUND_SAMPLE = 10


class ForecastingModelExplainer(ABC):
@abstractmethod
def __init__(
self,
model: ForecastingModel,
background_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
background_past_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
background_future_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
):
dumjax marked this conversation as resolved.
Show resolved Hide resolved
"""
The base class for forecasting model explainers. It defines the *minimal* behavior that all
forecasting model explainers support.

Naming:

- A background series is a `TimeSeries` with which to 'train' the `Explainer` model.
- A foreground series is the `TimeSeries` to explain using the fitted `Explainer` model.

Parameters
----------
model
A `ForecastingModel` to be explained. It must be fitted first.
background_series
A series or list of series to *train* the `ForecastingModelExplainer` along with any foreground series.
Consider using a reduced well-chosen background to reduce computation time.

- optional if `model` was fit on a single target series. By default, it is the `series` used
at fitting time.
- mandatory if `model` was fit on multiple (sequence of) target series.

background_past_covariates
A past covariates series or list of series that the model needs once fitted.
background_future_covariates
A future covariates series or list of series that the model needs once fitted.
"""
if not model._fit_called:
raise_log(
ValueError(
"The model must be fitted before instantiating a ForecastingModelExplainer."
),
logger,
)

if model._is_probabilistic():
dumjax marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"The model is probabilistic, but num_samples=1 will be used for explainability."
)

self.model = model

# if `background_series` was not passed, use `training_series` saved in fitted forecasting model.
if background_series is None:

raise_if(
(background_past_covariates is not None)
or (background_future_covariates is not None),
"Supplied background past or future covariates but no background series. Please provide "
"`background_series`.",
)

raise_if(
self.model.training_series is None,
"`background_series` must be provided if `model` was fit on multiple time series.",
)

background_series = self.model.training_series
background_past_covariates = self.model.past_covariate_series
background_future_covariates = self.model.future_covariate_series

else:
if self.model.encoders.encoding_available:
(
background_past_covariates,
background_future_covariates,
) = self.model.generate_predict_encodings(
n=len(background_series) - self.model.min_train_series_length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be set to output_chunk_length? I'm probably missing something.

series=background_series,
past_covariates=background_past_covariates,
future_covariates=background_future_covariates,
)

self.background_series = series2seq(background_series)
self.background_past_covariates = series2seq(background_past_covariates)
self.background_future_covariates = series2seq(background_future_covariates)

if self.model.uses_past_covariates:
raise_if(
self.model._expect_past_covariates
and self.background_past_covariates is None,
"A background past covariates is not provided, but the model needs past covariates.",
)

if self.model.uses_future_covariates:
raise_if(
self.model._expect_future_covariates
and self.background_future_covariates is None,
"A background future covariates is not provided, but the model needs future covariates.",
)

self.target_components = self.background_series[0].columns.to_list()
self.past_covariates_components = None
if self.background_past_covariates is not None:
self.past_covariates_components = self.background_past_covariates[
0
].columns.to_list()
self.future_covariates_components = None
if self.background_future_covariates is not None:
self.future_covariates_components = self.background_future_covariates[
0
].columns.to_list()

self._check_background_covariates(
self.background_series,
self.background_past_covariates,
self.background_future_covariates,
self.target_components,
self.past_covariates_components,
self.future_covariates_components,
)

if not self._test_stationarity():
logger.warning(
"At least one time series component of the background time series is not stationary."
" Beware of wrong interpretation with chosen explainability."
)

@staticmethod
def _check_background_covariates(
background_series,
background_past_covariates,
background_future_covariates,
target_components,
past_covariates_components,
future_covariates_components,
) -> None:

if background_past_covariates is not None:
raise_if_not(
len(background_series) == len(background_past_covariates),
"The number of background series and past covariates must be the same.",
)

if background_future_covariates is not None:
raise_if_not(
len(background_series) == len(background_future_covariates),
"The number of background series and future covariates must be the same.",
)

# ensure we have the same names between TimeSeries (if list of). Important to ensure homogeneity
# for explained features.
for idx in range(len(background_series)):
raise_if_not(
all(
[
background_series[idx].columns.to_list() == target_components,
background_past_covariates[idx].columns.to_list()
== past_covariates_components
if background_past_covariates is not None
else True,
background_future_covariates[idx].columns.to_list()
== future_covariates_components
if background_future_covariates is not None
else True,
]
),
"Columns names must be identical between TimeSeries list components (multi-TimeSeries).",
)

@abstractmethod
def explain(
self,
foreground_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
foreground_past_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
foreground_future_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
horizons: Optional[Collection[int]] = None,
target_components: Optional[Collection[str]] = None,
) -> ExplainabilityResult:
"""
dumjax marked this conversation as resolved.
Show resolved Hide resolved
Explains a foreground time series, returns an :class:`ExplainabilityResult`.

Results can be retrieved via the method
:func:`ExplainabilityResult.get_explanation(horizon, target_component)`.
The result is a multivariate `TimeSeries` instance containing the 'explanation'
for the (horizon, target_component) forecast at any timestamp forecastable corresponding to
the foreground `TimeSeries` input.

The component name convention of this multivariate `TimeSeries` is:
``"{name}_{type_of_cov}_lag_{idx}"``, where:

- ``{name}`` is the component name from the original foreground series (target, past, or future).
- ``{type_of_cov}`` is the covariates type. It can take 3 different values:
``"target"``, ``"past_cov"`` or ``"future_cov"``.
- ``{idx}`` is the lag index.

**Example:**

Say we have a model with 2 target components named ``"T_0"`` and ``"T_1"``,
3 past covariates with default component names ``"0"``, ``"1"``, and ``"2"``,
and one future covariate with default component name ``"0"``.
Also, ``horizons = [1, 2]``.
The model is a regression model, with ``lags = 3``, ``lags_past_covariates=[-1, -3]``,
``lags_future_covariates = [0]``.

We provide `foreground_series`, `foreground_past_covariates`, `foreground_future_covariates` each of length 5.


>>> explain_results = explainer.explain(
>>> foreground_series=foreground_series,
>>> foreground_past_covariates=foreground_past_covariates,
>>> foreground_future_covariates=foreground_future_covariates,
>>> horizons=[1, 2],
>>> target_names=["T_0", "T_1"])
>>> output = explain_results.get_explanation(horizon=1, target="T_1")

Then the method returns a multivariate TimeSeries containing the *explanations* of
the corresponding `ForecastingModelExplainer`, with the following component names:

- T_0_target_lag-1
- T_0_target_lag-2
- T_0_target_lag-3
- T_1_target_lag-1
- T_1_target_lag-2
- T_1_target_lag-3
- 0_past_cov_lag-1
- 0_past_cov_lag-3
- 1_past_cov_lag-1
- 1_past_cov_lag-3
- 2_past_cov_lag-1
- 2_past_cov_lag-3
- 0_fut_cov_lag_0

This series has length 3, as the model can explain 5-3+1 forecasts
(timestamp indexes 4, 5, and 6)

Parameters
----------
foreground_series
Optionally, the target `TimeSeries` to be explained. Can be multivariate.
If not provided, the background `TimeSeries` will be explained instead.
foreground_past_covariates
Optionally, past covariate timeseries if needed by the ForecastingModel.
foreground_future_covariates
Optionally, future covariate timeseries if needed by the ForecastingModel.
horizons
Optionally, a collection of integers representing the future lags to be explained.
Horizon 1 corresponds to the first timestamp being forecasted.
All values must be no larger than `output_chunk_length` of the explained model.
target_components
Optionally, A list of string naming the target components to be explained.

Returns
-------
ExplainabilityResult
The forecast explanations.

"""
pass

def _test_stationarity(self):
return all(
[
(
stationarity_tests(background_serie[c])
for c in background_serie.components
)
for background_serie in self.background_series
]
)
90 changes: 90 additions & 0 deletions darts/explainability/explainability_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
ExplainabilityResult
--------------------

Contains the explainability results obtained from :func:`ForecastingModelExplainer.explain()`.
"""

from abc import ABC
from typing import Dict, Optional, Sequence, Union

from numpy import integer

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not

logger = get_logger(__name__)


class ExplainabilityResult(ABC):
"""
Stores the explainability results of a :class:`ForecastingModelExplainer`
with convenient access to the results.
"""

def __init__(
self,
explained_forecasts: Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
],
):

self.explained_forecasts = explained_forecasts
if isinstance(self.explained_forecasts, list):
self.available_horizons = list(self.explained_forecasts[0].keys())
h_0 = self.available_horizons[0]
self.available_components = list(self.explained_forecasts[0][h_0].keys())
else:
self.available_horizons = list(self.explained_forecasts.keys())
h_0 = self.available_horizons[0]
self.available_components = list(self.explained_forecasts[h_0].keys())

def get_explanation(
self, horizon: int, component: Optional[str] = None
) -> Union[TimeSeries, Sequence[TimeSeries]]:
"""
Returns one or several `TimeSeries` representing the explanations
for a given horizon and component.

Parameters
----------
horizon
The horizon for which to return the explanation.
component
The component for which to return the explanation. Does not
need to be specified for univariate series.
"""

raise_if(
component is None and len(self.available_components) > 1,
ValueError(
"The component parameter is required when the model has more than one component."
),
logger,
)

if component is None:
component = self.available_components[0]

raise_if_not(
horizon in self.available_horizons,
"Horizon {} is not available. Available horizons are: {}".format(
horizon, self.available_horizons
),
)

raise_if_not(
component in self.available_components,
"Component {} is not available. Available components are: {}".format(
component, self.available_components
),
)

if isinstance(self.explained_forecasts, list):
return [
self.explained_forecasts[i][horizon][component]
for i in range(len(self.explained_forecasts))
]
else:
return self.explained_forecasts[horizon][component]