Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/torchmetrics #996

Merged
merged 35 commits into from
Jun 15, 2022
Merged

Feat/torchmetrics #996

merged 35 commits into from
Jun 15, 2022

Conversation

gdevos010
Copy link
Contributor

Fixes #995.

Summary

Added Torch Metrics to training and validation

@codecov-commenter
Copy link

codecov-commenter commented Jun 7, 2022

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.91%. Comparing base (abf12da) to head (d6a0e24).
Report is 485 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #996      +/-   ##
==========================================
- Coverage   92.92%   92.91%   -0.01%     
==========================================
  Files          76       76              
  Lines        7628     7647      +19     
==========================================
+ Hits         7088     7105      +17     
- Misses        540      542       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that looks like a good idea. Maybe it'd be nice to augment the new early stopping subsection of the User Guide to also include an example showing how to use a custom metric for early stopping? WDYT?

darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks already for another nice idea and the contribution :)

I think this could be a valuable addition for more feedback on model performance during training.

My main suggestion is that we could/should(?) try to base the approach metrics on instances of torchmetrics.Metric. This can potentially reduce complexity on our side such as setting up metrics, need for additional function call parameters, ...

darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
@gdevos010
Copy link
Contributor Author

@hrzn @dennisbader The model now accepts a TorchMetric or a MetricCollection. I also added the likelihood test and updated the early stop example

@gdevos010
Copy link
Contributor Author

I also added an example of using Ray Tune with a torch model

Copy link
Contributor

@hrzn hrzn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a very nice addition, thanks a lot!
I just have one last doubt before we can merge. It seems you tested with univariate series only. I guess it might not work out of the box for multivariate series? That is, if y and y_hat have some dimension (batch, time, dim, n_params) with dim > 1. In this case, my guess would be that we need to call metrics() for each of the y[:, i, :] and somehow aggregate/reduce the results. We could take the mean as a default reduction - so what is tracked in those cases is the average "metric" over all components.
It'd be nice to also add a test case for the univariate cases.

@gdevos010
Copy link
Contributor Author

The metrics work with multivariate series. You can calculate regression metrics on multidimensional data as long as they have the same shape. The aggregate/reductions happens in the metric calculation, the same as calculating multiple time steps. Hopefully that explanation makes sense.

I added test cases for multivariate time series

@hrzn
Copy link
Contributor

hrzn commented Jun 13, 2022

The metrics work with multivariate series. You can calculate regression metrics on multidimensional data as long as they have the same shape. The aggregate/reductions happens in the metric calculation, the same as calculating multiple time steps. Hopefully that explanation makes sense.

I added test cases for multivariate time series

OK perfect, thanks!

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice thank you :) Looks great!

I left two suggestions two avoid try/except and avoid computing anything that is not required when there is no metric available (such as sampling).

After these it's ready to merge :) 👍

darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
@@ -58,6 +62,9 @@ def __init__(
PyTorch loss function used for training.
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
Default: ``torch.nn.MSELoss()``.
torch_metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you copy paste the docstring into all our TorchForecastingModels (TFTModel, NBEATSModel, ...)? Otherwise this will not be shown in the model documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

@gdevos010
Copy link
Contributor Author

Are we able to merge this branch?

@hrzn
Copy link
Contributor

hrzn commented Jun 15, 2022

Are we able to merge this branch?

Yes. Thanks! Will be part of the next release (in a few days).

@hrzn hrzn merged commit 2c43352 into unit8co:master Jun 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Torch Metrics
4 participants