Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xai init #909

Merged
merged 98 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
784cf13
xai init
dumjax Apr 13, 2022
5797010
small corrections
dumjax Apr 13, 2022
9fb8bb1
corrections and adds
dumjax Apr 19, 2022
efbb2f6
corrections
dumjax Apr 20, 2022
ce9de7a
linting corrections
dumjax Apr 20, 2022
16a1f57
lint_isort correct
dumjax Apr 20, 2022
5c8fd17
Merge branch 'master' into feat/xai
hrzn Jun 21, 2022
0c45389
Merge branch 'master' into feat/xai
hrzn Jun 30, 2022
a6cc229
small improvements
hrzn Jun 30, 2022
2f0fcf0
Explainability class corrections
dumjax Jul 20, 2022
25a6845
small changes + docstrings
dumjax Jul 28, 2022
ef960db
Shap method choice first adds
dumjax Aug 3, 2022
a4ca533
small adds
dumjax Aug 3, 2022
21787c7
explain new version
dumjax Aug 5, 2022
f1f7680
Merge branch 'master' into feat/xai
hrzn Aug 7, 2022
021e09e
plots and checks
dumjax Aug 9, 2022
62f6518
bugs on cov checks + Xy facto step 1
dumjax Aug 31, 2022
f2d471d
tabularization.py add
dumjax Aug 31, 2022
0f569a3
factor Xy part 2
dumjax Sep 1, 2022
7b1b040
Merge branch 'master' into feat/xai
dumjax Sep 6, 2022
f5139c6
lint
dumjax Sep 6, 2022
d6964f4
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0db6258
clean
dumjax Sep 6, 2022
4ff8108
clean
dumjax Sep 6, 2022
e227ce5
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
09f5ac8
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
cf55086
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0549a5e
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
e811731
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
41d4cd3
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
17aac2a
clean
dumjax Sep 6, 2022
86181cf
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
2661d01
Update darts/explainability/shap_explainer.py
dumjax Sep 6, 2022
00b62e0
clean
dumjax Sep 6, 2022
6947e35
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
7dc9488
corrections
dumjax Sep 7, 2022
5d6672a
encoders - facto corrections - others
dumjax Sep 12, 2022
be8e7a0
Merge branch 'master' into feat/xai
dumjax Sep 12, 2022
712eb57
timeseries_generation bugs + unit tests start
dumjax Sep 12, 2022
4827c75
corrections + Unit tests
dumjax Sep 13, 2022
0d4c978
corrections
dumjax Sep 13, 2022
1c60312
corrections tests multivariates
dumjax Sep 14, 2022
96ae80a
Merge branch 'master' into feat/xai
dumjax Sep 14, 2022
a04f779
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
f4078df
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
3ab2dd0
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
be0aabe
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4d85c5c
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
44767bd
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
488476a
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4201117
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
237a7cf
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
373fb12
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
39cb2c9
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a46ddcf
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1fc9ef8
docstrings corrections and others
dumjax Sep 22, 2022
49e4b53
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
e8694a2
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
5b528d8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a3fe9da
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1857a8e
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
19303ef
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
d640c65
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
213e670
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
40e97e8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
17bc7d4
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
ede4969
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1b19e14
corrections
dumjax Sep 22, 2022
db40e80
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
77e9c2a
corrections
dumjax Sep 27, 2022
86c0274
timeseries generation small correction
dumjax Sep 27, 2022
72bd897
corrections
dumjax Sep 27, 2022
e4eafcc
Merge branch 'master' into feat/xai
dumjax Sep 27, 2022
98823c0
create_lagged_data adds | unit testing | encoders
dumjax Sep 28, 2022
d8cf73a
correction bugs no covariate in explain()
dumjax Sep 29, 2022
f1e3c43
Merge branch 'master' into feat/xai
hrzn Sep 29, 2022
4534eed
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
23cf14e
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
babd744
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
e40712a
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4aba547
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4000e81
Update darts/utils/data/encoders.py
dumjax Sep 30, 2022
e9e7052
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
d32c02a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
2c34089
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
520846a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
dd7f200
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
fee5750
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
9a74e28
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
b60dcf4
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
31a8a3d
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
424d794
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
cca3bd8
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
ff4b13a
corrections
dumjax Oct 4, 2022
feb1e70
test comment
dumjax Oct 4, 2022
0dee794
Some doc fixes
hrzn Oct 4, 2022
b29ae68
More small fixes
hrzn Oct 4, 2022
4915d32
Last small fixes
hrzn Oct 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
171 changes: 171 additions & 0 deletions darts/explainability/explainability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Explainability Base Class
------------------------------
dumjax marked this conversation as resolved.
Show resolved Hide resolved

