Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xai init #909

Merged
merged 98 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
784cf13
xai init
dumjax Apr 13, 2022
5797010
small corrections
dumjax Apr 13, 2022
9fb8bb1
corrections and adds
dumjax Apr 19, 2022
efbb2f6
corrections
dumjax Apr 20, 2022
ce9de7a
linting corrections
dumjax Apr 20, 2022
16a1f57
lint_isort correct
dumjax Apr 20, 2022
5c8fd17
Merge branch 'master' into feat/xai
hrzn Jun 21, 2022
0c45389
Merge branch 'master' into feat/xai
hrzn Jun 30, 2022
a6cc229
small improvements
hrzn Jun 30, 2022
2f0fcf0
Explainability class corrections
dumjax Jul 20, 2022
25a6845
small changes + docstrings
dumjax Jul 28, 2022
ef960db
Shap method choice first adds
dumjax Aug 3, 2022
a4ca533
small adds
dumjax Aug 3, 2022
21787c7
explain new version
dumjax Aug 5, 2022
f1f7680
Merge branch 'master' into feat/xai
hrzn Aug 7, 2022
021e09e
plots and checks
dumjax Aug 9, 2022
62f6518
bugs on cov checks + Xy facto step 1
dumjax Aug 31, 2022
f2d471d
tabularization.py add
dumjax Aug 31, 2022
0f569a3
factor Xy part 2
dumjax Sep 1, 2022
7b1b040
Merge branch 'master' into feat/xai
dumjax Sep 6, 2022
f5139c6
lint
dumjax Sep 6, 2022
d6964f4
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0db6258
clean
dumjax Sep 6, 2022
4ff8108
clean
dumjax Sep 6, 2022
e227ce5
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
09f5ac8
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
cf55086
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0549a5e
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
e811731
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
41d4cd3
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
17aac2a
clean
dumjax Sep 6, 2022
86181cf
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
2661d01
Update darts/explainability/shap_explainer.py
dumjax Sep 6, 2022
00b62e0
clean
dumjax Sep 6, 2022
6947e35
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
7dc9488
corrections
dumjax Sep 7, 2022
5d6672a
encoders - facto corrections - others
dumjax Sep 12, 2022
be8e7a0
Merge branch 'master' into feat/xai
dumjax Sep 12, 2022
712eb57
timeseries_generation bugs + unit tests start
dumjax Sep 12, 2022
4827c75
corrections + Unit tests
dumjax Sep 13, 2022
0d4c978
corrections
dumjax Sep 13, 2022
1c60312
corrections tests multivariates
dumjax Sep 14, 2022
96ae80a
Merge branch 'master' into feat/xai
dumjax Sep 14, 2022
a04f779
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
f4078df
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
3ab2dd0
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
be0aabe
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4d85c5c
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
44767bd
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
488476a
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4201117
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
237a7cf
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
373fb12
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
39cb2c9
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a46ddcf
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1fc9ef8
docstrings corrections and others
dumjax Sep 22, 2022
49e4b53
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
e8694a2
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
5b528d8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a3fe9da
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1857a8e
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
19303ef
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
d640c65
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
213e670
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
40e97e8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
17bc7d4
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
ede4969
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1b19e14
corrections
dumjax Sep 22, 2022
db40e80
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
77e9c2a
corrections
dumjax Sep 27, 2022
86c0274
timeseries generation small correction
dumjax Sep 27, 2022
72bd897
corrections
dumjax Sep 27, 2022
e4eafcc
Merge branch 'master' into feat/xai
dumjax Sep 27, 2022
98823c0
create_lagged_data adds | unit testing | encoders
dumjax Sep 28, 2022
d8cf73a
correction bugs no covariate in explain()
dumjax Sep 29, 2022
f1e3c43
Merge branch 'master' into feat/xai
hrzn Sep 29, 2022
4534eed
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
23cf14e
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
babd744
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
e40712a
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4aba547
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4000e81
Update darts/utils/data/encoders.py
dumjax Sep 30, 2022
e9e7052
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
d32c02a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
2c34089
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
520846a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
dd7f200
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
fee5750
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
9a74e28
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
b60dcf4
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
31a8a3d
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
424d794
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
cca3bd8
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
ff4b13a
corrections
dumjax Oct 4, 2022
feb1e70
test comment
dumjax Oct 4, 2022
0dee794
Some doc fixes
hrzn Oct 4, 2022
b29ae68
More small fixes
hrzn Oct 4, 2022
4915d32
Last small fixes
hrzn Oct 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions darts/explainability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from darts.explainability.shap_explainer import ShapExplainer
dumjax marked this conversation as resolved.
Show resolved Hide resolved
222 changes: 222 additions & 0 deletions darts/explainability/explainability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
"""
Explainability Base Class
------------------------------
dumjax marked this conversation as resolved.
Show resolved Hide resolved

TODO
"""
from abc import ABC, abstractmethod
from typing import Optional, Sequence, Union

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_log
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.utils import retain_period_common_to_all
from darts.utils.statistics import stationarity_tests

logger = get_logger(__name__)


class ForecastingModelExplainer(ABC):
@abstractmethod
def __init__(
self,
model: ForecastingModel,
background_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
background_past_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
background_future_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
):
dumjax marked this conversation as resolved.
Show resolved Hide resolved

if not model._fit_called:
raise_log(
ValueError(
"The model must be fitted before instantiating a ForecastingModelExplainer."
),
logger,
)

if model._is_probabilistic():
dumjax marked this conversation as resolved.
Show resolved Hide resolved
# TODO: We can probably add explainability to probabilistic models, by taking the mean output.
raise_log(
ValueError(
"Explainability is only available for non-probabilistic models."
),
logger,
)

self.model = model

