Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xai init #909

Merged
merged 98 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
784cf13
xai init
dumjax Apr 13, 2022
5797010
small corrections
dumjax Apr 13, 2022
9fb8bb1
corrections and adds
dumjax Apr 19, 2022
efbb2f6
corrections
dumjax Apr 20, 2022
ce9de7a
linting corrections
dumjax Apr 20, 2022
16a1f57
lint_isort correct
dumjax Apr 20, 2022
5c8fd17
Merge branch 'master' into feat/xai
hrzn Jun 21, 2022
0c45389
Merge branch 'master' into feat/xai
hrzn Jun 30, 2022
a6cc229
small improvements
hrzn Jun 30, 2022
2f0fcf0
Explainability class corrections
dumjax Jul 20, 2022
25a6845
small changes + docstrings
dumjax Jul 28, 2022
ef960db
Shap method choice first adds
dumjax Aug 3, 2022
a4ca533
small adds
dumjax Aug 3, 2022
21787c7
explain new version
dumjax Aug 5, 2022
f1f7680
Merge branch 'master' into feat/xai
hrzn Aug 7, 2022
021e09e
plots and checks
dumjax Aug 9, 2022
62f6518
bugs on cov checks + Xy facto step 1
dumjax Aug 31, 2022
f2d471d
tabularization.py add
dumjax Aug 31, 2022
0f569a3
factor Xy part 2
dumjax Sep 1, 2022
7b1b040
Merge branch 'master' into feat/xai
dumjax Sep 6, 2022
f5139c6
lint
dumjax Sep 6, 2022
d6964f4
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0db6258
clean
dumjax Sep 6, 2022
4ff8108
clean
dumjax Sep 6, 2022
e227ce5
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
09f5ac8
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
cf55086
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0549a5e
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
e811731
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
41d4cd3
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
17aac2a
clean
dumjax Sep 6, 2022
86181cf
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
2661d01
Update darts/explainability/shap_explainer.py
dumjax Sep 6, 2022
00b62e0
clean
dumjax Sep 6, 2022
6947e35
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
7dc9488
corrections
dumjax Sep 7, 2022
5d6672a
encoders - facto corrections - others
dumjax Sep 12, 2022
be8e7a0
Merge branch 'master' into feat/xai
dumjax Sep 12, 2022
712eb57
timeseries_generation bugs + unit tests start
dumjax Sep 12, 2022
4827c75
corrections + Unit tests
dumjax Sep 13, 2022
0d4c978
corrections
dumjax Sep 13, 2022
1c60312
corrections tests multivariates
dumjax Sep 14, 2022
96ae80a
Merge branch 'master' into feat/xai
dumjax Sep 14, 2022
a04f779
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
f4078df
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
3ab2dd0
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
be0aabe
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4d85c5c
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
44767bd
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
488476a
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4201117
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
237a7cf
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
373fb12
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
39cb2c9
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a46ddcf
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1fc9ef8
docstrings corrections and others
dumjax Sep 22, 2022
49e4b53
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
e8694a2
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
5b528d8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a3fe9da
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1857a8e
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
19303ef
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
d640c65
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
213e670
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
40e97e8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
17bc7d4
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
ede4969
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1b19e14
corrections
dumjax Sep 22, 2022
db40e80
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
77e9c2a
corrections
dumjax Sep 27, 2022
86c0274
timeseries generation small correction
dumjax Sep 27, 2022
72bd897
corrections
dumjax Sep 27, 2022
e4eafcc
Merge branch 'master' into feat/xai
dumjax Sep 27, 2022
98823c0
create_lagged_data adds | unit testing | encoders
dumjax Sep 28, 2022
d8cf73a
correction bugs no covariate in explain()
dumjax Sep 29, 2022
f1e3c43
Merge branch 'master' into feat/xai
hrzn Sep 29, 2022
4534eed
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
23cf14e
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
babd744
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
e40712a
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4aba547
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4000e81
Update darts/utils/data/encoders.py
dumjax Sep 30, 2022
e9e7052
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
d32c02a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
2c34089
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
520846a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
dd7f200
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
fee5750
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
9a74e28
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
b60dcf4
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
31a8a3d
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
424d794
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
cca3bd8
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
ff4b13a
corrections
dumjax Oct 4, 2022
feb1e70
test comment
dumjax Oct 4, 2022
0dee794
Some doc fixes
hrzn Oct 4, 2022
b29ae68
More small fixes
hrzn Oct 4, 2022
4915d32
Last small fixes
hrzn Oct 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions darts/explainability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from darts.explainability.shap_explainer import ShapExplainer
dumjax marked this conversation as resolved.
Show resolved Hide resolved
334 changes: 334 additions & 0 deletions darts/explainability/explainability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
"""
Forecasting Model Explainer Base Class
--------------------------------------
A forecasting model explainer takes a fitted forecasting model as input and applies an Explainability model
to it. Its purpose is to explain each past input contribution to a given model forecast. This 'explanation'
depends on the characteristics of the XAI model chosen (shap, lime etc...).

"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Sequence, Union

from numpy import integer

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.utils.statistics import stationarity_tests
from darts.utils.utils import series2seq

logger = get_logger(__name__)

MIN_BACKGROUND_SAMPLE = 10


class ExplainabilityResult(ABC):
"""
Class to store the explainability results of a `ForecastingModelExplainer`, and to
easily access the results.
"""

def __init__(
self,
explained_forecasts: Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
],
):

self.explained_forecasts = explained_forecasts
if isinstance(self.explained_forecasts, list):
self.available_horizons = list(self.explained_forecasts[0].keys())
h_0 = self.available_horizons[0]
self.available_components = list(self.explained_forecasts[0][h_0].keys())
else:
self.available_horizons = list(self.explained_forecasts.keys())
h_0 = self.available_horizons[0]
self.available_components = list(self.explained_forecasts[h_0].keys())

def get_explanation(
self, horizon: int, component: str
dumjax marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[TimeSeries, Sequence[TimeSeries]]:

raise_if_not(
horizon in self.available_horizons,
"Horizon {} is not available. Available horizons are: {}".format(
horizon, self.available_horizons
),
)

raise_if_not(
component in self.available_components,
"Component {} is not available. Available components are: {}".format(
component, self.available_components
),
)

if isinstance(self.explained_forecasts, list):
return [
self.explained_forecasts[i][horizon][component]
for i in range(len(self.explained_forecasts))
]
else:
return self.explained_forecasts[horizon][component]


class ForecastingModelExplainer(ABC):
@abstractmethod
def __init__(
self,
model: ForecastingModel,
background_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
background_past_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
background_future_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
):
dumjax marked this conversation as resolved.
Show resolved Hide resolved
"""The base class for forecasting model explainers. It defines the *minimal* behavior that all
forecasting model explainers support.

