Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xai init #909

Merged
merged 98 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
784cf13
xai init
dumjax Apr 13, 2022
5797010
small corrections
dumjax Apr 13, 2022
9fb8bb1
corrections and adds
dumjax Apr 19, 2022
efbb2f6
corrections
dumjax Apr 20, 2022
ce9de7a
linting corrections
dumjax Apr 20, 2022
16a1f57
lint_isort correct
dumjax Apr 20, 2022
5c8fd17
Merge branch 'master' into feat/xai
hrzn Jun 21, 2022
0c45389
Merge branch 'master' into feat/xai
hrzn Jun 30, 2022
a6cc229
small improvements
hrzn Jun 30, 2022
2f0fcf0
Explainability class corrections
dumjax Jul 20, 2022
25a6845
small changes + docstrings
dumjax Jul 28, 2022
ef960db
Shap method choice first adds
dumjax Aug 3, 2022
a4ca533
small adds
dumjax Aug 3, 2022
21787c7
explain new version
dumjax Aug 5, 2022
f1f7680
Merge branch 'master' into feat/xai
hrzn Aug 7, 2022
021e09e
plots and checks
dumjax Aug 9, 2022
62f6518
bugs on cov checks + Xy facto step 1
dumjax Aug 31, 2022
f2d471d
tabularization.py add
dumjax Aug 31, 2022
0f569a3
factor Xy part 2
dumjax Sep 1, 2022
7b1b040
Merge branch 'master' into feat/xai
dumjax Sep 6, 2022
f5139c6
lint
dumjax Sep 6, 2022
d6964f4
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0db6258
clean
dumjax Sep 6, 2022
4ff8108
clean
dumjax Sep 6, 2022
e227ce5
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
09f5ac8
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
cf55086
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
0549a5e
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
e811731
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
41d4cd3
Update darts/explainability/explainability.py
dumjax Sep 6, 2022
17aac2a
clean
dumjax Sep 6, 2022
86181cf
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
2661d01
Update darts/explainability/shap_explainer.py
dumjax Sep 6, 2022
00b62e0
clean
dumjax Sep 6, 2022
6947e35
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 6, 2022
7dc9488
corrections
dumjax Sep 7, 2022
5d6672a
encoders - facto corrections - others
dumjax Sep 12, 2022
be8e7a0
Merge branch 'master' into feat/xai
dumjax Sep 12, 2022
712eb57
timeseries_generation bugs + unit tests start
dumjax Sep 12, 2022
4827c75
corrections + Unit tests
dumjax Sep 13, 2022
0d4c978
corrections
dumjax Sep 13, 2022
1c60312
corrections tests multivariates
dumjax Sep 14, 2022
96ae80a
Merge branch 'master' into feat/xai
dumjax Sep 14, 2022
a04f779
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
f4078df
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
3ab2dd0
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
be0aabe
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4d85c5c
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
44767bd
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
488476a
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
4201117
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
237a7cf
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
373fb12
Update darts/explainability/explainability.py
dumjax Sep 22, 2022
39cb2c9
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a46ddcf
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1fc9ef8
docstrings corrections and others
dumjax Sep 22, 2022
49e4b53
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
e8694a2
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
5b528d8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
a3fe9da
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1857a8e
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
19303ef
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
d640c65
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
213e670
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
40e97e8
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
17bc7d4
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
ede4969
Update darts/explainability/shap_explainer.py
dumjax Sep 22, 2022
1b19e14
corrections
dumjax Sep 22, 2022
db40e80
Merge branch 'feat/xai' of github.com:unit8co/darts into feat/xai
dumjax Sep 22, 2022
77e9c2a
corrections
dumjax Sep 27, 2022
86c0274
timeseries generation small correction
dumjax Sep 27, 2022
72bd897
corrections
dumjax Sep 27, 2022
e4eafcc
Merge branch 'master' into feat/xai
dumjax Sep 27, 2022
98823c0
create_lagged_data adds | unit testing | encoders
dumjax Sep 28, 2022
d8cf73a
correction bugs no covariate in explain()
dumjax Sep 29, 2022
f1e3c43
Merge branch 'master' into feat/xai
hrzn Sep 29, 2022
4534eed
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
23cf14e
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
babd744
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
e40712a
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4aba547
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
4000e81
Update darts/utils/data/encoders.py
dumjax Sep 30, 2022
e9e7052
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
d32c02a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
2c34089
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
520846a
Update darts/explainability/explainability.py
dumjax Sep 30, 2022
dd7f200
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
fee5750
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
9a74e28
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
b60dcf4
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
31a8a3d
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
424d794
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
cca3bd8
Update darts/explainability/shap_explainer.py
dumjax Sep 30, 2022
ff4b13a
corrections
dumjax Oct 4, 2022
feb1e70
test comment
dumjax Oct 4, 2022
0dee794
Some doc fixes
hrzn Oct 4, 2022
b29ae68
More small fixes
hrzn Oct 4, 2022
4915d32
Last small fixes
hrzn Oct 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions darts/explainability/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from darts.explainability.shap_explainer import ShapExplainer
dumjax marked this conversation as resolved.
Show resolved Hide resolved
305 changes: 305 additions & 0 deletions darts/explainability/explainability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
"""
Forecasting Model Explainer Base Class
------------------------------
dumjax marked this conversation as resolved.
Show resolved Hide resolved
A forecasting model explainer captures an already fitted forecasting model, and apply an Explainability model
to this forecasting model. Its purpose is to be able to explain each past input contribution to a given model forecast.
This 'explanation' depends on the characteristics of the XAI model chosen (shap, lime etc...).
dumjax marked this conversation as resolved.
Show resolved Hide resolved

"""
from abc import ABC, abstractmethod
from cmath import inf
from typing import Dict, Optional, Sequence, Union

from numpy import integer

from darts import TimeSeries
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.models.forecasting.torch_forecasting_model import TorchForecastingModel
from darts.utils.statistics import stationarity_tests

logger = get_logger(__name__)

MIN_BACKGROUND_SAMPLE = 10


class ForecastingModelExplainer(ABC):
@abstractmethod
def __init__(
self,
model: ForecastingModel,
background_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
background_past_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
background_future_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
):
dumjax marked this conversation as resolved.
Show resolved Hide resolved
"""The base class for forecasting model explainers. It defines the *minimal* behavior that all
forecasting model explainers support.

Naming:
- A background time series is a time series with which we 'train' the Explainer model.
- A foreground time series is the time series we will explain according to the fitted Explainer model.

Parameters
----------
model
A ForecastingModel we want to explain. It has to be fitted first.
background_series
A TimeSeries or a list of time series we want to use to 'train' with any foreground we want to explain.
This is optional, for 2 reasons:
- In general we want to keep the training_series of the model and this is the default one,
but in case of multiple time series training (global or meta learning) the ForecastingModel doesn't
save them. In this case we need to feed a background time series.
- We might want to consider a reduced well chosen background in order to reduce computation
time.
background_past_covariates
A past covariates TimeSeries or list of TimeSeries that the model needs once fitted.
background_future_covariates
A future covariates TimeSeries or list of TimeSeries that the model needs once fitted.
"""
dumjax marked this conversation as resolved.
Show resolved Hide resolved
if not model._fit_called:
raise_log(
ValueError(
"The model must be fitted before instantiating a ForecastingModelExplainer."
),
logger,
)

if model._is_probabilistic():
dumjax marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"The model is probabilistic, but n_sample=1 will be used for explainability."
dumjax marked this conversation as resolved.
Show resolved Hide resolved
)

self.model = model

# In case we don't want to fit the Explainer with a specific background time series, we use the one
# already existing in the fitted model input.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
if background_series is None:

raise_if(
(background_past_covariates is not None)
or (background_future_covariates is not None),
"There is background past or future covariates but no background series. Please provide one.",
dumjax marked this conversation as resolved.
Show resolved Hide resolved
)

raise_if(
self.model.training_series is None,
"A background time series has to be provided for a model fitted on multiple time series, as"
"no training series has been saved by the model.",
)

