-
Notifications
You must be signed in to change notification settings - Fork 879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TFTExplainer #1392
Add TFTExplainer #1392
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #1392 +/- ##
========================================
Coverage 93.84% 93.85%
========================================
Files 126 128 +2
Lines 12174 12416 +242
========================================
+ Hits 11425 11653 +228
- Misses 749 763 +14
☔ View full report in Codecov by Sentry. |
Thanks a lot for this PR @Cattes ! Please bear with us... we are a bit slow to review at the moment (busy preparing the release of 0.23), but we'll get at it, and we feel very enthusiastic about adding TFT explainability! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very good @Cattes , I think it'll be a nice addition to Darts. Thanks a lot!
I have some comments, mainly related to the docstrings, but also one question about how we handle the horizons in the returned ExplainabilityResult
- are you sure we should always consider horizon 0 only there? What about the case where the actual forecast horizon is larger?
When n > output_chunk_length
, forward()
will be called multiple times auto-regressively, so we have to be careful there. Maybe it'd make sense to disregard n
in explain()
, and return some result for each horizon in the output_chunk_length
? I might also be missing something.
self._model = model | ||
|
||
@property | ||
def encoder_importance(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add docstrings explaining what this and decoder_importance
are returning? They can be quite useful I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added docstrings to the module and properties. I wasn't 100% sure on the details of the model so if you could have a look at it that would be great? If everything is fine you can resolve this conversation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, I was trying to call this function and ran into an error. it says no attribute 'encoder_sparse_weights'. I went to the tft_model.py and uncommented this code chunk :
return self.to_network_output(
prediction=self.transform_output(out, target_scale=x["target_scale"]),
attention=attn_out_weights,
static_variables=static_covariate_var,
encoder_variables=encoder_sparse_weights,
decoder_variables=decoder_sparse_weights,
decoder_lengths=decoder_lengths,
encoder_lengths=encoder_lengths,
)
It now says TFTModule has no attribute called to_network_output.
Can I get some help regarding how to call the explainer and use it in my code?
"decoder_importance": self.decoder_importance, | ||
} | ||
|
||
def explain(self, **kwargs) -> ExplainabilityResult: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about taking some of the predict()
parameters explicitly? At least series
, past_covariates
, future_covariates
and n
would make sense IMO. It will produce more comprehensible API documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if that is relevant here at all.
I do not understand why predict has to be called to get the proper attention heads of the time series. The learned autoregressive connections should depend on how predict
is called. But if predict
is not called at all the attention_heads
saved in self._model.model._attn_out_weights
do not have the right format. I assume they are still in a state of training and the predict()
call changes that.
If that is the case I would rather remove the **kwargs
completely from the explain
method here and call predict once with self._model.model.output_chunk_length
to get the correct attention heads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree with you we need to call predict()
here. However predict()
takes a lot more arguments than just n
. It takes series
(the series to predict), as well as covariates arguments and other arguments: see the API doc.
I think we should probably change the signature of explain()
to something like
def explain(self, series, past_covariates, future_covariates, **kwargs) -> ExplainabilityResult
This way in the docstring you can list series
, past_covariates
and future_covariates
, and explain that those are passed down to predict()
. You can also say that n
will always be set to output_chunk_length
(unless I'm wrong I think that's always what's needed), and that **kwargs
can contain extra arguments for the predict method and link to the API documentation of TFTModel.predict()
.
I hope it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think calling predict()
is just a technicality to get to the correct attention weights. I don't think the way we call predict
matters at all for the result, its just important that it was called (for whatever reason).
If I understand it correctly the attention weights are learned during training and are not impacted by the data used in the predict
call.
They don't have a similar logic behind them like shapley values but are learned during the training and are a fixed part of the trained model.
Maybe I am wrong, but if I am right I would rather remove all parameter passed to explain()
and have the predict()
call happen without the user needing to know about it at all.
# return the explainer result to be used in other methods | ||
return ExplainabilityResult( | ||
{ | ||
0: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always relating to horizon 0 only? How about the cases where predict()
is called with n > 1
above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to set the 0
here to be compatible with the ForecastingModelExplainer
base class. To get the attention_heads
the predict
method of the TFT class has to be called or the attention_heads will not show the correct values. I am not sure why yet. Placing this logic into the explain() method as the ExplainabilityResult
felt like a sensible choice.
We could deviate from the ForecastingModelExplainer
class or add a note to the docstring that the 0
is irrelevant in this context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I follow well, here the explanation is for all forecasted horizons at once, right?
I would then propose the following. We can adapt the class ExplainabilityResult
in order to make it a little bit more flexible:
- It could be used with one explanation per horizon (as now), or
- with one single explanation for all horizons (as required in this case for the TFT).
To accommodate the second case, we could make it possible to build ExplainabilityResult
with only a Dict[str, TimeSeries]
(in addition to Dict[integer, Dict[str, TimeSeries]]
), so we avoid specifying the horizon. We can also adapt ExplainabilityResult.get_explanation()
to make specifying the horizon
optional, and not supported if the underlying explanation is not split by horizon.
WDYT? I would find this cleaner than "hacking" the class by using a fake horizon 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cattes any thoughts on this? ^
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its a good idea to change the class to handle the TFT explainer. I didn't want to do it before discussing it. Having the hack with horizon=0 was just to conform with the given api. It was not an intuitive solution.
I have added Dict[str, TimeSeries]
to the valid type for class initialization and made the horizon
optional.
I also added a few more validations to deal with the different input types explicitly.
@Cattes it seems there's a linting issue preventing the tests from being run. |
Yes, we did some refactoring of the Explainability module a couple of weeks back.
I can implement these things :) |
Hi @Cattes I finished the adaptions for TFTExplainer now. The refactor got quite big, because I took the time to refactor the I also caught one or two bugs on the way that I fixed with it. I updated the PR description with the points I added/adapted. Let me know if you want to go over the changes and/or if you're okay with the new version :) Thanks for this great PR and sorry again for the time it took us to get this through 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice refactoring of the explainability and really cool feature (awaited by a lot of users)!
Some minors comments, this is almost ready for merge 🚀
Thank you @dennisbader and @madtoinou for finishing up the PR! Sorry I could not have a look at it before the merge because I was on holidays. |
Implements explainability for the TFT model as discussed in #675
Summary
I have added a
TFTExplainer
class similar to theShapExplainer
class to get the explainability insights for theTFT
model.The class contains the
encoder_importance
anddecoder_importance
of the trainedTFT
Model (grouped together with aplot
option in theget_variable_selection_weight
method).The
TFTExplainer.explain()
method calls thepredict
method of the trainedTFT
Model to get the attention over time.I have also provided a
plot_attention_heads
method to plot the attention over time either as the average attention, as a heatmap or a detailed plot of all available attention heads.I have added a section on how to use the class to the
13-TFT-example.ipynb
notebook.Other Information
I have oriented myself to the suggestions in the Issue made by @hrzn
The inital code on how to get the details from the TFT class was provided by @MagMueller in the Issue.
Edit (@dennisbader, 27-07-2023):
ForecastingModelExplainer
andExplainabilityResult
to simplify/unify implementation of new explainersexplain()
outputTFTModel
attention mask whenfull_attention=True
TorchForecastingModel
in case of training on a single seriesadd_relative_index