Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
yf225 committed Jun 19, 2024
2 parents bd76da8 + 77a90c2 commit 1bfdcac
Show file tree
Hide file tree
Showing 52 changed files with 165 additions and 187 deletions.
10 changes: 6 additions & 4 deletions test/distributed/test_dynamo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,12 +1084,14 @@ def _(ctx):
# far from an exhaustive check of all the expected guards, just check a couple of them.
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
"""local "L['self']" ID_MATCH"""
).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check(
f"""{expected_guard_source} "L['self'].net" ID_MATCH"""
).check(
f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH"""
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH"""
f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH"""
).run(
GUARDS_FILE.getvalue()
)
Expand Down
16 changes: 8 additions & 8 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5111,10 +5111,10 @@ def wrapper_fn(x):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"):
l_self_tensor_constant0 = L_self_tensor_constant0
def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None
sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
Expand All @@ -5133,16 +5133,16 @@ def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_
getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_
def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
l_flat_tangents_1_ = L_flat_tangents_1_
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None
copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
return (mul_tensor,)
""",
)
Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2411,6 +2411,7 @@ def test_profiler_matmul_dim_fp16_pattern(self):
num_matched.append(len(pattern.matched_events()))
self.assertEqual(num_matched, [i for i, _ in cases])

@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
def test_profiler_pattern_matcher_json_report(self):
x = torch.ones((100, 100))
model = nn.Sequential(
Expand Down
20 changes: 20 additions & 0 deletions torch/_dynamo/create_parameter_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# mypy: allow-untyped-defs
import threading
from contextlib import contextmanager

import torch

doc = """
Expand Down Expand Up @@ -56,3 +59,20 @@ def new_parameter_placeholder(size, dtype, device, requires_grad):
# Allocating a zero tensor would causes assert failures in autograd.
result.untyped_storage().resize_(0)
return result


_TLS = threading.local()


@contextmanager
def do_not_convert_to_tracable_parameter():
old_flag = getattr(_TLS, "convert_tracable_parameter", True)
_TLS.convert_tracable_parameter = False
try:
yield False
finally:
_TLS.convert_tracable_parameter = old_flag


def can_convert_to_tracable_parameter():
return getattr(_TLS, "convert_tracable_parameter", True)
3 changes: 3 additions & 0 deletions torch/_dynamo/mutation_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks


unpatched_nn_module_init = torch.nn.Module.__init__


class MutationTracker:
db = ExactWeakKeyDictionary()

Expand Down
32 changes: 20 additions & 12 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,7 @@ def codegen_save_tempvars(self, cg: PyCodegen):
elif isinstance(var.mutable_local, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
if "__call_nn_module_init" in self.store_attr_mutations.get(
var.mutable_local, {}
):
assert isinstance(var, variables.UnspecializedNNModuleVariable)
cg.load_import_from(utils.__name__, "nn_module_new")
else:
cg.load_import_from(utils.__name__, "object_new")
cg.load_import_from(utils.__name__, "object_new")
cg(var.mutable_local.cls_source)
cg.extend_output(create_call_function(1, True))
cg.add_cache(var)
Expand Down Expand Up @@ -562,18 +556,32 @@ def codegen_update_mutated(self, cg: PyCodegen):
]
)
elif self.is_attribute_mutation(var):
for name, value in self.store_attr_mutations.get(
var.mutable_local, {}
).items():
# Applying mutations involves two steps: 1) Push all
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
# apply the mutations.
#
# Dynamo must ensure that mutations are applied in the same
# order as in the original program. Therefore, two reverse
# operations occur below.
#
# The first reverse operation concerns `suffixes`. We apply
# suffixes in reverse order due to the way Python handles the
# stack. In Step 1, we push all reconstructed objects onto the
# stack, but the item at the top of the stack refers to the last
# attribute in the mutation order. If not fixed, this will apply
# the mutations of attributes in the reverse order. To account
# for this reversal, we iterate through the mutable attributes
# in reverse order.
for name, value in reversed(
self.store_attr_mutations.get(var.mutable_local, {}).items()
):
if isinstance(var, variables.NewGlobalVariable):
cg.tx.output.update_co_names(name)
cg(value)
assert isinstance(var.mutable_local.source, GlobalSource) # type: ignore[attr-defined]
suffixes.append(
[create_instruction("STORE_GLOBAL", argval=name)]
)
elif name == "__call_nn_module_init":
pass # handled in codegen_save_tempvars
elif isinstance(value, variables.DeletedVariable):
if isinstance(
var.mutable_local, AttributeMutationExisting
Expand Down
19 changes: 14 additions & 5 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
PythonModuleVariable,
UnknownVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable
from .variables.user_defined import (
RemovableHandleVariable,
Expand Down Expand Up @@ -415,13 +415,22 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction):
if push:
self.push(value)
self.jump(inst)
elif isinstance(value, UnspecializedNNModuleVariable):
mod = value.value
if truth_fn(mod):
if push:
self.push(value)
self.jump(inst)
elif isinstance(value, UserDefinedObjectVariable):
x = None
if hasattr(value, "__bool__"):
try:
x = value.var_getattr(self, "__bool__")
# if __bool__ is missing, trying __len__ to infer a truth value.
if (x is None or isinstance(x, GetAttrVariable)) and hasattr(value, "__len__"):
except exc.ObservedException:
# if __bool__ is missing, trying __len__ to infer a truth value.
x = value.var_getattr(self, "__len__")
else:
if isinstance(x, GetAttrVariable):
# if __bool__ is missing, trying __len__ to infer a truth value.
x = value.var_getattr(self, "__len__")

# __bool__ or __len__ is function
if isinstance(x, UserMethodVariable):
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2020,12 +2020,12 @@ def object_has_getattribute(value: Any):
return False


def get_custom_getattr(value: Any):
def get_custom_getattr(value: Any, ignore_nn_module_getattr: bool = False):
try:
getattr_fn = inspect.getattr_static(type(value), "__getattr__")
except AttributeError:
getattr_fn = None
if getattr_fn is torch.nn.Module.__getattr__:
if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__:
# ignore this case of getattr
getattr_fn = None
return getattr_fn
Expand Down
6 changes: 5 additions & 1 deletion torch/_dynamo/variables/dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def python_type(self):
def __contains__(self, vt):
assert isinstance(vt, VariableTracker)
Hashable = ConstDictVariable._HashableTracker
return is_hashable(vt) and Hashable(vt) in self.items
return (
is_hashable(vt)
and Hashable(vt) in self.items
and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
)

def reconstruct(self, codegen):
# instructions to load collections.OrderedDict if necessary
Expand Down
26 changes: 19 additions & 7 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
import torch.utils._pytree as pytree
from .. import config, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..mutation_guard import unpatched_nn_module_init
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
from ..utils import (
check_unspec_or_constant_args,
Expand Down Expand Up @@ -121,7 +123,6 @@ def call_method(
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
inner_fn, source = self._resolved_getattr_and_source(self, name)

if inner_fn is object.__init__:
return LambdaVariable(identity)
elif inner_fn is torch.nn.Module.__init__:
Expand All @@ -133,12 +134,10 @@ def call_method(
and isinstance(objvar.mutable_local, AttributeMutationNew)
and not (args or kwargs)
):
tx.output.side_effects.store_attr(
objvar,
"__call_nn_module_init",
variables.ConstantVariable.create(True),
)
return variables.ConstantVariable.create(None)
with do_not_convert_to_tracable_parameter():
return variables.UserFunctionVariable(
unpatched_nn_module_init, source=source
).call_function(tx, [self.objvar] + args, kwargs)
else:
unimplemented("super() nn.Module.__init__")
elif isinstance(inner_fn, types.FunctionType):
Expand Down Expand Up @@ -181,6 +180,19 @@ def call_method(
self.objvar, UserDefinedObjectVariable
):
return self.objvar.method_setattr_standard(tx, *args, **kwargs)
elif inner_fn is object.__delattr__:
attr = args[0]
try:
attr = attr.as_python_constant()
except NotImplementedError:
unimplemented(f"non-const delattr attr: {attr}")
if not tx.output.side_effects.is_attribute_mutation(self.objvar):
unimplemented(f"delattr({self.objvar}, {attr}, ...)")

tx.output.side_effects.store_attr(
self.objvar, attr, variables.DeletedVariable()
)
return variables.ConstantVariable(None)

unimplemented(f"non-function or method super: {inner_fn}")

Expand Down
11 changes: 11 additions & 0 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,17 @@ def call_method(
kwargs,
)

if method is torch.nn.Module.__setattr__ and isinstance(
args[1], variables.DeletedVariable
):
# Trace through __delattr__ to track mutations on the module
# members like `_modules``.
return tx.inline_user_function_return(
variables.UserFunctionVariable(torch.nn.Module.__delattr__),
[self, args[0]],
kwargs,
)

return super().call_method(tx, name, args, kwargs)


Expand Down
9 changes: 8 additions & 1 deletion torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from ..._guards import TracingContext
from .. import config, polyfill, variables
from ..codegen import PyCodegen
from ..create_parameter_op import new_parameter_placeholder, tracable_create_parameter
from ..create_parameter_op import (
can_convert_to_tracable_parameter,
new_parameter_placeholder,
tracable_create_parameter,
)
from ..device_interface import get_registered_device_interfaces
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
Expand Down Expand Up @@ -910,6 +914,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True):
if data.source:
return cls._nn_param_via_prefix_insert(tx, data, requires_grad)

if not can_convert_to_tracable_parameter():
unimplemented("Workaround for issues with nn_parameter construction")

try:
shape = tuple(data.var_getattr(tx, "shape").as_python_constant())
dtype = data.var_getattr(tx, "dtype").as_python_constant()
Expand Down
Loading

0 comments on commit 1bfdcac

Please sign in to comment.