Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/ loading metrics and loss in load_from_checkpoint #1759

Merged
merged 10 commits into from May 23, 2023

Conversation

madtoinou
Copy link
Collaborator

Fixes #1758.

Summary

Since loss_fn and torch_metrics are not saved in PLForecastingModule checkpoints, they must be re-created using the model.model_params values so that the training continue with the proper loss (and continue to report the desired torch metrics).

Other Information

Added the corresponding unittests

@madtoinou madtoinou added this to In review in darts via automation May 11, 2023
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @madtoinou for this, Looks good :)

I would be interested to see if we can let PL handle the saving/loading of these parameters by adapting PLForecastingModule.on_save_checkpoint and PLForecastingModule.on_load_checkpoint.

@codecov-commenter
Copy link

codecov-commenter commented May 15, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: -0.13 ⚠️

Comparison is base (1efb1f8) 94.19% compared to head (37956ac) 94.06%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1759      +/-   ##
==========================================
- Coverage   94.19%   94.06%   -0.13%     
==========================================
  Files         125      125              
  Lines       11505    11495      -10     
==========================================
- Hits        10837    10813      -24     
- Misses        668      682      +14     
Impacted Files Coverage Δ
...arts/models/forecasting/torch_forecasting_model.py 90.15% <ø> (-0.21%) ⬇️
darts/models/forecasting/pl_forecasting_module.py 93.98% <100.00%> (+0.09%) ⬆️

... and 10 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome that it worked with the checkpointing :)
Actually when you mentioned that we ignore loss_fn, and torch_metrics when saving the hyperparameters, I tested if we can achieve the same thing by removing the ignore, and it works :) I left a comment.

After this change we can merge 🚀


def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# by default our models are initialized as float32. For other dtypes, we need to cast to the correct precision
# before parameters are loaded by PyTorch-Lightning
dtype = checkpoint["model_dtype"]
self.to_dtype(dtype)

# restoring attributes necessary to resume from training properly
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I just saw that we don't load the "train_sample_shape" from checkpoint. I think we should add this here as well, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, it's already loaded when calling load_weights_from_checkpoint(). My guess is that since it's one of the constructor argument and that it does not require any processing, the de-serializing of the checkpoint by Pytorch Lightning does the job.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, looks great! Thanks a lot @madtoinou 💯 🚀

@dennisbader dennisbader merged commit 31da6d3 into master May 23, 2023
7 of 9 checks passed
darts automation moved this from In review to Done May 23, 2023
@dennisbader dennisbader deleted the fix/load_loss_ckpt branch May 23, 2023 09:17
alexcolpitts96 pushed a commit to alexcolpitts96/darts that referenced this pull request May 31, 2023
* fix: loss_fn and torch_metrics are properly restored when calling laoding_from_checkpoint()

* fix: moved fix to the PL on_save/on_load methods instead of load_from_checkpoint()

* fix: address reviewer comments, loss and metrics objects are saved in the constructor

* update changelog

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
@dennisbader dennisbader moved this from Done to Released in darts Aug 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
darts
Released
Development

Successfully merging this pull request may close these issues.

Unable to resume training from check point with custom loss function in NBEATS
3 participants