self.background_series = self.model.training_series
hrzn marked this conversation as resolved.
Show resolved Hide resolved
self.background_past_covariates = self.model.past_covariate_series
hrzn marked this conversation as resolved.
Show resolved Hide resolved
self.background_future_covariates = self.model.future_covariate_series

else:

self.background_series = background_series
self.background_past_covariates = background_past_covariates
self.background_future_covariates = background_future_covariates

# ensure list of TimeSeries format
if isinstance(self.background_series, TimeSeries):
self.background_series = [self.background_series]
self.background_past_covariates = (
[self.background_past_covariates]
if self.background_past_covariates
else None
)
self.background_future_covariates = (
[self.background_future_covariates]
if self.background_future_covariates
else None
)

raise_if(
self.model.uses_past_covariates and self.background_past_covariates is None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
"A background past covariates is not provided, but the model needs past covariates.",
)

raise_if(
self.model.uses_future_covariates
and self.background_future_covariates is None,
"A background future covariates is not provided, but the model needs future covariates.",
)

self.target_names = self.background_series[0].columns.to_list()
if self.background_past_covariates is not None:
self.past_covariates_names = self.background_past_covariates[
0
].columns.to_list()
if self.background_future_covariates is not None:
self.future_covariates_names = self.background_future_covariates[
0
].columns.to_list()

self._check_background_covariates()

if not self._test_stationarity():
logger.warning(
"One time series component of the background time series is not stationary."
dumjax marked this conversation as resolved.
Show resolved Hide resolved
" Beware of wrong interpretation with chosen explainability."
)

def _check_background_covariates(self):
dumjax marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(self.model, RegressionModel):
len_target_min = (
len(self.model.lags["target"]) if self.model.lags["target"] else 0
)
len_past_min = (
len(self.model.lags["past"]) if self.model.lags["past"] else 0
dumjax marked this conversation as resolved.
Show resolved Hide resolved
)
len_future_min = (
len(self.model.lags["future"]) if self.model.lags["future"] else 0
)
min_length = max(len_target_min, len_past_min, len_future_min)

elif isinstance(self.model, TorchForecastingModel):
min_length = self.model.input_chunk_length
else:
min_length = inf

# ensure we have the same names between TimeSeries (if list of). Important to ensure homogeneity
# for explained features.
for idx in range(len(self.background_series)):
raise_if_not(
all(
[
self.background_series[idx].columns.to_list()
== self.target_names,
self.background_past_covariates[idx].columns.to_list()
== self.past_covariates_names
if self.background_past_covariates is not None
else True,
self.background_future_covariates[idx].columns.to_list()
== self.future_covariates_names
if self.background_future_covariates is not None
else True,
]
),
"Columns names must be identical between TimeSeries list components (multi-TimeSeries).",
)

# the number of samples we will build for explanation is:
# sum(len(intersection(target, fut_cov, past_cov))- min_length+1). We compare this to a fixed constant min.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand this computation for the number of samples.
Random idea: wouldn't it be simpler to first try and create the dataset, and then only if it is too small raise an error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the first goal was to check that the time index of each of tuple (series , pas , future covariates) were corresponding more or less. Indeed the tabularization function doesn't provide such checks on time index. On top of that I introduced at this level a MIN_BACKGROUND_SAMPLE, independently of which kind of models we use (regression model, torch, type of explainability). If not I would have to check after each different required transformation depending on the model.

nb_background_samples = 0
for idx in range(len(self.background_series)):
nb_background_samples += max(
len(
self.background_series[idx].time_index.intersection(
self.background_past_covariates[idx].time_index.intersection(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this would crash too if we don't have past or future covariates?

self.background_future_covariates[idx].time_index
)
)
)
- min_length
+ 1,
0,
)
raise_if(
nb_background_samples <= MIN_BACKGROUND_SAMPLE,
"The number of samples for the background series is too small.",
)

