Skip to content

Commit

Permalink
Only make a shallow copy when loading optimizer state_dict
Browse files Browse the repository at this point in the history
ghstack-source-id: cee33dcf6d531cb942fb9cde6947943dd7d709b2
Pull Request resolved: #106082
  • Loading branch information
janeyx99 committed Jul 28, 2023
1 parent 4d3ea5d commit 2a04b93
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ e3900d2ba5c9f91a24a9ce34520794c8366d5c54
30fb2c4abaaaa966999eab11674f25b18460e609
# 2023-06-06 clang-format on Foreach / Multi-Tensor-Apply
515c4279416f13fcc3c898e560f8ae8f15139a03
# [optim][BE] split test file into logical parts: SWA, LR, optim
a53cda1ddc15336dc1ff0ce1eff2a49cdc5f882e
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f5edcb2088195db71bcd36d0f8f1b6a5e663afd8
498bce41fccac29a1f1a4310ed7779102057cc78
6 changes: 2 additions & 4 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,6 @@ def fn_base(optimizer, weight, bias):
optimizer_c.step(fn_c)
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)
# Make sure state dict is deterministic with equal but not identical parameters
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
# Make sure repeated parameters have identical representation in state dict
Expand Down Expand Up @@ -301,7 +299,7 @@ def fn_base(optimizer, weight, bias):
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)

# Make sure state dict wasn't modified
# Make sure state_dict_c isn't modified by merely calling load_state_dict
self.assertEqual(state_dict, state_dict_c)

# Make sure that device of state['step'] is still CPU
Expand All @@ -312,7 +310,7 @@ def fn_base(optimizer, weight, bias):
for state in new_state_dict["state"].values():
self.assertEqual(state["step"].device.type, "cpu")

for _i in range(20):
for _ in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
self.assertEqual(weight, weight_cuda)
Expand Down
8 changes: 5 additions & 3 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,8 @@ def load_state_dict(self, state_dict: StateDict) -> None:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# shallow copy, to be consistent with module API
state_dict = state_dict.copy()

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
Expand All @@ -722,7 +722,9 @@ def load_state_dict(self, state_dict: StateDict) -> None:

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']

# Deepcopy as we write into saved_groups later to update state
saved_groups = deepcopy(state_dict['param_groups'])

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
Expand Down

0 comments on commit 2a04b93

Please sign in to comment.