diff --git a/test/test_zero1.py b/test/test_zero1.py index 13588d779348..5ed8d742c9fb 100644 --- a/test/test_zero1.py +++ b/test/test_zero1.py @@ -88,6 +88,54 @@ def test_zero1(self): opt2.step() xm.mark_step() + def test_zero1_load(self): + device = xm.xla_device() + + model = nn.Linear(32, 32) + x = torch.ones((32, 32)) + x.requires_grad = True + model = model.to(device) + x = x.to(device) + y = model(x).sum() + y.backward() + xm.mark_step() + + #original optimizer + opt = ZeroRedundancyOptimizer( + model.parameters(), + torch.optim.SGD, + lr=0.5, + momentum=0.5, + grad_clipping=True) + + opt.step() + + #creating a dummy to confirm reload is correct + dummy_model = nn.Linear(32, 32) + dummy_model = dummy_model.to(device) + reloaded_opt = ZeroRedundancyOptimizer( + dummy_model.parameters(), + torch.optim.SGD, + lr=0.1, + momentum=0.1, + grad_clipping=True) + + orig_opt_state = opt.state_dict() + + #reloading the state dict, not performing torch.save + # as it is unnecessary here, the output of torch.load + # is same as what is directly used here. + reloaded_opt.load_state_dict(orig_opt_state) + + self.assertEqual(reloaded_opt['param_groups'], + orig_opt_state['param_groups']) + + self.assertEqual(reloaded_opt['state'], orig_opt_state['state']) + + self.assertEqual(reloaded_opt['base_state'], orig_opt_state['base_state']) + + self.assertEqual(reloaded_opt['shape_info'], orig_opt_state['shape_info']) + def _mp_fn(index): device = xm.xla_device() diff --git a/torch_xla/distributed/zero_redundancy_optimizer.py b/torch_xla/distributed/zero_redundancy_optimizer.py index b76b53ee42c8..fbbe7d33aa0e 100644 --- a/torch_xla/distributed/zero_redundancy_optimizer.py +++ b/torch_xla/distributed/zero_redundancy_optimizer.py @@ -519,6 +519,7 @@ def load_state_dict(self, state_dict): tmp = self.base_optimizer.state_dict() tmp['state'] = base_state + tmp['param_groups'] = state_dict['param_groups'] self.base_optimizer.load_state_dict(tmp) if 'sharded_master_weights' in state_dict: master_weights = state_dict['sharded_master_weights']