Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix _load_from_state_dict for num_batches_tracked in batchnorm #115285

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Dec 6, 2023

I approved #110850 which did the following

Previously:
num_batches_tracked not in state_dict when doing m.load_state_dict(state_dict) --> always overwrite module's num_batches_tracked in load_from_state_dict with a 0 cpu tensor

Now:
num_batches_tracked not in state_dict loaded when doing m.load_state_dict(state_dict) --> only overwrite module's num_batches_tracked in load_from_state_dict with a 0 cpu tensor if module does not have num_batches_tracked

This causes the following issue:

with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)

If num_batches_tracked is not in state_dict, since modules's num_batches_tracked is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

AssertionError: Does not support mixing cuda+meta

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok

Stack from ghstack (oldest at bottom):

Copy link

pytorch-bot bot commented Dec 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/115285

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit f08f7a1 with merge base 0ced55e (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test? Also I guess the condition should be flipped?

…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Dec 7, 2023
ghstack-source-id: d2d6fd2897ea9c7180c25b097e80e1f3437d3ab8
Pull Request resolved: #115285
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 7, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/mikaylagawarecki/163/head branch December 11, 2023 15:29
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
…ch#115285)

I approved pytorch#110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok

Pull Request resolved: pytorch#115285
Approved by: https://github.com/albanD
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants