Skip to content

Commit

Permalink
Update on "Fix _load_from_state_dict for num_batches_tracked in batch…
Browse files Browse the repository at this point in the history
…norm"

I approved #110850 which did the following

Previously:
`num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor

Now:
`num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked`  in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked`

This causes the following issue:

```
with torch.device('meta'):
     m = BatchNorm(...)
m.load_state_dict(state_dict, assign=True)
```

If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised

```
AssertionError: Does not support mixing cuda+meta
```     

I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok





[ghstack-poisoned]
  • Loading branch information
mikaylagawarecki committed Dec 7, 2023
1 parent 25affad commit 87ada87
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5312,12 +5312,23 @@ def test_batchnorm_load_state_dict(self):
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(0))

bn.num_batches_tracked = torch.tensor(10)
state_dict = bn.state_dict()
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))

empty_dict = OrderedDict()
bn.load_state_dict(empty_dict, strict=False)
self.assertEqual(bn.state_dict()["num_batches_tracked"], torch.tensor(10))

# test that when `num_batches_tracked` is not in loaded state_dict,
# meta num_batches_tracked is still replaced with singleton 0 tensor
with torch.device('meta'):
meta_bn = torch.nn.BatchNorm2d(3)
self.assertTrue(meta_bn.num_batches_tracked.device == torch.device('meta'))
state_dict.pop("num_batches_tracked")
meta_bn.load_state_dict(empty_dict, assign=True, strict=False)
self.assertEqual(meta_bn.state_dict()["num_batches_tracked"], torch.tensor(0))


def test_pairwise_distance(self):
input1 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
input2 = torch.randn(4, 4, requires_grad=True, dtype=torch.double)
Expand Down

0 comments on commit 87ada87

Please sign in to comment.