Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rin to all torch models #1969

Merged
merged 12 commits into from
Sep 2, 2023
Merged

add rin to all torch models #1969

merged 12 commits into from
Sep 2, 2023

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Aug 27, 2023

Fixes #1121

Summary

  • adds RINorm to all TorchForecastingModels with model creation parameter use_reversible_instance_norm
  • moves the logic to the PLForecastingModule base level, including a wrapper around the forward method.

All models except RNN and Transformer have better performance with RIN. Have to investigate what's going on for these two models:

image

@alexcolpitts96
Copy link
Contributor

@dennisbader, thanks for implementing this since I didn't have time.

I don't know what dataset or training parameters you used to test the performance change; however, RIN isn't a guarantee to improve performance. In the TiDE paper they showed that it didn't always help (table 8, https://arxiv.org/pdf/2304.08424.pdf) so it could be the case for the Transformer and RNNs.

@dennisbader
Copy link
Collaborator Author

Hi @alexcolpitts96 , I actually just applied to models to your notebook example from TiDE against NHiTS.
I just want to make sure that TCNModel and RNNModel work properly with RIN (or not use it for now if the architecture can't handle it. Since RNN is recursive, I think that might be the issue, but have to investigate more)

@gdevos010
Copy link
Contributor

Thank you for adding this!

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great, good job @dennisbader 🚀

darts/models/forecasting/block_rnn_model.py Show resolved Hide resolved
darts/models/forecasting/rnn_model.py Show resolved Hide resolved
@alexcolpitts96
Copy link
Contributor

@dennisbader I did some digging into what it might take to improve performance for Transformers and RNNs using RIN. I agree with @madtoinou that there seems to be some underlying recurrent architectural problem. I was able to get Transformers to give less bad performance; however, it was still pretty awful.

From section 4.2 in this paper:

For data with trend, all attention models show inferior generalizability, especially Fourier attention.

I have found this in my own experience as well where attention only improves performance when applied to some alternate representation (like Seq2Seq context vectors) and even then the improvement was nearly negligible.

As for the recurrent problem? I will see if I can find some explanation, but I think pushing out RIN globally without understanding the RNN problem should be fine.

I was also wondering what your thoughts were about moving to a weekly (or at least regular) release schedule? Patch releases (0.25.x) could help push out builds with features and bug fixes sooner.

@dennisbader
Copy link
Collaborator Author

I was mostly concerned by TCNModel and RNNModel (transformers performed at least not worse than vanilla models).

After another test it seems that TCNModel performs okay when predicting with n <= output_chunk_length. So we can keep support for it.

We will ignore user supplied use_reversible_instance_norm for RNNModel and log a warning when they set it. By default it will not use RIN due to its recurrent nature.

@dennisbader
Copy link
Collaborator Author

@alexcolpitts96, we'll try to release more frequently in the future. Next release is planned for next week.
I see a schedule for every 1-2 months as realistic. This is due to some periods where we have lower capacity.
Patch releases to fix severe issues can of course happen more often.

@codecov-commenter
Copy link

codecov-commenter commented Sep 1, 2023

Codecov Report

Patch coverage is 97.56% of modified lines.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the GitHub App Integration for your organization. Read more.

Files Changed Coverage
darts/models/components/statsforecast_utils.py ø
darts/models/forecasting/pl_forecasting_module.py 94.11%
darts/models/forecasting/block_rnn_model.py 100.00%
darts/models/forecasting/dlinear.py 100.00%
darts/models/forecasting/nbeats.py 100.00%
darts/models/forecasting/nhits.py 100.00%
darts/models/forecasting/nlinear.py 100.00%
darts/models/forecasting/rnn_model.py 100.00%
darts/models/forecasting/tcn_model.py 100.00%
darts/models/forecasting/tft_model.py 100.00%
... and 2 more

📢 Thoughts on this report? Let us know!.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good to me 🚀, can't wait to see the performance gain across all the DL models!

@dennisbader dennisbader merged commit fecb99d into master Sep 2, 2023
9 checks passed
@dennisbader dennisbader deleted the feat/rin_norm_torch_models branch September 2, 2023 10:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reversible Instance Normalization for Accurate Time-Series Forecasting against Distribution Shift.
5 participants