Skip to content

Commit

Permalink
Rectify test_zero1.py once optim.load_state_dict doesn't guarantee im…
Browse files Browse the repository at this point in the history
…mutability (#5382)

* [TEST ONLY] print statements for test_zero1.py to debug

* Try fix

* Rectify test_zero1.py to account for state_dict modification

* Fix lint
  • Loading branch information
janeyx99 committed Jul 31, 2023
1 parent fbec7c1 commit ca5eab8
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions test/test_zero1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer
from torch_xla import runtime as xr
from torch.testing._internal.common_utils import TestCase
from copy import deepcopy

import unittest

Expand Down Expand Up @@ -34,18 +35,22 @@ def test_zero1(self):

opt1.step()
opt2.step()
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

s1 = opt1.state_dict()
s2 = opt2.state_dict()
self.assertEqual(s1, s2['base'])

# deepcopy s1 to load later because pytorch optimizers do not guarantee the input
# state_dict will not be modified. on the other hand, s2 has this guarantee.
s1_clone = deepcopy(s1)

opt1.load_state_dict(s1)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

# step still runnable
opt1.step()
opt2.step()
opt1.load_state_dict(s1)
opt1.load_state_dict(s1_clone)
opt2.load_state_dict(s2)
self.assertEqual(opt1.state_dict(), opt2.state_dict()['base'])

Expand Down

0 comments on commit ca5eab8

Please sign in to comment.