Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TSMixer Model #2293

Merged
merged 38 commits into from
Apr 8, 2024
Merged

Conversation

cristof-r
Copy link
Contributor

@cristof-r cristof-r commented Mar 21, 2024

Add TSMixer model #1807 with several unit tests.
Adopted PyTorch implementation from this repository: https://github.com/ditschuk/pytorch-tsmixer/
The paper can be found here: https://arxiv.org/pdf/2303.06053.pdf

@cristof-r cristof-r marked this pull request as draft March 21, 2024 12:46
@cristof-r
Copy link
Contributor Author

Any feedback is very welcome, for me it seems good so far.
I can add it to the different markdowns and make an example notebook if you like the implementation.

@cristof-r cristof-r marked this pull request as ready for review March 21, 2024 14:31
Copy link

@VascoSch92 VascoSch92 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoa... nice job

darts/models/forecasting/tsmixer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/tsmixer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/tsmixer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/tsmixer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/tsmixer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/tsmixer_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/tsmixer_model.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_tsmixer.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_tsmixer.py Outdated Show resolved Hide resolved
cristof-r and others added 10 commits March 22, 2024 08:59
Co-authored-by: Vasco Schiavo <115561717+VascoSch92@users.noreply.github.com>
Co-authored-by: Vasco Schiavo <115561717+VascoSch92@users.noreply.github.com>
Co-authored-by: Vasco Schiavo <115561717+VascoSch92@users.noreply.github.com>
Co-authored-by: Vasco Schiavo <115561717+VascoSch92@users.noreply.github.com>
Co-authored-by: Vasco Schiavo <115561717+VascoSch92@users.noreply.github.com>
@codecov-commenter
Copy link

codecov-commenter commented Mar 22, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 94.01%. Comparing base (91c7087) to head (cef5678).

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2293      +/-   ##
==========================================
+ Coverage   93.95%   94.01%   +0.05%     
==========================================
  Files         136      137       +1     
  Lines       13687    13857     +170     
==========================================
+ Hits        12860    13027     +167     
- Misses        827      830       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@dennisbader
Copy link
Collaborator

Hi @cristof-r, wow, this indeed looks amazing from a first glance!
Thanks a lot for this great PR, just give us some time to review it :) 🚀 🚀 🚀

@leoniewgnr
Copy link

If there is anything I can help with, please let me know! Looking forward to this:)

@cristof-r
Copy link
Contributor Author

@leoniewgnr If there is anything I can help with, please let me know! Looking forward to this:)

It would be great if you have an interesting idea for a small notebook example to demonstrate the TSMixer.
I was thinking of comparing it against the TFT model, showing its (hopefully) higher performance, like it was demonstrated in the original paper.
Unfortunately using a bigger dataset like the "ETTm1" (which was also used in the paper) takes too long to train and evaluate.

@dennisbader
Copy link
Collaborator

@cristof-r and @leoniewgnr, I'm currently reviewing the PR. There were a couple of things to change, so I started working on a new branch with a couple of adaptions to this PR. I'll soon open a PR to merge into this one.

Among other things it will improve the performance and reduce training time drastically.

While working on it I also made a little notebook for testing with the ETTh1Dataset. The model works pretty well, very close to TiDEModel.

I'll keep you updated.

@cristof-r
Copy link
Contributor Author

@dennisbader sounds great! Thank you very much

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this really nice PR @cristof-r. I took the freedom to push the changes already into this branch. I will comment on the most important changes:

tests

  • removed some of the tests that were either already handled in other test files, or tests that were taking a "long" time to complete (e.g. model performance (accuracy) tests which took ~30 seconds to complete)

model implementation

  • norms:
    • adapted TimeBatchNorm2d implementation to use actual 2d batch norm
    • removed support for RINorm since it should be used with model parameter use_reversible_instance_norm
  • modules
    • made all module classes private to hide them in the rendered documentation
    • removed ConditionalFeatureMixing and instead added the logic to _ConditionalMixerLayer
  • model parameters
    • lowered the default parameters values (e.g. blocks, hidden_size, ...) to make a lighter default version
  • main things that were fixed:
    • before, output_dim for all modules was set to hidden_output_size=hidden_size * output_dim * nr_params, whereas it should just be hdden_size. This was why the model was getting really slow for multivariate target series or probabilistic models.
    • multi-component static covariates were not properly handled. We need to flatten the static covariates, and have static_cov_dim=n components * n static features
    • I believe the static mixing was not handled correctly before. It looks like it was done as described in the paper. However, I'm not sure if the paper described it correctly.
      • before, static covariates were project to hidden_size with a linear layer, and then concatenated with x. In the first block x has only the actual number of input features. So then for the concatenation, x has much lower dimensionalty than x_static.
      • Now at the beginning we apply feature mixing to historical, future, and static covariates separately, and then feed them together to the mixing layers.

model example notebook

  • added an example notebook comparing a probabilistic TSMixer with TiDEmodel on a multivariate dataset, including future covariates (encoders) and static information

Let me know if you agree with the changes :) And again, thanks a lot for this great contribution, really appreciated!

@cristof-r
Copy link
Contributor Author

@dennisbader thank you very much for the improvements, I learned a lot.
Also thank you very much for the darts library in general, it is really great.
Do you already know when the next release will be?

@dennisbader
Copy link
Collaborator

@cristof-r, we're aiming to release within the next two weeks.

@dennisbader dennisbader merged commit 0d5c722 into unit8co:master Apr 8, 2024
7 of 9 checks passed
@cristof-r cristof-r deleted the feature/ts_mixer_model branch April 9, 2024 08:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants