-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rework metrics logic to support states #2391
rework metrics logic to support states #2391
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2391 +/- ##
==========================================
- Coverage 93.75% 93.74% -0.01%
==========================================
Files 138 138
Lines 14343 14338 -5
==========================================
- Hits 13447 13441 -6
- Misses 896 897 +1 ☔ View full report in Codecov by Sentry. |
|
||
def update(self, preds, target) -> None: | ||
if preds.shape != target.shape: | ||
raise ValueError("preds and target must have the same shape") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError("preds and target must have the same shape") | |
raise ValueError(f"preds and target must have the same shape, but got {preds.shape} for preds and {target.shape} for target.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated. Note that this is just a test class, this error isn't raised in the training loop itself
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this great PR @tRosenflanz 🚀
It looks really good, just had some minor suggestions.
After those have been addressed, we can merge :)
return loss | ||
|
||
def _compute_metrics(self, metrics): | ||
res = metrics.compute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this method to where method _update_metrics()
is defined?
Also, we can skip as done in _update_metrics()
res = metrics.compute() | |
if not len(metrics): | |
return | |
res = metrics.compute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved and "skipped" when needed. For some reason linter doesn't add empty line by default btw
if preds.shape != target.shape: | ||
raise ValueError( | ||
"preds and target must have the same shape " | ||
f"but got {preds.shape} for preds and {target.shape} for target." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not required for this test
if preds.shape != target.shape: | |
raise ValueError( | |
"preds and target must have the same shape " | |
f"but got {preds.shape} for preds and {target.shape} for target." | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not required for this test
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def __init__(self): | |
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed. Note that this might fail in future if you start adding some args for devices or such
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
**tfm_kwargs, | ||
) | ||
model.fit(self.series) | ||
assert model.model.trainer.logged_metrics["train_NumsCalled"] != 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, just saw this now. We should probably check that it is > 1
since 0
would be incorrect.
assert model.model.trainer.logged_metrics["train_NumsCalled"] != 1 | |
assert model.model.trainer.logged_metrics["train_NumsCalled"] > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, fixed. This tests that update_metrics works correctly as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot @tRosenflanz, read to merge 🚀 💯
Checklist before merging this PR:
Fixes #2390, fixes #2389
Summary
Instead of doing forward/log on metrics on each step it only updates the state at every step and does the computation/log/reset at the end of the epoch.