Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Parity tests for functional optimizer step_param
Pull Request resolved: #61756 DDP will support running optimizer as communication hook with optimizers that support a per-parameter/gradient step function `step_param`. Add parity tests as we implement more optimizers that support step_param to ensure parity with regular optimizers. ghstack-source-id: 134272866 Differential Revision: [D29727549](https://our.internmc.facebook.com/intern/diff/D29727549/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29727549/)!
- Loading branch information
1 parent
ad7a87c
commit 481fdd1
Showing
2 changed files
with
97 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import unittest | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.optim import SGD | ||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS | ||
|
||
if not IS_WINDOWS: | ||
from torch.distributed.optim.functional_sgd import _FunctionalSGD | ||
_SUPPORTED_OPTIM_MAPPING = { | ||
SGD: _FunctionalSGD, | ||
} | ||
|
||
|
||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
torch.manual_seed(0) | ||
self.lin1 = nn.Linear(3, 3, bias=False) | ||
self.lin2 = nn.Linear(3, 3, bias=False) | ||
|
||
def forward(self, t1): | ||
return self.lin2(F.relu(self.lin1(t1))) | ||
|
||
|
||
class TestFunctionalOptimParity(TestCase): | ||
def _validate_parameters(self, params_1, params_2): | ||
for p1, p2 in zip(params_1, params_2): | ||
self.assertEqual(p1, p2) | ||
|
||
def _test_functional_optim_parity(self, optim_cls, *args, **kwargs): | ||
module_optim = MyModule() | ||
module_functional = MyModule() | ||
optim_params = module_optim.parameters() | ||
functional_params = module_functional.parameters() | ||
optim = optim_cls(optim_params, *args, **kwargs) | ||
functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None) | ||
if not functional_optim_cls: | ||
raise ValueError(f"Functional optimizer not implemented for {optim_cls}") | ||
optim_functional = functional_optim_cls([], *args, allow_empty_param_list=True) | ||
if not hasattr(optim_functional, "step_param"): | ||
raise ValueError( | ||
f"Functional optimizer class {optim_functional} must implement step_param method." | ||
) | ||
|
||
# Initial weights should match | ||
self._validate_parameters( | ||
module_optim.parameters(), module_functional.parameters() | ||
) | ||
# Save old parameters to verify optimizer modifies them. | ||
old_module_optim_params = [ | ||
param.clone().detach() for param in module_optim.parameters() | ||
] | ||
old_module_functional_params = [ | ||
param.clone().detach() for param in module_functional.parameters() | ||
] | ||
|
||
t1 = torch.randn(3, 3) | ||
for _ in range(10): | ||
module_optim.zero_grad() | ||
module_functional.zero_grad() | ||
# Forward + Backward | ||
optim_out = module_optim(t1).sum() | ||
functional_out = module_functional(t1).sum() | ||
optim_out.backward() | ||
functional_out.backward() | ||
# Optimizer step | ||
optim.step() | ||
# Functional optimizer step_param | ||
for param in module_functional.parameters(): | ||
grad = param.grad | ||
optim_functional.step_param(param, grad) | ||
|
||
# Validate parameters are equal | ||
for optim_param, functional_param in zip( | ||
module_optim.parameters(), module_functional.parameters() | ||
): | ||
self.assertEqual(optim_param, functional_param) | ||
# Validate parameters are modified. | ||
for i, (optim_param, functional_param) in enumerate( | ||
zip(module_optim.parameters(), module_functional.parameters()) | ||
): | ||
self.assertNotEqual(old_module_optim_params[i], optim_param) | ||
self.assertNotEqual(old_module_functional_params[i], functional_param) | ||
|
||
@unittest.skipIf( | ||
IS_WINDOWS, | ||
"Functional optimizer not support on windows, see https://github.com/pytorch/pytorch/issues/62137", | ||
) | ||
def test_functional_optim_parity(self): | ||
self._test_functional_optim_parity(SGD, 1e-2) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |