Skip to content

Commit

Permalink
Disable dynamo on some opt methods and differentiable optimizer tests (
Browse files Browse the repository at this point in the history
…#103066)

- Disables dynamo on the differentiable optimizer tests
- Disables dynamo on some test methods which expose a very rare dynamo edge case
- Disables dynamo on export/save optimizer state methods because it shouldn't trace those anyway.

I have a draft PR to fix the two tests marked skip due to unsupported mutation of step.

Pull Request resolved: #103066
Approved by: https://github.com/janeyx99, https://github.com/malfet
  • Loading branch information
mlazos authored and pytorchmergebot committed Jun 7, 2023
1 parent f760899 commit 0769a50
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
28 changes: 26 additions & 2 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
skipIfRocm,
skipIfTorchDynamo
)

from torch._dynamo import disable as disable_dynamo

from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_device_type import largeTensorTest
from typing import Dict, Any, Tuple
Expand Down Expand Up @@ -190,12 +193,23 @@ def fn():
else:
self.assertLess(fn().item(), initial_value)

# Note: disable dynamo on this function
# This allows us to continue running actual logic of the optimizer
# tests in dynamo without tracing this test code which has a lot of unsupported
# behavior
@disable_dynamo(recursive=False)
def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None):
weight = Parameter(weight)
bias = Parameter(bias)
with torch.no_grad():
input = input.clone().detach().requires_grad_()

# Note: Disable dynamo on this function
# This avoids a bug where input_cuda is not detected in the environment
# because it currently is not defined in the local environmet. Unable to repro
# anywhere else however and this is test code that we don't need to spend
# time getting dynamo to trace unless the issue repros in real models.
@disable_dynamo(recursive=False)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input_cuda if weight.is_cuda else input
Expand All @@ -219,7 +233,7 @@ def fn_base(optimizer, weight, bias):
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
# Run both optimizers in parallel
for _ in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
Expand Down Expand Up @@ -1072,6 +1086,7 @@ def test_sparse_adam(self):
optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])

# ROCm precision is too low to pass this test
@skipIfTorchDynamo("Unsupported mutation of step")
def test_adadelta(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 4e-3
Expand Down Expand Up @@ -1114,6 +1129,7 @@ def test_adadelta(self):
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
optim.Adadelta(None, lr=1e-2, rho=1.1)

@skipIfTorchDynamo("Unsupported mutation of step")
def test_adadelta_complex(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 2e-2
Expand Down Expand Up @@ -1322,6 +1338,7 @@ def test_radam(self):
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
optim.RAdam(None, lr=1e-2, weight_decay=-1)

@skipIfTorchDynamo("Unsupported mutation of step")
def test_rmsprop(self):
for foreach in (False, True):
self._test_basic_cases(
Expand Down Expand Up @@ -1782,7 +1799,9 @@ def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
)


@skipIfTorchDynamo("Differentiable optimizers not supported")
class TestDifferentiableOptimizer(TestCase):

def test_sgd(self):
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1800,6 +1819,7 @@ def test_sgd(self):
),
)


def test_adam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1825,6 +1845,7 @@ def test_adam(self):
),
)


def test_rmsprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand Down Expand Up @@ -1857,6 +1878,7 @@ def test_rmsprop(self):
),
)


def test_adadelta(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1878,6 +1900,7 @@ def test_adadelta(self):
),
)


def test_adagrad(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1898,6 +1921,7 @@ def test_adagrad(self):
),
)


def test_adamax(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1919,6 +1943,7 @@ def test_adamax(self):
),
)


@skipIfTorchDynamo("The inplace mu update fails with dynamo, "
"since this is only happening when differentiable is enabled, skipping for now")
def test_asgd(self):
Expand All @@ -1944,7 +1969,6 @@ def test_asgd(self):
),
)

@skipIfTorchDynamo()
def test_rprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,9 @@ def patch():
opt.step = disable(opt.step)

opt.zero_grad = disable(opt.zero_grad)
opt.state_dict = disable(opt.state_dict)
opt.load_state_dict = disable(opt.load_state_dict)
opt.add_param_group = disable(opt.add_param_group)

# disable any currently set hooks
# Note: we only want to disable the profiling hook
Expand Down

0 comments on commit 0769a50

Please sign in to comment.