Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable dynamo on some opt methods and differentiable optimizer tests #103066

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 26 additions & 2 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
skipIfRocm,
skipIfTorchDynamo
)

from torch._dynamo import disable
mlazos marked this conversation as resolved.
Show resolved Hide resolved

from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_device_type import largeTensorTest
from typing import Dict, Any, Tuple
Expand Down Expand Up @@ -191,12 +194,23 @@ def fn():
else:
self.assertLess(fn().item(), initial_value)

# Note: disable dynamo on this function
# This allows us to continue running actual logic of the optimizer
# tests in dynamo without tracing this test code which has a lot of unsupported
# behavior
@disable(recursive=False)
def _test_state_dict(self, weight, bias, input, constructor, atol=None, rtol=None):
weight = Parameter(weight)
bias = Parameter(bias)
with torch.no_grad():
input = input.clone().detach().requires_grad_()

# Note: Disable dynamo on this function
# This avoids a bug where input_cuda is not detected in the environment
# because it currently is not defined in the local environmet. Unable to repro
# anywhere else however and this is test code that we don't need to spend
# time getting dynamo to trace unless the issue repros in real models.
@disable(recursive=False)
def fn_base(optimizer, weight, bias):
optimizer.zero_grad()
i = input_cuda if weight.is_cuda else input
Expand All @@ -220,7 +234,7 @@ def fn_base(optimizer, weight, bias):
state_dict = deepcopy(optimizer.state_dict())
state_dict_c = deepcopy(optimizer.state_dict())
optimizer_c.load_state_dict(state_dict_c)
# Run both optimizations in parallel
# Run both optimizers in parallel
for _ in range(20):
optimizer.step(fn)
optimizer_c.step(fn_c)
Expand Down Expand Up @@ -1073,6 +1087,7 @@ def test_sparse_adam(self):
optim.SparseAdam([{"params": [torch.zeros(3, layout=torch.sparse_coo)]}])

# ROCm precision is too low to pass this test
@skipIfTorchDynamo("Unsupported mutation of step")
def test_adadelta(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 4e-3
Expand Down Expand Up @@ -1115,6 +1130,7 @@ def test_adadelta(self):
with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
optim.Adadelta(None, lr=1e-2, rho=1.1)

@skipIfTorchDynamo("Unsupported mutation of step")
def test_adadelta_complex(self):
# Handles https://github.com/pytorch/pytorch/issues/69698
self.rel_tol = 2e-2
Expand Down Expand Up @@ -1323,6 +1339,7 @@ def test_radam(self):
with self.assertRaisesRegex(ValueError, "Invalid weight_decay value: -1"):
optim.RAdam(None, lr=1e-2, weight_decay=-1)

@skipIfTorchDynamo("Unsupported mutation of step")
def test_rmsprop(self):
for foreach in (False, True):
self._test_basic_cases(
Expand Down Expand Up @@ -1783,7 +1800,9 @@ def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
)


@skipIfTorchDynamo("Differentiable optimizers not supported")
class TestDifferentiableOptimizer(TestCase):

def test_sgd(self):
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1801,6 +1820,7 @@ def test_sgd(self):
),
)


def test_adam(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1826,6 +1846,7 @@ def test_adam(self):
),
)


def test_rmsprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand Down Expand Up @@ -1858,6 +1879,7 @@ def test_rmsprop(self):
),
)


def test_adadelta(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1879,6 +1901,7 @@ def test_adadelta(self):
),
)


def test_adagrad(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1899,6 +1922,7 @@ def test_adagrad(self):
),
)


def test_adamax(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand All @@ -1920,6 +1944,7 @@ def test_adamax(self):
),
)


@skipIfTorchDynamo("The inplace mu update fails with dynamo, "
"since this is only happening when differentiable is enabled, skipping for now")
def test_asgd(self):
Expand All @@ -1945,7 +1970,6 @@ def test_asgd(self):
),
)

@skipIfTorchDynamo()
def test_rprop(self):
state = {}
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
Expand Down
3 changes: 3 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,9 @@ def patch():
opt.step = disable(opt.step)

opt.zero_grad = disable(opt.zero_grad)
opt.state_dict = disable(opt.state_dict)
opt.load_state_dict = disable(opt.load_state_dict)
opt.add_param_group = disable(opt.add_param_group)
mlazos marked this conversation as resolved.
Show resolved Hide resolved

# disable any currently set hooks
# Note: we only want to disable the profiling hook
Expand Down