Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only make a shallow copy when loading optimizer state_dict #106082

Closed
wants to merge 6 commits into from
Closed
2 changes: 2 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ e3900d2ba5c9f91a24a9ce34520794c8366d5c54
30fb2c4abaaaa966999eab11674f25b18460e609
# 2023-06-06 clang-format on Foreach / Multi-Tensor-Apply
515c4279416f13fcc3c898e560f8ae8f15139a03
# [optim][BE] split test file into logical parts: SWA, LR, optim
a53cda1ddc15336dc1ff0ce1eff2a49cdc5f882e
janeyx99 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
f5edcb2088195db71bcd36d0f8f1b6a5e663afd8
498bce41fccac29a1f1a4310ed7779102057cc78
6 changes: 2 additions & 4 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,6 @@ def fn_base(optimizer, weight, bias):
optimizer_c.step(fn_c)
self.assertEqual(weight, weight_c)
self.assertEqual(bias, bias_c)
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)
# Make sure state dict is deterministic with equal but not identical parameters
self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict())
# Make sure repeated parameters have identical representation in state dict
Expand Down Expand Up @@ -301,7 +299,7 @@ def fn_base(optimizer, weight, bias):
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_cuda.load_state_dict(state_dict_c)

# Make sure state dict wasn't modified
# Make sure state_dict_c isn't modified by merely calling load_state_dict
self.assertEqual(state_dict, state_dict_c)

# Make sure that device of state['step'] is still CPU
Expand All @@ -312,7 +310,7 @@ def fn_base(optimizer, weight, bias):
for state in new_state_dict["state"].values():
self.assertEqual(state["step"].device.type, "cpu")

for _i in range(20):
for _ in range(20):
optimizer.step(fn)
optimizer_cuda.step(fn_cuda)
self.assertEqual(weight, weight_cuda)
Expand Down
8 changes: 5 additions & 3 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,8 @@ def load_state_dict(self, state_dict: StateDict) -> None:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# shallow copy, to be consistent with module API
state_dict = state_dict.copy()

for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
hook_result = pre_hook(self, state_dict)
Expand All @@ -722,7 +722,9 @@ def load_state_dict(self, state_dict: StateDict) -> None:

# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']

# Deepcopy as we write into saved_groups later to update state
saved_groups = deepcopy(state_dict['param_groups'])

if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of "
Expand Down