TODO
"""

from pyexpat import model
dumjax marked this conversation as resolved.
Show resolved Hide resolved
from xmlrpc.client import Boolean
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.utils.statistics import stationarity_tests

from darts import models
from typing import Optional, Tuple, Union, Any, Callable, Dict, List, Sequence
from abc import ABC, abstractmethod
import numpy as np
import pandas as pd

from ..timeseries import TimeSeries
from ..logging import get_logger, raise_log, raise_if_not, raise_if


logger = get_logger(__name__)


class ForecastingModelExplainer(ABC):

@abstractmethod
def __init__(
self,
model: ForecastingModel,
background_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
background_past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
background_future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
n: Optional[int]=1,
past_steps_explained: Optional[int] = 2
):
dumjax marked this conversation as resolved.
Show resolved Hide resolved

if not issubclass(type(model), RegressionModel):
dumjax marked this conversation as resolved.
Show resolved Hide resolved
raise_log(
ValueError('Invalid model type. For now, only RegressionModel type can be explained.'),
logger
)

if not model._fit_called:
raise_log(
ValueError('The model must be fitted before instantiating a ForecastingModelExplainer.'),
logger
)

if model._is_probabilistic():
dumjax marked this conversation as resolved.
Show resolved Hide resolved
# TODO: We can probably add explainability to probabilistic models, by taking the mean output.
raise_log(
ValueError('Explainability is only available for non-probabilistic models.'),
logger
)

self.model = model

if background_series is None:

raise_if(
(not background_past_covariates is None) or (not background_future_covariates is None),
"There is background past or future covariates but no background series. Please provide one.")

raise_if(
self.model.training_series is None,
"A background time series has to be provided for a model fitted on multiple time series, as"
"no training series has been saved by the model.")


self.background_series = self.model.training_series
hrzn marked this conversation as resolved.
Show resolved Hide resolved
self.background_past_covariates = self.model.past_covariate_series
hrzn marked this conversation as resolved.
Show resolved Hide resolved
self.background_future_covariates = self.model.future_covariate_series

self.target_names = self.background_series.columns
self.past_covariates_names = self.background_past_covariates.columns
dumjax marked this conversation as resolved.
Show resolved Hide resolved
self.future_covariates_names = self.background_future_covariates.columns

raise_if(
self.model._expect_past_covariates and self.background_past_covariates is None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
"A background past covariates is not provided, but the model needs past covariates.")

raise_if(
self.model._expect_future_covariates and self.background_future_covariates is None,
"A background future covariates is not provided, but the model needs future covariates.")

self.model = model
dumjax marked this conversation as resolved.
Show resolved Hide resolved
self.past_steps_explained = past_steps_explained
self.n = n

# For now we won't consider further time step that output_chunk_length, even though we could in
# theory explain anything in the future by auto regressive process.
if hasattr(self.model, 'output_chunk_length'):
self.n = self.model.output_chunk_length
dumjax marked this conversation as resolved.
Show resolved Hide resolved

if hasattr(self.model, 'input_chunk_length'):
dumjax marked this conversation as resolved.
Show resolved Hide resolved
self.past_steps_explained = self.model.input_chunk_length

if not self.test_stationarity():
logger.warning("One time series component of the background time series is not stationary."
"Beware of wrong interpretation with chosen explainability")

@abstractmethod
def explain_from_input(
foreground_series: TimeSeries,
foreground_past_covariates: Optional[TimeSeries],
dumjax marked this conversation as resolved.
Show resolved Hide resolved
foreground_future_covariates: Optional[TimeSeries],
horizons: Optional[Sequence[int]]=None,
target_names: Optional[Sequence[str]]=None
) -> Union[TimeSeries, Sequence[TimeSeries]]:
dumjax marked this conversation as resolved.
Show resolved Hide resolved
"""
dumjax marked this conversation as resolved.
Show resolved Hide resolved
Return explanations values for each target and covariates lag, in a multivariate TimeSeries format.
Each timestamp of the foreground TimeSeries is explained in the output TimeSeries, with the following
notation:
`name`_`type_of_cov`_lag_`int`

Example:
Let's say we have a model with 2 targets (multivariates) names T_1 and T_2, one past covariate and one
future covariate. Also, n = 2 and past_step_explained = 2.

Then the function is supposed to return a dictionary time series, with for example
dumjax marked this conversation as resolved.
Show resolved Hide resolved

output[0]['T_1'] (but also output[1]['T_1'], output[0]['T_2'] and output[1]['T_2']a TimeSeries
with the following components:
- T_1_target_lag-1
- T_1_target_lag-2
- 0_past_cov_lag-1 (we didn't name the past covariate so it took the default name)
- 0_past_cov_lag-2
- 0_fut_cov_lag_0 (could be also lag_1 if output[1])


Parameters
----------
foreground_series
TimeSeries target we want to explain. Can be multivariate.
foreground_past_covariates
Optionally, past covariate timeseries if needed by model.
foreground_future_covariates
Optionally, future covariate timeseries if needed by model.
horizons
Optionally, a list of integer values representing which elements in the future
we want to explain, starting from the first timestamp prediction at 0.
For now we consider only models with output_chunk_length and it can't be bigger than
output_chunk_length.
If no input, then all elements of output_chunk_length will be explained.
target_names
Optionally, a list of string values naming the targets we want to explain.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
If no input, then all targets will be explained.

Returns
-------
a shap Explanation dictionary of dictionaries of shap Explanation objects:
- each element of the first dictionary is corresponding to an horizon
- each element of the second layer dictionary is corresponding to a target
"""
pass

def test_stationarity(self):
return all(
[stationarity_tests(self.background_series[c]) for c in self.background_series.components]
)