Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize optimizer in dynamo to avoid graph break and tracing slowness #102640

Closed
wants to merge 13 commits into from
3 changes: 0 additions & 3 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,9 +1272,6 @@ def patch():
)
opt.zero_grad = disable(opt.zero_grad)

if hasattr(opt, "_init_group"):
opt._init_group = disable(opt._init_group)

# disable any currently set hooks
# Note: we only want to disable the profiling hook
# which is the *last* hook applied, we want to keep the no_grad hook
Expand Down
8 changes: 6 additions & 2 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,11 +819,15 @@ def enum_repr(value, local):


def dict_param_key_ids(value):
return {id(k) for k in value.keys() if isinstance(k, torch.nn.Parameter)}
return {
id(k) for k in value.keys() if isinstance(k, (torch.nn.Parameter, torch.Tensor))
}


def dict_const_keys(value):
return {k for k in value.keys() if not isinstance(k, torch.nn.Parameter)}
return {
k for k in value.keys() if not isinstance(k, (torch.nn.Parameter, torch.Tensor))
}


def dict_const_keys_repr(const_keys, *, local):
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@
SkipFilesVariable,
TypingVariable,
)

from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable
from .tensor import (
SymNodeVariable,
TensorVariable,
Expand Down Expand Up @@ -578,6 +580,12 @@ def index_source(key):
source=self.source,
guards=make_guards(GuardBuilder.FUNCTION_MATCH),
)
elif isinstance(value, torch.optim.Optimizer):
return OptimizerVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.TYPE_MATCH),
)
elif ProcessGroupVariable.is_process_group(value):
return ProcessGroupVariable(
value,
Expand Down
120 changes: 120 additions & 0 deletions torch/_dynamo/variables/optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Dict, List

import torch
from ..source import AttrSource, GetItemSource, GlobalWeakRefSource
from ..utils import global_key_name

from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable


class ArgMappingException(Exception):
pass


class OptimizerVariable(UserDefinedObjectVariable):
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow intialization of the optimizer"""
if name == "_init_group":
try:
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
self.value._init_group(*py_args, **py_kwargs)
self.install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
return ConstantVariable(None)
except ArgMappingException:
# trace normally if we can't map args
pass

return super().call_method(tx, name, args, kwargs)

def map_grads_to_sources(self):
"""Map the optimizer's grads to their sources"""
self.grad_to_source = {}
for g_ind, group in enumerate(self.value.param_groups):
group_source = GetItemSource(AttrSource(self.source, "param_groups"), g_ind)
for p_ind, p in enumerate(group["params"]):
if p.grad is not None:
self.grad_to_source[p.grad] = AttrSource(
GetItemSource(GetItemSource(group_source, "params"), p_ind),
"grad",
)

def var_getattr(self, tx, name):
if name == "_init_group":
return GetAttrVariable(self, name)

return super().var_getattr(tx, name)

def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""

def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]

raise ArgMappingException()

new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}

return new_args, new_kwargs

def install_guards(self, tx):
from .builder import VariableBuilder

state_dict_var = VariableBuilder(tx, AttrSource(self.source, "state"))(
self.value.state
)
tx.output.guards.update(state_dict_var.guards)

group_guards = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
tx.output.guards.update(group_guards.guards)

def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from .builder import VariableBuilder

# don't add weakref guards for grads, they will possibly change on
# each iteration
if tensor_value in self.grad_to_source:
return VariableBuilder(tx, self.grad_to_source[tensor_value])(tensor_value)
else:
tx.store_dict_key(global_key_name(tensor_value), tensor_value)
return VariableBuilder(
tx, GlobalWeakRefSource(global_key_name(tensor_value))
)(tensor_value)

def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
self.map_grads_to_sources()
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable) and all(
isinstance(t, torch.Tensor) for t in py_arg
):
tensor_vars = ListVariable(
[self.wrap_tensor(tx, t) for t in py_arg],
mutable_local=MutableLocal(),
)
arg.call_method(tx, "extend", (tensor_vars,), {})