Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
yf225 committed Jun 19, 2024
1 parent 321a014 commit d57dae5
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 133 deletions.
18 changes: 0 additions & 18 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,24 +760,6 @@ def register_attr_or_module(
assert "source" in options
source = options["source"]

# Dynamic Path 2 - module is dynamic, and is fsdp
if is_dynamic_nn_module(target, self.root_tx.export) and getattr(
target, "_is_fsdp_managed_module", False
):
name = "_".join(map(str, names))
base = name
for i in itertools.count():
if name not in self.nn_modules:
self.nn_modules[name] = target
break
name = f"{base}_{i}"
vt = variables.nn_module.FSDPManagedNNModuleVariable(
target,
name,
**options,
)
return self.side_effects.track_object_existing(target, vt)

assert not isinstance(source, ParamBufferSource)

if isinstance(target, torch.Tensor):
Expand Down
11 changes: 3 additions & 8 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,14 +2691,9 @@ def nn_module_proxy(mod):
if isinstance(mod, torch.fx.GraphModule):
# Dynamo-generated GM's shouldn't contain user-created GM's
return mod
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
if isinstance(mod, FSDPModule):
# TODO(yf225): this is a hacky workaround and is not the right thing to do. Need to think about how to work around FSDP.__new__()
return mod
else:
proxy = mod.__class__.__new__(mod.__class__)
proxy.__dict__ = mod.__dict__
return proxy
proxy = mod.__class__.__new__(mod.__class__)
proxy.__dict__ = mod.__dict__
return proxy


class GmWrapper(torch.nn.Module):
Expand Down
14 changes: 1 addition & 13 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def _can_lift_attrs_to_inputs(self, vt):
TensorWithTFOverrideVariable,
UserDefinedObjectVariable,
NumpyNdarrayVariable,
FSDPManagedNNModuleVariable,
]:
return True
return False
Expand Down Expand Up @@ -1226,19 +1225,8 @@ def wrap_module(self, value: torch.nn.Module):
#
# ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps
# them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager)
base = self.name
name = self.name
for i in itertools.count():
if name not in self.tx.output.nn_modules:
self.tx.output.nn_modules[name] = value
break
name = f"{base}_{i}"
self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH)
return FSDPManagedNNModuleVariable(
value,
name,
source=self.get_source(),
)
return FSDPManagedNNModuleVariable(value, source=self.get_source())
else:
return self.tx.output.register_attr_or_module(
value,
Expand Down
97 changes: 3 additions & 94 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,6 @@ def _custom_getattr_fallback(self, base, tx, name, options):
if not isinstance(getattr_fn, types.FunctionType):
unimplemented("torch.nn.Module with a non-function custom __getattr__")

if getattr(base, "_is_fsdp_managed_module", False):
from .builder import VariableBuilder
return VariableBuilder(tx, options["source"])(getattr_fn(base, name))
return variables.UserMethodVariable(getattr_fn, self, **options).call_function(
tx, [variables.ConstantVariable.create(name)], {}
)
Expand Down Expand Up @@ -320,9 +317,6 @@ def var_getattr(self, tx, name):
elif is_safe_constant(subobj) or istensor(subobj):
# Support possibly common cases of class members
return VariableBuilder(tx, NNModuleSource(source))(subobj)
elif istype(subobj, types.GetSetDescriptorType):
assert source
return VariableBuilder(tx, source)(subobj.__get__(base))
else:
unimplemented(
f"class property {name} - {typestr(base)} {typestr(subobj)}"
Expand Down Expand Up @@ -589,20 +583,6 @@ def gen_source(source, name):
elif name == "buffers":
tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name())
return wrap_values(module.named_buffers(**get_kwargs("recurse")))
elif name == "_named_members":
# The get_members_fn fails a const check, but this is a private internal lambda
# passed in nn_module, and so can be safely non-const, as it will not execute arbitrary user code
return wrap_values(
module._named_members(
**get_kwargs(
"get_members_fn",
"prefix",
"recurse",
"remove_duplicates",
assert_const=False,
)
)
)
elif name == "keys":
assert not (args or kwargs)
result = []
Expand Down Expand Up @@ -942,6 +922,8 @@ def call_method(
):
# Handle submodules
self.is_state_mutated = True
if self.is_state_mutated and not tx.output.side_effects.is_attribute_mutation(self):
tx.output.side_effects.track_object_existing(self.value, self)

return super().call_method(tx, name, args, kwargs)

Expand All @@ -958,86 +940,13 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable):
compilation.
"""

def __init__(self, value, module_key, **kwargs):
def __init__(self, value, **kwargs):
source = kwargs.get("source", None)
assert (
source is not None
), "FSDPManagedNNModule depends on having an accurate source to control guarding."

super().__init__(value=value, **kwargs)
self.source = FSDPManagedNNModuleVariable._wrap_source(source)
self.module_key = module_key
self.module = value

@staticmethod
def _wrap_source(source):
if not isinstance(source, (FSDPNNModuleSource, NotNNModuleSource)):
if torch._dynamo.config.skip_fsdp_guards:
return FSDPNNModuleSource(source)
else:
# this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes
return NotNNModuleSource(source)
else:
return source

def __setattr__(self, name: str, value: Any) -> None:
if name == "source":
value = FSDPManagedNNModuleVariable._wrap_source(value)

return super().__setattr__(name, value)

def call_method(
self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
key = self.module_key

named_embed = functools.partial(
_named_embed,
tx=tx,
key=key,
source_cls=FSDPNNModuleSource,
source=self.source,
)
wrap_values = functools.partial(
_wrap_values,
tx=tx,
key=key,
source_cls=FSDPNNModuleSource,
source=self.source,
)
get_kwargs = functools.partial(
_get_kwargs, mod=self.value, name=name, args=args, kwargs=kwargs
)

if name == "buffers":
return wrap_values(self.value.named_buffers(**get_kwargs("recurse")))
elif name == "named_buffers":
result = []
for name, buffer in self.value.named_buffers(
**get_kwargs("prefix", "recurse", "remove_duplicate")
):
result.append(named_embed(name, buffer))
return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
elif name == "children":
assert not (args or kwargs)
return wrap_values(self.value.named_children())
return super().call_method(tx, name, args, kwargs)

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return super().call_function(tx, args, kwargs)

def var_getattr(self, tx, name):
if name in ["named_buffers", "children", "buffers"]:
# Route this to produce a ListIteratorVariable instead of getting the generator
return variables.LambdaVariable(
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
)
return super().var_getattr(tx, name)

def as_python_constant(self):
return self.value


def _gen_source(source, name):
Expand Down

0 comments on commit d57dae5

Please sign in to comment.