Naming:
- A background series is a `TimeSeries` with which we 'train' the `Explainer` model.
- A foreground series is the `TimeSeries` we will explain according to the fitted `Explainer` model.

Parameters
----------
model
A `ForecastingModel` we want to explain. It must be fitted first.
background_series
A series or list of series to *train* the `ForecastingModelExplainer` along with any foreground series.
Consider using a reduced well-chosen backgroundto to reduce computation time.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
- optional if `model` was fit on a single target series. By default, it is the `series` used
at fitting time.
- mandatory if `model` was fit on multiple (list of) target series.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
background_past_covariates
A past covariates series or list of series that the model needs once fitted.
background_future_covariates
A future covariates series or list of series that the model needs once fitted.
"""
if not model._fit_called:
raise_log(
ValueError(
"The model must be fitted before instantiating a ForecastingModelExplainer."
),
logger,
)

if model._is_probabilistic():
dumjax marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"The model is probabilistic, but num_samples=1 will be used for explainability."
)

self.model = model

# if `background_series` was not passed, use `training_series` saved in fitted forecasting model.
if background_series is None:

raise_if(
(background_past_covariates is not None)
or (background_future_covariates is not None),
"Supplied background past or future covariates but no background series. Please provide "
"`background_series`.",
)

raise_if(
self.model.training_series is None,
"`background_series` must be provided if `model` was fit on multiple time series.",
)

background_series = self.model.training_series
background_past_covariates = self.model.past_covariate_series
background_future_covariates = self.model.future_covariate_series

else:
if self.model.encoders.encoding_available:
# (
dumjax marked this conversation as resolved.
Show resolved Hide resolved
# background_past_covariates,
# background_future_covariates,
# ) = self.model.encoders.encode_train(
# target=background_series,
# past_covariate=background_past_covariates,
# future_covariate=background_future_covariates,
# )
(
background_past_covariates,
background_future_covariates,
) = self.model.generate_predict_encodings(
n=len(background_series) - self.model.min_train_series_length,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be set to output_chunk_length? I'm probably missing something.

series=background_series,
past_covariates=background_past_covariates,
future_covariates=background_future_covariates,
)

self.background_series = series2seq(background_series)
self.background_past_covariates = series2seq(background_past_covariates)
self.background_future_covariates = series2seq(background_future_covariates)

if self.model.uses_past_covariates:
raise_if(
self.model._expect_past_covariates
and self.background_past_covariates is None,
"A background past covariates is not provided, but the model needs past covariates.",
)

if self.model.uses_future_covariates:
raise_if(
self.model._expect_future_covariates
and self.background_future_covariates is None,
"A background future covariates is not provided, but the model needs future covariates.",
)

self.target_names = self.background_series[0].columns.to_list()
self.past_covariates_names = None
if self.background_past_covariates is not None:
self.past_covariates_names = self.background_past_covariates[
0
].columns.to_list()
self.future_covariates_names = None
if self.background_future_covariates is not None:
self.future_covariates_names = self.background_future_covariates[
0
].columns.to_list()

self._check_background_covariates()

if not self._test_stationarity():
logger.warning(
"At least one time series component of the background time series is not stationary."
" Beware of wrong interpretation with chosen explainability."
)

def _check_background_covariates(self):
dumjax marked this conversation as resolved.
Show resolved Hide resolved

if self.background_past_covariates is not None:
raise_if_not(
len(self.background_series) == len(self.background_past_covariates),
"The number of background series and past covariates must be the same.",
)

if self.background_future_covariates is not None:
raise_if_not(
len(self.background_series) == len(self.background_future_covariates),
"The number of background series and future covariates must be the same.",
)

# ensure we have the same names between TimeSeries (if list of). Important to ensure homogeneity
# for explained features.
for idx in range(len(self.background_series)):
raise_if_not(
all(
[
self.background_series[idx].columns.to_list()
== self.target_names,
self.background_past_covariates[idx].columns.to_list()
== self.past_covariates_names
if self.background_past_covariates is not None
else True,
self.background_future_covariates[idx].columns.to_list()
== self.future_covariates_names
if self.background_future_covariates is not None
else True,
]
),
"Columns names must be identical between TimeSeries list components (multi-TimeSeries).",
)

@abstractmethod
def explain(
self,
foreground_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
foreground_past_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
foreground_future_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
horizons: Optional[Sequence[int]] = None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
target_names: Optional[Sequence[str]] = None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
dumjax marked this conversation as resolved.
Show resolved Hide resolved
) -> ExplainabilityResult:
"""
dumjax marked this conversation as resolved.
Show resolved Hide resolved
Main method of the ForecastingExplainer class.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
Return a ExplainabilityResult instance.

