Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions test/inductor/test_compiled_optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Owner(s): ["module: inductor"]

import sys
import unittest

from copy import deepcopy

import torch

import torch._inductor

from torch.testing._internal.common_utils import TestCase


aten = torch.ops.aten

try:
try:
from .test_torchinductor import check_model, check_model_cuda, requires_cuda
except ImportError:
from test_torchinductor import check_model, check_model_cuda, requires_cuda
except (unittest.SkipTest, ImportError) as e:
sys.stderr.write(f"{type(e)}: {e}\n")
if __name__ == "__main__":
sys.exit(0)
raise


def make_test(optim_cls, closure=None, **kwargs):
@requires_cuda()
def test_fn(self):
input = torch.ones([10, 10], device="cuda:0")
model_eager = torch.nn.Sequential(
*[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(2)]
)
model_eager(input).sum().backward()

input = torch.ones([10, 10], device="cuda:0")
model_compiled = deepcopy(model_eager)
model_compiled(input).sum().backward()

opt_eager = optim_cls(model_eager.parameters(), **kwargs)
opt_compiled = optim_cls(model_compiled.parameters(), **kwargs)
# run the patcher so that step has the expected structure
torch._dynamo.eval_frame.TorchPatcher.patch()

# unwrap step to avoid a deliberate graph break due to
# a limitation of functionalization/no_grad detection
# see the [Note on graph break] in optimizer.py
# This ignores the outer _use_grad_if_differentiable wrapper
# and instead manually disables grad before calling step, which is fine
# for now as dynamo does not support differentiable optimizers anyway
step_fn = opt_compiled.step.__wrapped__
if closure is not None:

def fn():
step_fn(opt_compiled, closure)

else:

def fn():
step_fn(opt_compiled)

with torch.set_grad_enabled(False):
torch.compile(fn, backend="inductor", fullgraph=True)()
opt_eager.step()

self.assertEqual(
list(model_eager.parameters()), list(model_compiled.parameters())
)
if self.check_kernel_count:
# currently, we compile the step and the rest of the computation
# separately because the step is a single element tensor
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)

return test_fn


class CompiledOptimizerTests(TestCase):
check_model_cuda = check_model_cuda
check_model_cpu = check_model
check_kernel_count = True

def setUp(self):
super().setUp()
torch._inductor.metrics.reset()

def tearDown(self):
super().tearDown()
torch._inductor.metrics.reset()

test_adam = make_test(torch.optim.Adam, lr=0.01)
test_adam_weight_decay = make_test(torch.optim.Adam, lr=0.01, weight_decay=0.01)
2 changes: 0 additions & 2 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,6 @@ def patch():
from ..optim import (
adadelta,
adagrad,
adam,
adamax,
adamw,
asgd,
Expand All @@ -1255,7 +1254,6 @@ def patch():
for opt_mod in (
adadelta,
adagrad,
adam,
adamax,
adamw,
asgd,
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/variables/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class GuardInstallException(Exception):
class OptimizerVariable(UserDefinedObjectVariable):
def __init__(self, value, grad_to_source=None, **kwargs):
super().__init__(value, **kwargs)

for group in self.value.param_groups:
if "capturable" in group:
group["capturable"] = True

if grad_to_source is None:
self.grad_to_source = {}

Expand Down
10 changes: 7 additions & 3 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,9 @@ def adam(params: List[Tensor],
if foreach is None:
foreach = False

if not all(isinstance(t, torch.Tensor) for t in state_steps):
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")

if foreach and torch.jit.is_scripting():
Expand Down Expand Up @@ -339,7 +341,8 @@ def _single_tensor_adam(params: List[Tensor],
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]

if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."

# update step
Expand Down Expand Up @@ -428,7 +431,8 @@ def _multi_tensor_adam(params: List[Tensor],
if len(params) == 0:
return

if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
"If capturable=True, params and state_steps must be CUDA tensors."

Expand Down
8 changes: 5 additions & 3 deletions torch/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def adamw(
See :class:`~torch.optim.AdamW` for details.
"""

if not all(isinstance(t, torch.Tensor) for t in state_steps):
if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
Expand Down Expand Up @@ -382,7 +382,8 @@ def _single_tensor_adamw(
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]

if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert (
param.is_cuda and step_t.is_cuda
), "If capturable=True, params and state_steps must be CUDA tensors."
Expand Down Expand Up @@ -479,7 +480,8 @@ def _multi_tensor_adamw(
if len(params) == 0:
return

if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert all(
p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
), "If capturable=True, params and state_steps must be CUDA tensors."
Expand Down
23 changes: 17 additions & 6 deletions torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,15 @@ def __repr__(self):

# Currently needed by Adam and AdamW
def _cuda_graph_capture_health_check(self):
# If we are compiling, we take the capturable path automatically
# One caveat here is that if we are compiling, we *permit* step/param tensors to be on CPU
# so we do not explicitly enable the capturable flag. Inductor will decide whether cudagraphs
# Note [torch.compile x capturable]
# If we are compiling, we try to take the capturable path automatically by
# setting the flag to True during tracing. Due to this, we skip all the checks
# normally required for determining whether we can use CUDA graphs and
# shunt the responsibility to torch.inductor. This saves time during tracing
# since the checks are slow without sacrificing UX since inductor will warn
# later if CUDA graphs cannot be enabled, e.g.,
# https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.
# Thus, when compiling, inductor will determine if cudagraphs
# can be enabled based on whether there is input mutation or CPU tensors.
if not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available():
capturing = torch.cuda.is_current_stream_capturing()
Expand Down Expand Up @@ -422,11 +428,16 @@ def _process_value_according_to_param_policy(param: Tensor, value: Tensor, param
capturable = pg["capturable"] if "capturable" in pg else False
break

if key != "step" or capturable or fused:
if key == 'step':
if capturable or fused:
return value.to(dtype=torch.float32, device=param.device)
else:
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=param.device)
return value.to(device=param.device)
return value
else:
return value.to(device=param.device)

def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Expand Down