Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix custom module for RNNModel and add tests #2088

Merged
merged 10 commits into from
Dec 10, 2023
Merged

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Nov 23, 2023

Fixes #1523, fixes #2082.

Summary

  • adds CustomRNNModule and CustomBlockRNNModule for defining custom rnn models.
  • To create a new module, subclass from :class:CustomBlockRNNModule and:
    • Define the architecture in the module constructor (__init__())
    • Add the forward() method and define the logic of your module's forward pass
    • Use the custom module class when creating a new Block/RNNModel with parameter model.
  • fixes error when passing a custom rnn module to RNNModel. The model expected an instantiated module object, but then attempted to instantiate the module from the object when creating an RNNmodel (resulting in calling the forward() pass instead of the module constructor).

@codecov-commenter
Copy link

codecov-commenter commented Nov 23, 2023

Codecov Report

Attention: 3 lines in your changes are missing coverage. Please review.

Comparison is base (7cfdf62) 93.91% compared to head (facbc4e) 93.91%.
Report is 2 commits behind head on master.

Files Patch % Lines
darts/models/forecasting/block_rnn_model.py 96.42% 1 Missing ⚠️
darts/models/forecasting/rnn_model.py 96.96% 1 Missing ⚠️
...arts/models/forecasting/torch_forecasting_model.py 50.00% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #2088   +/-   ##
=======================================
  Coverage   93.91%   93.91%           
=======================================
  Files         135      135           
  Lines       13299    13319   +20     
=======================================
+ Hits        12490    12509   +19     
- Misses        809      810    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you @dennisbader for taking care of this!

@dennisbader dennisbader merged commit 4170093 into master Dec 10, 2023
9 checks passed
@dennisbader dennisbader deleted the fix/rnn_custom_module branch December 10, 2023 12:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants