-
Notifications
You must be signed in to change notification settings - Fork 21.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix _load_from_state_dict for num_batches_tracked in batchnorm #115285
Fix _load_from_state_dict for num_batches_tracked in batchnorm #115285
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115285
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit f08f7a1 with merge base 0ced55e ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test? Also I guess the condition should be flipped?
…norm" I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok [ghstack-poisoned]
…norm" I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok [ghstack-poisoned]
…norm" I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok [ghstack-poisoned]
ghstack-source-id: d2d6fd2897ea9c7180c25b097e80e1f3437d3ab8 Pull Request resolved: #115285
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ch#115285) I approved pytorch#110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok Pull Request resolved: pytorch#115285 Approved by: https://github.com/albanD
I approved #110850 which did the following
Previously:
num_batches_tracked
not in state_dict when doingm.load_state_dict(state_dict)
--> always overwrite module'snum_batches_tracked
inload_from_state_dict
with a 0 cpu tensorNow:
num_batches_tracked
not in state_dict loaded when doingm.load_state_dict(state_dict)
--> only overwrite module'snum_batches_tracked
inload_from_state_dict
with a 0 cpu tensor if module does not havenum_batches_tracked
This causes the following issue:
If
num_batches_tracked
is not instate_dict
, sincemodules's
num_batches_tracked
is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raisedI am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok
Stack from ghstack (oldest at bottom):