Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix num_batches_tracked of BatchNorm when load_state_dict #110850

Closed
wants to merge 1 commit into from

Conversation

FFFrog
Copy link
Collaborator

@FFFrog FFFrog commented Oct 9, 2023

Fixes #110361

as the title shown

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Oct 9, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110850

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0dddef8 with merge base 73170b2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@albanD albanD removed their request for review October 9, 2023 15:37
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 9, 2023
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@mikaylagawarecki
Copy link
Contributor

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2023
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased mrl_fix_batchnorm onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout mrl_fix_batchnorm && git pull --rebase)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
@FFFrog FFFrog deleted the mrl_fix_batchnorm branch October 30, 2023 06:48
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
mikaylagawarecki added a commit that referenced this pull request Dec 7, 2023
…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Dec 7, 2023
…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Dec 7, 2023
…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Dec 7, 2023
I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok

Pull Request resolved: #115285
Approved by: https://github.com/albanD
hyperfraise pushed a commit to hyperfraise/pytorch that referenced this pull request Dec 21, 2023
…ch#115285)

I approved pytorch#110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok

Pull Request resolved: pytorch#115285
Approved by: https://github.com/albanD
hyperfraise pushed a commit to hyperfraise/pytorch that referenced this pull request Dec 21, 2023
…ch#115285)

I approved pytorch#110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok

Pull Request resolved: pytorch#115285
Approved by: https://github.com/albanD
ZhiweiYan-96 pushed a commit to ZhiweiYan-96/pytorch that referenced this pull request Dec 22, 2023
…ch#115285)

I approved pytorch#110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok

Pull Request resolved: pytorch#115285
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: nn release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BatchNorm layer 'num_batches_tracked' overwritten with default value when loading empty state_dict
4 participants