Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ShapExplainer + static covariates #1450

Closed
alexanderlange53 opened this issue Dec 28, 2022 · 2 comments
Closed

[BUG] ShapExplainer + static covariates #1450

alexanderlange53 opened this issue Dec 28, 2022 · 2 comments
Labels
improvement New feature or improvement

Comments

@alexanderlange53
Copy link

alexanderlange53 commented Dec 28, 2022

Describe the bug
Hi all,
the ShapExplainer is throwing an error when used with static covariates:
The number of features in data (10) is not the same as it was in training data (11)

To Reproduce

import numpy as np
import pandas as pd
from darts.explainability import ShapExplainer

from darts import TimeSeries
from darts.models import LightGBMModel
from darts.utils import timeseries_generation as tg

period = 20
sine_series = tg.sine_timeseries(
    length=4 * period, value_frequency=1 / period, column_name="series", freq="h"
)

sine_vals = sine_series.values()
linear_vals = np.expand_dims(np.linspace(1, -1, num=19), -1)

sine_vals[21:40] = linear_vals
sine_vals[61:80] = linear_vals
irregular_series = TimeSeries.from_times_and_values(
    values=sine_vals, times=sine_series.time_index, columns=["series"]
)

def get_model_params_2():
    """helper function that generates model parameters"""
    return {
        "lags": int(period / 2),
        "output_chunk_length": int(period / 2),
        "add_encoders": {
            "datetime_attribute": {"future": ["hour"]}
        },  
    }

sine_series_st_bin = sine_series.with_static_covariates(
    pd.DataFrame(data={"curve_type": [1]})
)
irregular_series_st_bin = irregular_series.with_static_covariates(
    pd.DataFrame(data={"curve_type": [0]})
)

train_series = [sine_series_st_bin, irregular_series_st_bin]
for series in train_series:
    print(series.static_covariates)

model =  LightGBMModel(**get_model_params_2())

model.fit(train_series)
explainer = ShapExplainer(model, background_series=train_series)
explainer.summary_plot()

System:

  • Python version: [3.8.10]
  • darts version [0.23]

Additional context
The error message says that the number of features in the data passed to the ShapExplainer does not match the number of features in the data used to train the model, when in fact it does. The example is with LightGBM, but I get the same error with other models.

@alexanderlange53 alexanderlange53 added bug Something isn't working triage Issue waiting for triaging labels Dec 28, 2022
@dennisbader
Copy link
Collaborator

@alexanderlange53, thanks for raising that. Unfortunately, ShapExplainer does not yet support static covariates.
Removing the static covariates from your input series when training the model should work.

I opened ticket #1457 for that.

@hrzn hrzn added improvement New feature or improvement and removed bug Something isn't working triage Issue waiting for triaging labels Jan 4, 2023
@madtoinou
Copy link
Collaborator

Fixed by #1803.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement New feature or improvement
Projects
None yet
Development

No branches or pull requests

4 participants