Skip to content

Commit

Permalink
Update on "Fix _load_from_state_dict for num_batches_tracked in batch…
Browse files Browse the repository at this point in the history
…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
  • Loading branch information
mikaylagawarecki committed Dec 7, 2023
1 parent 87ada87 commit d21f7c4
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5312,7 +5312,6 @@ def test_batchnorm_load_state_dict(self):
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(0))

bn.num_batches_tracked = torch.tensor(10)
state_dict = bn.state_dict()
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))

empty_dict = OrderedDict()
Expand All @@ -5328,7 +5327,6 @@ def test_batchnorm_load_state_dict(self):
meta_bn.load_state_dict(empty_dict, assign=True, strict=False)
self.assertEqual(meta_bn.state_dict()["num_batches_tracked"], torch.tensor(0))


def test_pairwise_distance(self):
input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
Expand Down

0 comments on commit d21f7c4

Please sign in to comment.