@abstractmethod
def explain(
self,
foreground_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
foreground_past_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
foreground_future_covariates: Optional[
Union[TimeSeries, Sequence[TimeSeries]]
] = None,
horizons: Optional[Sequence[int]] = None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
target_names: Optional[Sequence[str]] = None,
dumjax marked this conversation as resolved.
Show resolved Hide resolved
dumjax marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[
Dict[integer, Dict[str, TimeSeries]],
Sequence[Dict[integer, Dict[str, TimeSeries]]],
]:
"""
dumjax marked this conversation as resolved.
Show resolved Hide resolved
Main method of the ForecastingExplainer class.
Return a dictionary of dictionaries of (mutivariates) TimeSeries instances
dumjax marked this conversation as resolved.
Show resolved Hide resolved
(or a list of dictionaries of dictionaries, il multiple TimeSeries list):
dumjax marked this conversation as resolved.
Show resolved Hide resolved
- the first dimension corresponds to the horizons we want to explain.
dumjax marked this conversation as resolved.
Show resolved Hide resolved
- the second dimension corresponds to the components of the target time series we want to explain.
dumjax marked this conversation as resolved.
Show resolved Hide resolved


The value of the second dimension dictionary is a (multivariate) TimeSeries instance giving the 'explanation'
dumjax marked this conversation as resolved.
Show resolved Hide resolved
for a given forecast (horizon, target) at any timestamp forecastable corresponding to the foreground
TimeSeries input.

The name convention for each component of this multivariate TimeSeries is:
`name`_`type_of_cov`_lag_`int` where:
dumjax marked this conversation as resolved.
Show resolved Hide resolved
- `name` is the existing name of the component in the original different foreground TimeSeries (target or past
or future).
- `type_of_cov` is the type of covariates. It can take 3 different values: `target`, `past`, `future`.
- `int` is the lag index.

Example:
Let's say we have a model with 2 targets (multivariates) named "T_1" and "T_2", three past covariates we didn't
dumjax marked this conversation as resolved.
Show resolved Hide resolved
name and one future covariate we didn't name. Also, horizons = [0, 1].
The model is a regression model, with lags = 3, lags_past_covariates=[-1, -3], lags_future_covariates = [0]

We provide a foreground_series (not a list), past covariates, future covariates, of length 5.

Then the output will be the following:

output[0]['T_1'] a multivariate TimeSeries containing the 'explanations' of the chosen Explainer, with
component names:
- T_0_target_lag-1
- T_0_target_lag-2
- T_0_target_lag-3
- T_1_target_lag-1
- T_1_target_lag-2
- T_1_target_lag-3
- 0_past_cov_lag-1 (we didn't name the past covariate so it took the default name 0)
- 0_past_cov_lag-3 (we didn't name the past covariate so it took the default name 0)
- 1_past_cov_lag-1 (we didn't name the past covariate so it took the default name 1)
- 1_past_cov_lag-3 (we didn't name the past covariate so it took the default name 1)
- 2_past_cov_lag-1 (we didn't name the past covariate so it took the default name 2)
- 2_past_cov_lag-3 (we didn't name the past covariate so it took the default name 2)
- 0_fut_cov_lag_0 (we didn't name the future covariate so it took the default name 0)

of length 3, as we can explain 5-3+1 forecasts (basically timestamp indexes 4, 5, and 6)


Parameters
----------
foreground_series
Optionally, target timeseries we want to explain. Can be multivariate.
If none is provided, explain will automatically provide the whole background TimeSeries explanation.
foreground_past_covariates
Optionally, past covariate timeseries if needed by the ForecastingModel.
foreground_future_covariates
Optionally, future covariate timeseries if needed by the ForecastingModel.
horizons
Optionally, a list of integer values representing which elements in the future
we want to explain, starting from the first timestamp prediction at 0.
For now we consider only models with output_chunk_length and it can't be bigger than output_chunk_length.
target_names
Optionally, A list of string naming the target names we want to explain.

Returns
-------
a dictionary of dictionary of Timeseries (or a list of such) of explaining values :
- each element of the first dimension dictionary is corresponding to a forecast horizon
- each element of the second dimension dictionary is corresponding to a target name
"""
pass

def _test_stationarity(self):
return all(
[
(
stationarity_tests(background_serie[c])
for c in background_serie.components
)
for background_serie in self.background_series
]
)