Results can be retrieved via the ExplainabilityResult.get_explanation(horizon, target)
dumjax marked this conversation as resolved.
Show resolved Hide resolved
The result is a multivariate TimeSeries instance containing the 'explanation'
for the (horizon, target) forecast at any timestamp forecastable corresponding to the foreground
TimeSeries input.

The component name convention of this multivariate `TimeSeries` is:
``f'{name}_{type_of_cov}_{lag}_{int}'``, where:

- `name` is the component name from the original foreground series (target, past, or future).
- `type_of_cov` is the covariates type. It can take 3 different values: ``{"target", "past", "future"}``.
- `int` is the lag index.

Example:
Let's say we have a model with 2 target components named ``"T_0"`` and ``"T_1"``, three past covariates we
didn't name, and one future covariate we didn't name. Also, ``horizons = [0, 1]``.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
The model is a regression model, with ``lags = 3``, ``lags_past_covariates=[-1, -3]``,
``lags_future_covariates = [0]``.

We provide `foreground_series`, `foreground_past_covariates`, `foreground_future_covariates` each of length 5.


>>> explain_results = explainer.explain(
foreground_series=foreground_series,
foreground_past_covariates=foreground_past_covariates,
foreground_future_covariates=foreground_future_covariates, horizons=[0, 1], target_names=["T_0", "T_1"])
dumjax marked this conversation as resolved.
Show resolved Hide resolved
>>> output = explain_results.get_explanation(horizon=0, target="T_1")

Then ``output`` is a multivariate TimeSeries containing the *explanations* of the chosen `Explainer`,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
with component names:
- T_0_target_lag-1
- T_0_target_lag-2
- T_0_target_lag-3
- T_1_target_lag-1
- T_1_target_lag-2
- T_1_target_lag-3
- 0_past_cov_lag-1 (we didn't name the past covariate so it took the default name 0)
dumjax marked this conversation as resolved.
Show resolved Hide resolved
- 0_past_cov_lag-3 (we didn't name the past covariate so it took the default name 0)
- 1_past_cov_lag-1 (we didn't name the past covariate so it took the default name 1)
- 1_past_cov_lag-3 (we didn't name the past covariate so it took the default name 1)
- 2_past_cov_lag-1 (we didn't name the past covariate so it took the default name 2)
- 2_past_cov_lag-3 (we didn't name the past covariate so it took the default name 2)
- 0_fut_cov_lag_0 (we didn't name the future covariate so it took the default name 0)

of length 3, as we can explain 5-3+1 forecasts (basically timestamp indexes 4, 5, and 6)

dumjax marked this conversation as resolved.
Show resolved Hide resolved


Parameters
----------
foreground_series
Optionally, target timeseries we want to explain. Can be multivariate.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
If none is provided, explain will automatically provide the whole background TimeSeries explanation.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
foreground_past_covariates
Optionally, past covariate timeseries if needed by the ForecastingModel.
foreground_future_covariates
Optionally, future covariate timeseries if needed by the ForecastingModel.
horizons
Optionally, a list of integer values representing which elements in the future
dumjax marked this conversation as resolved.
Show resolved Hide resolved
we want to explain, starting from the first timestamp prediction at 0.
For now we consider only models with output_chunk_length and it can't be bigger than output_chunk_length.
target_names
Optionally, A list of string naming the target names we want to explain.

Returns
-------
An `ExplainabilityResult` instance.
dumjax marked this conversation as resolved.
Show resolved Hide resolved

"""
pass

def _test_stationarity(self):
return all(
[
(
stationarity_tests(background_serie[c])
for c in background_serie.components
)
for background_serie in self.background_series
]
)