Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "Fix _load_from_state_dict for num_batches_tracked in batch…
…norm" I approved #110850 which did the following Previously: `num_batches_tracked` not in state_dict when doing `m.load_state_dict(state_dict)` --> always overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor Now: `num_batches_tracked` not in state_dict loaded when doing `m.load_state_dict(state_dict)` --> only overwrite module's `num_batches_tracked` in `load_from_state_dict` with a 0 cpu tensor if module does not have `num_batches_tracked` This causes the following issue: ``` with torch.device('meta'): m = BatchNorm(...) m.load_state_dict(state_dict, assign=True) ``` If `num_batches_tracked` is not in `state_dict`, since `modules's` `num_batches_tracked` is present on meta device, it is not overwritten with a 0 cpu tensor. When compiling, this error is raised ``` AssertionError: Does not support mixing cuda+meta ``` I am not sure whether the explicit check for meta device makes sense as a fix, will add testing if this fix is ok [ghstack-poisoned]
- Loading branch information