Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize optimizer in dynamo to avoid graph break and tracing slowness #102640

Closed
wants to merge 13 commits into from
6 changes: 0 additions & 6 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,14 +1280,8 @@ def patch():
if opt in excluded_opts:
opt.step = disable(opt.step)

opt._cuda_graph_capture_health_check = disable(
opt._cuda_graph_capture_health_check
)
opt.zero_grad = disable(opt.zero_grad)

if hasattr(opt, "_init_group"):
opt._init_group = disable(opt._init_group)

# disable any currently set hooks
# Note: we only want to disable the profiling hook
# which is the *last* hook applied, we want to keep the no_grad hook
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@
SkipFilesVariable,
TypingVariable,
)

from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable
from .tensor import (
SymNodeVariable,
TensorVariable,
Expand Down Expand Up @@ -578,6 +580,12 @@ def index_source(key):
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif isinstance(value, torch.optim.Optimizer):
return OptimizerVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
elif ProcessGroupVariable.is_process_group(value):
return ProcessGroupVariable(
value,
Expand Down
99 changes: 99 additions & 0 deletions torch/_dynamo/variables/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Dict, List

import torch
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name

from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable


class ArgMappingException(Exception):
pass


class OptimizerVariable(UserDefinedObjectVariable):
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow intialization of the optimizer"""
if name == "_init_group":
try:
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
self.value._init_group(*py_args, **py_kwargs)
self.install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
return ConstantVariable(None)
except ArgMappingException:
# trace normally if we can't map args
pass

return super().call_method(tx, name, args, kwargs)

def var_getattr(self, tx, name):
if name == "_init_group":
return GetAttrVariable(self, name)

return super().var_getattr(tx, name)

def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""

def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif isinstance(arg, ConstDictVariable) and isinstance(
arg.source, GetItemSource
mlazos marked this conversation as resolved.
Show resolved Hide resolved
):
return self.value.param_groups[arg.source.index]

raise ArgMappingException()

new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}

return new_args, new_kwargs

def install_guards(self, tx):
from .builder import VariableBuilder

state_dict_var = VariableBuilder(tx, AttrSource(self.source, "state"))(
self.value.state
)
tx.output.guards.update(state_dict_var.guards)

group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
tx.output.guards.update(group_guards.guards)

def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from .builder import VariableBuilder

tx.store_dict_key(global_key_name(tensor_value), tensor_value)
return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(tensor_value)))(
tensor_value
)

def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable) and all(
isinstance(t, torch.Tensor) for t in py_arg
):
tensor_vars = ListVariable(
[self.wrap_tensor(tx, t) for t in py_arg],
mutable_local=MutableLocal(),
)
arg.call_method(tx, "extend", (tensor_vars,), {})