New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix num_batches_tracked of BatchNorm when load_state_dict #110850
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110850
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0dddef8 with merge base 73170b2 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@pytorchbot merge -r |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
3efee14
to
0dddef8
Compare
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…0850) Fixes pytorch#110361 as the title shown Pull Request resolved: pytorch#110850 Approved by: https://github.com/mikaylagawarecki
…0850) Fixes pytorch#110361 as the title shown Pull Request resolved: pytorch#110850 Approved by: https://github.com/mikaylagawarecki
…0850) Fixes pytorch#110361 as the title shown Pull Request resolved: pytorch#110850 Approved by: https://github.com/mikaylagawarecki
…norm" I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok [ghstack-poisoned]
…norm" I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok [ghstack-poisoned]
…norm" I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok [ghstack-poisoned]
I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok Pull Request resolved: #115285 Approved by: https://github.com/albanD
…ch#115285) I approved pytorch#110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok Pull Request resolved: pytorch#115285 Approved by: https://github.com/albanD
…ch#115285) I approved pytorch#110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok Pull Request resolved: pytorch#115285 Approved by: https://github.com/albanD
…ch#115285) I approved pytorch#110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok Pull Request resolved: pytorch#115285 Approved by: https://github.com/albanD
Fixes #110361
as the title shown