Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFTExplainer #1392

Merged
merged 50 commits into from
Jul 31, 2023
Merged

Add TFTExplainer #1392

merged 50 commits into from
Jul 31, 2023

Conversation

Cattes
Copy link
Contributor

@Cattes Cattes commented Nov 27, 2022

Implements explainability for the TFT model as discussed in #675

Summary

I have added a TFTExplainer class similar to the ShapExplainer class to get the explainability insights for the TFT model.

The class contains the encoder_importance and decoder_importance of the trained TFT Model (grouped together with a plot option in the get_variable_selection_weight method).

The TFTExplainer.explain() method calls the predict method of the trained TFT Model to get the attention over time.

I have also provided a plot_attention_heads method to plot the attention over time either as the average attention, as a heatmap or a detailed plot of all available attention heads.

I have added a section on how to use the class to the 13-TFT-example.ipynb notebook.

Other Information

I have oriented myself to the suggestions in the Issue made by @hrzn

from darts.models import TFTModel
from darts.explainability import TFTExplainer

my_model = TFTModel(...)
my_mode.fit(...)

explainer = TFTExplainer(my_model, ...)

# get the explainability results
results = explainer.explain()

# plot variable selection weights
explainer.plot_variable_selection(results)

# plot the attention over time (three plot options)
explainer.plot_attention(results, plot_type="all")
explainer.plot_attention(results, plot_type="time")
explainer.plot_attention(results, plot_type="heatmap")

# get feature importance values
encoder_importance = results.get_encoder_importance()
decoder_importance = results.get_decoder_importance()
static_covariates_importance = results.get_static_covariates_importance()
# get attention `TimeSeries`
attention = results.get_attention()

The inital code on how to get the details from the TFT class was provided by @MagMueller in the Issue.

Edit (@dennisbader, 27-07-2023):

  • refactored ForecastingModelExplainer and ExplainabilityResult to simplify/unify implementation of new explainers
  • explain() now calls model.predict() with the passed foreground/background series. The attention and feature importances are actually dependent on the input to predict/forward (and not fixed after training).
  • plot methods are now relying on the explain() output
  • added static covariates importance
  • updated the attention plots to have xaxis 0 point where the prediction starts (start of output chunk)
  • added colorbar to attention heat map
  • fixed TFTModel attention mask when full_attention=True
  • fixed an issue with encoded covariates not being properly stored in the TorchForecastingModel in case of training on a single series
  • adapted tests to check on combinations of input series (univariate, multivariate, multiple multivariate) and covariates, encoders, and model creation parameter add_relative_index

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented Nov 29, 2022

Codecov Report

Patch coverage: 94.54% and no project coverage change.

Comparison is base (d30f163) 93.84% compared to head (7e80088) 93.85%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@           Coverage Diff            @@
##           master    #1392    +/-   ##
========================================
  Coverage   93.84%   93.85%            
========================================
  Files         126      128     +2     
  Lines       12174    12416   +242     
========================================
+ Hits        11425    11653   +228     
- Misses        749      763    +14     
Files Changed Coverage Δ
darts/models/forecasting/forecasting_model.py 95.01% <80.00%> (-0.14%) ⬇️
darts/explainability/explainability_result.py 92.13% <90.00%> (-2.87%) ⬇️
darts/explainability/shap_explainer.py 87.69% <91.66%> (-1.66%) ⬇️
darts/explainability/tft_explainer.py 93.15% <93.15%> (ø)
darts/explainability/utils.py 96.66% <96.66%> (ø)
darts/explainability/__init__.py 100.00% <100.00%> (ø)
darts/explainability/explainability.py 100.00% <100.00%> (+3.57%) ⬆️
darts/models/forecasting/regression_model.py 95.36% <100.00%> (+0.03%) ⬆️
darts/models/forecasting/tft_model.py 97.27% <100.00%> (+0.50%) ⬆️
...arts/models/forecasting/torch_forecasting_model.py 90.64% <100.00%> (-0.05%) ⬇️
... and 1 more

... and 4 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@hrzn
Copy link
Contributor

hrzn commented Dec 22, 2022

Thanks a lot for this PR @Cattes ! Please bear with us... we are a bit slow to review at the moment (busy preparing the release of 0.23), but we'll get at it, and we feel very enthusiastic about adding TFT explainability!

Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good @Cattes , I think it'll be a nice addition to Darts. Thanks a lot!
I have some comments, mainly related to the docstrings, but also one question about how we handle the horizons in the returned ExplainabilityResult - are you sure we should always consider horizon 0 only there? What about the case where the actual forecast horizon is larger?
When n > output_chunk_length, forward() will be called multiple times auto-regressively, so we have to be careful there. Maybe it'd make sense to disregard n in explain(), and return some result for each horizon in the output_chunk_length? I might also be missing something.

darts/timeseries.py Outdated Show resolved Hide resolved
darts/explainability/tft_explainer.py Outdated Show resolved Hide resolved
self._model = model

@property
def encoder_importance(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docstrings explaining what this and decoder_importance are returning? They can be quite useful I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added docstrings to the module and properties. I wasn't 100% sure on the details of the model so if you could have a look at it that would be great? If everything is fine you can resolve this conversation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I was trying to call this function and ran into an error. it says no attribute 'encoder_sparse_weights'. I went to the tft_model.py and uncommented this code chunk :
return self.to_network_output(
prediction=self.transform_output(out, target_scale=x["target_scale"]),
attention=attn_out_weights,
static_variables=static_covariate_var,
encoder_variables=encoder_sparse_weights,
decoder_variables=decoder_sparse_weights,
decoder_lengths=decoder_lengths,
encoder_lengths=encoder_lengths,
)

It now says TFTModule has no attribute called to_network_output.

Can I get some help regarding how to call the explainer and use it in my code?

"decoder_importance": self.decoder_importance,
}

def explain(self, **kwargs) -> ExplainabilityResult:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about taking some of the predict() parameters explicitly? At least series, past_covariates, future_covariates and n would make sense IMO. It will produce more comprehensible API documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if that is relevant here at all.

I do not understand why predict has to be called to get the proper attention heads of the time series. The learned autoregressive connections should depend on how predict is called. But if predict is not called at all the attention_heads saved in self._model.model._attn_out_weights do not have the right format. I assume they are still in a state of training and the predict() call changes that.

If that is the case I would rather remove the **kwargs completely from the explain method here and call predict once with self._model.model.output_chunk_length to get the correct attention heads.

Copy link
Contributor

@hrzn hrzn Feb 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree with you we need to call predict() here. However predict() takes a lot more arguments than just n. It takes series (the series to predict), as well as covariates arguments and other arguments: see the API doc.
I think we should probably change the signature of explain() to something like

def explain(self, series, past_covariates, future_covariates, **kwargs) -> ExplainabilityResult

This way in the docstring you can list series, past_covariates and future_covariates, and explain that those are passed down to predict(). You can also say that n will always be set to output_chunk_length (unless I'm wrong I think that's always what's needed), and that **kwargs can contain extra arguments for the predict method and link to the API documentation of TFTModel.predict().
I hope it makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think calling predict() is just a technicality to get to the correct attention weights. I don't think the way we call predict matters at all for the result, its just important that it was called (for whatever reason).
If I understand it correctly the attention weights are learned during training and are not impacted by the data used in the predict call.

They don't have a similar logic behind them like shapley values but are learned during the training and are a fixed part of the trained model.

Maybe I am wrong, but if I am right I would rather remove all parameter passed to explain() and have the predict() call happen without the user needing to know about it at all.

darts/explainability/tft_explainer.py Outdated Show resolved Hide resolved
# return the explainer result to be used in other methods
return ExplainabilityResult(
{
0: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this always relating to horizon 0 only? How about the cases where predict() is called with n > 1 above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to set the 0 here to be compatible with the ForecastingModelExplainer base class. To get the attention_heads the predict method of the TFT class has to be called or the attention_heads will not show the correct values. I am not sure why yet. Placing this logic into the explain() method as the ExplainabilityResult felt like a sensible choice.
We could deviate from the ForecastingModelExplainer class or add a note to the docstring that the 0 is irrelevant in this context.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I follow well, here the explanation is for all forecasted horizons at once, right?
I would then propose the following. We can adapt the class ExplainabilityResult in order to make it a little bit more flexible:

  • It could be used with one explanation per horizon (as now), or
  • with one single explanation for all horizons (as required in this case for the TFT).

To accommodate the second case, we could make it possible to build ExplainabilityResult with only a Dict[str, TimeSeries] (in addition to Dict[integer, Dict[str, TimeSeries]]), so we avoid specifying the horizon. We can also adapt ExplainabilityResult.get_explanation() to make specifying the horizon optional, and not supported if the underlying explanation is not split by horizon.

WDYT? I would find this cleaner than "hacking" the class by using a fake horizon 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cattes any thoughts on this? ^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its a good idea to change the class to handle the TFT explainer. I didn't want to do it before discussing it. Having the hack with horizon=0 was just to conform with the given api. It was not an intuitive solution.
I have added Dict[str, TimeSeries] to the valid type for class initialization and made the horizon optional.
I also added a few more validations to deal with the different input types explicitly.

darts/explainability/tft_explainer.py Outdated Show resolved Hide resolved
darts/timeseries.py Outdated Show resolved Hide resolved
@hrzn
Copy link
Contributor

hrzn commented Jan 24, 2023

@Cattes it seems there's a linting issue preventing the tests from being run.

@dennisbader
Copy link
Collaborator

dennisbader commented Jul 11, 2023

Yes, we did some refactoring of the Explainability module a couple of weeks back.
I also noticed some things while experimenting a bit with the TFTExplainer:

  • TFTModel's attention head is actually dependent on the input. This is why you got different results for the attention head when calling predict() for the first time. So in my opinion we indeed need the explain() foreground arguments.
  • we need to add support for interpretable static covariate variable selection / importance

I can implement these things :)

@dennisbader
Copy link
Collaborator

dennisbader commented Jul 27, 2023

Hi @Cattes I finished the adaptions for TFTExplainer now. The refactor got quite big, because I took the time to refactor the ForecastingModelExplainer and ExplainabilityResult backbone to make it easier to implement new explainers in the future.

I also caught one or two bugs on the way that I fixed with it.

I updated the PR description with the points I added/adapted.

Let me know if you want to go over the changes and/or if you're okay with the new version :)

Thanks for this great PR and sorry again for the time it took us to get this through 🚀
This will be part of the next release which comes in one to two weeks 💯

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactoring of the explainability and really cool feature (awaited by a lot of users)!

Some minors comments, this is almost ready for merge 🚀

darts/explainability/explainability.py Outdated Show resolved Hide resolved
darts/explainability/explainability.py Outdated Show resolved Hide resolved
darts/explainability/explainability_result.py Outdated Show resolved Hide resolved
darts/explainability/tft_explainer.py Outdated Show resolved Hide resolved
darts/explainability/tft_explainer.py Show resolved Hide resolved
darts/explainability/utils.py Outdated Show resolved Hide resolved
darts/explainability/utils.py Outdated Show resolved Hide resolved
darts/models/forecasting/tft_model.py Show resolved Hide resolved
darts/timeseries.py Show resolved Hide resolved
@dennisbader dennisbader merged commit 071c7e8 into unit8co:master Jul 31, 2023
9 checks passed
@Cattes
Copy link
Contributor Author

Cattes commented Jul 31, 2023

Thank you @dennisbader and @madtoinou for finishing up the PR! Sorry I could not have a look at it before the merge because I was on holidays.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

7 participants