if background_series is None:

raise_if(
(background_past_covariates is not None)
or (background_future_covariates is not None),
"There is background past or future covariates but no background series. Please provide one.",
dumjax marked this conversation as resolved.
Show resolved Hide resolved
)

raise_if(
self.model.training_series is None,
"A background time series has to be provided for a model fitted on multiple time series, as"
"no training series has been saved by the model.",
)

self.background_series = self.model.training_series
hrzn marked this conversation as resolved.
Show resolved Hide resolved
self.background_past_covariates = self.model.past_covariate_series
hrzn marked this conversation as resolved.
Show resolved Hide resolved
self.background_future_covariates = self.model.future_covariate_series

else:

# ensure list of TimeSeries format
if isinstance(background_series, TimeSeries):
background_series = [background_series]
background_past_covariates = (
dumjax marked this conversation as resolved.
Show resolved Hide resolved
[background_past_covariates] if background_past_covariates else None
)
background_future_covariates = (
[background_future_covariates]
if background_future_covariates
else None
)
for idx in range(len(background_series)):
if not all(
dumjax marked this conversation as resolved.
Show resolved Hide resolved
len(
background_series[idx].time_index.difference(
background_past_covariates[idx].time_index
)
)
== 0,
len(
background_past_covariates[idx].time_index.difference(
background_future_covariates[idx].time_index
)
)
== 0,
len(
background_future_covariates[idx].time_index.difference(
background_series[idx].time_index
)
)
== 0,
):
logger.warning(
dumjax marked this conversation as resolved.
Show resolved Hide resolved
"Some series and their covariates don't share the same time index. We will take "
"the time index common to all."
)

(
background_series[idx],
dumjax marked this conversation as resolved.
Show resolved Hide resolved
background_past_covariates[idx],
background_future_covariates[idx],
) = retain_period_common_to_all(
[
background_series[idx],
background_past_covariates[idx],
background_future_covariates[idx],
]
)

self.background_series = background_series
self.background_past_covariates = background_past_covariates
self.background_future_covariates = background_past_covariates

self.target_names = self.background_series.columns
self.past_covariates_names = self.background_past_covariates.columns
dumjax marked this conversation as resolved.
Show resolved Hide resolved
self.future_covariates_names = self.background_future_covariates.columns

raise_if(
self.model._expect_past_covariates
and self.background_past_covariates is None,
"A background past covariates is not provided, but the model needs past covariates.",
)

raise_if(
self.model._expect_future_covariates
and self.background_future_covariates is None,
"A background future covariates is not provided, but the model needs future covariates.",
)

self.model = model
dumjax marked this conversation as resolved.
Show resolved Hide resolved

# For now we won't consider further time step that output_chunk_length, even though we could in
# theory explain anything in the future by auto regressive process.
if hasattr(self.model, "output_chunk_length"):
self.n = self.model.output_chunk_length
dumjax marked this conversation as resolved.
Show resolved Hide resolved
else:
self.n = 1

if hasattr(self.model, "input_chunk_length"):
self.past_steps_explained = self.model.input_chunk_length
else:
self.past_steps_explained = 1
dumjax marked this conversation as resolved.
Show resolved Hide resolved

if not self.test_stationarity():
logger.warning(
"One time series component of the background time series is not stationary."
dumjax marked this conversation as resolved.
Show resolved Hide resolved
" Beware of wrong interpretation with chosen explainability."
)

@abstractmethod
def explain_from_input(
foreground_series: TimeSeries,
foreground_past_covariates: Optional[TimeSeries] = None,
foreground_future_covariates: Optional[TimeSeries] = None,
horizons: Optional[Sequence[int]] = None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
target_names: Optional[Sequence[str]] = None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
dumjax marked this conversation as resolved.
Show resolved Hide resolved
) -> Sequence[Sequence[TimeSeries]]:
"""
dumjax marked this conversation as resolved.
Show resolved Hide resolved
Return explanations values for each target and covariates lag, in a multivariate TimeSeries format.
Each timestamp of the foreground TimeSeries is explained in the output TimeSeries, with the following
notation:
`name`_`type_of_cov`_lag_`int`

Example:
Let's say we have a model with 2 targets (multivariates) names T_1 and T_2, one past covariate and one
future covariate. Also, n = 2 and past_step_explained = 2.

Then the function is supposed to return a dictionary time series, with for example
dumjax marked this conversation as resolved.
Show resolved Hide resolved

output[0]['T_1'] (but also output[1]['T_1'], output[0]['T_2'] and output[1]['T_2']a TimeSeries
with the following components:
- T_1_target_lag-1
- T_1_target_lag-2
- 0_past_cov_lag-1 (we didn't name the past covariate so it took the default name)
- 0_past_cov_lag-2
- 0_fut_cov_lag_0 (could be also lag_1 if output[1])


Parameters
----------
foreground_series
TimeSeries target we want to explain. Can be multivariate.
foreground_past_covariates
Optionally, past covariate timeseries if needed by model.
foreground_future_covariates
Optionally, future covariate timeseries if needed by model.
horizons
Optionally, a list of integer values representing which elements in the future
we want to explain, starting from the first timestamp prediction at 0.
For now we consider only models with output_chunk_length and it can't be bigger than
output_chunk_length.
If no input, then all elements of output_chunk_length will be explained.
target_names
Optionally, a list of strings naming the components of `foreground_series` we want to explain.
If no input, then all targets will be explained.

Returns
-------
a TimeSeries or dictionary of Timeseries of explaining values :
- each element of the first dictionary is corresponding to an horizon
- each element of the second layer dictionary is corresponding to a target
"""
pass

def test_stationarity(self):
return all(
[
stationarity_tests(self.background_series[c])
for c in self.background_series.components
]
)