-
Notifications
You must be signed in to change notification settings - Fork 880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/torchmetrics #996
Feat/torchmetrics #996
Conversation
…/darts into fix/nbeats-nhits-TODOs
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #996 +/- ##
==========================================
- Coverage 92.92% 92.91% -0.01%
==========================================
Files 76 76
Lines 7628 7647 +19
==========================================
+ Hits 7088 7105 +17
- Misses 540 542 +2 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that looks like a good idea. Maybe it'd be nice to augment the new early stopping subsection of the User Guide to also include an example showing how to use a custom metric for early stopping? WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks already for another nice idea and the contribution :)
I think this could be a valuable addition for more feedback on model performance during training.
My main suggestion is that we could/should(?) try to base the approach metrics on instances of torchmetrics.Metric. This can potentially reduce complexity on our side such as setting up metrics, need for additional function call parameters, ...
@hrzn @dennisbader The model now accepts a |
I also added an example of using Ray Tune with a torch model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is definitely a very nice addition, thanks a lot!
I just have one last doubt before we can merge. It seems you tested with univariate series only. I guess it might not work out of the box for multivariate series? That is, if y
and y_hat
have some dimension (batch, time, dim, n_params)
with dim > 1
. In this case, my guess would be that we need to call metrics()
for each of the y[:, i, :]
and somehow aggregate/reduce the results. We could take the mean as a default reduction - so what is tracked in those cases is the average "metric" over all components.
It'd be nice to also add a test case for the univariate cases.
The metrics work with multivariate series. You can calculate regression metrics on multidimensional data as long as they have the same shape. The aggregate/reductions happens in the metric calculation, the same as calculating multiple time steps. Hopefully that explanation makes sense. I added test cases for multivariate time series |
OK perfect, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice thank you :) Looks great!
I left two suggestions two avoid try/except and avoid computing anything that is not required when there is no metric available (such as sampling).
After these it's ready to merge :) 👍
@@ -58,6 +62,9 @@ def __init__( | |||
PyTorch loss function used for training. | |||
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified. | |||
Default: ``torch.nn.MSELoss()``. | |||
torch_metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you copy paste the docstring into all our TorchForecastingModels (TFTModel, NBEATSModel, ...)? Otherwise this will not be shown in the model documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Are we able to merge this branch? |
Yes. Thanks! Will be part of the next release (in a few days). |
Fixes #995.
Summary
Added Torch Metrics to training and validation