From d57dae5286be22848e1fd1c78a6c33c76b9408c4 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 18 Jun 2024 21:30:47 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- torch/_dynamo/output_graph.py | 18 ------ torch/_dynamo/utils.py | 11 +--- torch/_dynamo/variables/builder.py | 14 +--- torch/_dynamo/variables/nn_module.py | 97 +--------------------------- 4 files changed, 7 insertions(+), 133 deletions(-) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 7b1974a50ed3c..6200a529c2c16 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -760,24 +760,6 @@ def register_attr_or_module( assert "source" in options source = options["source"] - # Dynamic Path 2 - module is dynamic, and is fsdp - if is_dynamic_nn_module(target, self.root_tx.export) and getattr( - target, "_is_fsdp_managed_module", False - ): - name = "_".join(map(str, names)) - base = name - for i in itertools.count(): - if name not in self.nn_modules: - self.nn_modules[name] = target - break - name = f"{base}_{i}" - vt = variables.nn_module.FSDPManagedNNModuleVariable( - target, - name, - **options, - ) - return self.side_effects.track_object_existing(target, vt) - assert not isinstance(source, ParamBufferSource) if isinstance(target, torch.Tensor): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 3f39aceeb1b6c..0afcde5e59363 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2691,14 +2691,9 @@ def nn_module_proxy(mod): if isinstance(mod, torch.fx.GraphModule): # Dynamo-generated GM's shouldn't contain user-created GM's return mod - from torch.distributed._composable.fsdp.fully_shard import FSDPModule - if isinstance(mod, FSDPModule): - # TODO(yf225): this is a hacky workaround and is not the right thing to do. Need to think about how to work around FSDP.__new__() - return mod - else: - proxy = mod.__class__.__new__(mod.__class__) - proxy.__dict__ = mod.__dict__ - return proxy + proxy = mod.__class__.__new__(mod.__class__) + proxy.__dict__ = mod.__dict__ + return proxy class GmWrapper(torch.nn.Module): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index daad474678b70..58697bb8765c9 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -332,7 +332,6 @@ def _can_lift_attrs_to_inputs(self, vt): TensorWithTFOverrideVariable, UserDefinedObjectVariable, NumpyNdarrayVariable, - FSDPManagedNNModuleVariable, ]: return True return False @@ -1226,19 +1225,8 @@ def wrap_module(self, value: torch.nn.Module): # # ID_MATCH is required to disambiguate cases as simple as a unit test that constructs 2 models and wraps # them differently with different FSDP configs. (test_dynamo_distributed.py -k test_fsdp_aot_eager) - base = self.name - name = self.name - for i in itertools.count(): - if name not in self.tx.output.nn_modules: - self.tx.output.nn_modules[name] = value - break - name = f"{base}_{i}" self.install_guards(GuardBuilder.TYPE_MATCH, GuardBuilder.ID_MATCH) - return FSDPManagedNNModuleVariable( - value, - name, - source=self.get_source(), - ) + return FSDPManagedNNModuleVariable(value, source=self.get_source()) else: return self.tx.output.register_attr_or_module( value, diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index cd87cb3a52822..1a35e2e914567 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -222,9 +222,6 @@ def _custom_getattr_fallback(self, base, tx, name, options): if not isinstance(getattr_fn, types.FunctionType): unimplemented("torch.nn.Module with a non-function custom __getattr__") - if getattr(base, "_is_fsdp_managed_module", False): - from .builder import VariableBuilder - return VariableBuilder(tx, options["source"])(getattr_fn(base, name)) return variables.UserMethodVariable(getattr_fn, self, **options).call_function( tx, [variables.ConstantVariable.create(name)], {} ) @@ -320,9 +317,6 @@ def var_getattr(self, tx, name): elif is_safe_constant(subobj) or istensor(subobj): # Support possibly common cases of class members return VariableBuilder(tx, NNModuleSource(source))(subobj) - elif istype(subobj, types.GetSetDescriptorType): - assert source - return VariableBuilder(tx, source)(subobj.__get__(base)) else: unimplemented( f"class property {name} - {typestr(base)} {typestr(subobj)}" @@ -589,20 +583,6 @@ def gen_source(source, name): elif name == "buffers": tx.output.guard_on_key_order.add(AttrSource(self.source, "_buffers").name()) return wrap_values(module.named_buffers(**get_kwargs("recurse"))) - elif name == "_named_members": - # The get_members_fn fails a const check, but this is a private internal lambda - # passed in nn_module, and so can be safely non-const, as it will not execute arbitrary user code - return wrap_values( - module._named_members( - **get_kwargs( - "get_members_fn", - "prefix", - "recurse", - "remove_duplicates", - assert_const=False, - ) - ) - ) elif name == "keys": assert not (args or kwargs) result = [] @@ -942,6 +922,8 @@ def call_method( ): # Handle submodules self.is_state_mutated = True + if self.is_state_mutated and not tx.output.side_effects.is_attribute_mutation(self): + tx.output.side_effects.track_object_existing(self.value, self) return super().call_method(tx, name, args, kwargs) @@ -958,86 +940,13 @@ class FSDPManagedNNModuleVariable(UnspecializedNNModuleVariable): compilation. """ - def __init__(self, value, module_key, **kwargs): + def __init__(self, value, **kwargs): source = kwargs.get("source", None) assert ( source is not None ), "FSDPManagedNNModule depends on having an accurate source to control guarding." super().__init__(value=value, **kwargs) - self.source = FSDPManagedNNModuleVariable._wrap_source(source) - self.module_key = module_key - self.module = value - - @staticmethod - def _wrap_source(source): - if not isinstance(source, (FSDPNNModuleSource, NotNNModuleSource)): - if torch._dynamo.config.skip_fsdp_guards: - return FSDPNNModuleSource(source) - else: - # this makes us behave like a usual UnspecializedNNModuleVariable for guarding purposes - return NotNNModuleSource(source) - else: - return source - - def __setattr__(self, name: str, value: Any) -> None: - if name == "source": - value = FSDPManagedNNModuleVariable._wrap_source(value) - - return super().__setattr__(name, value) - - def call_method( - self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] - ) -> VariableTracker: - key = self.module_key - - named_embed = functools.partial( - _named_embed, - tx=tx, - key=key, - source_cls=FSDPNNModuleSource, - source=self.source, - ) - wrap_values = functools.partial( - _wrap_values, - tx=tx, - key=key, - source_cls=FSDPNNModuleSource, - source=self.source, - ) - get_kwargs = functools.partial( - _get_kwargs, mod=self.value, name=name, args=args, kwargs=kwargs - ) - - if name == "buffers": - return wrap_values(self.value.named_buffers(**get_kwargs("recurse"))) - elif name == "named_buffers": - result = [] - for name, buffer in self.value.named_buffers( - **get_kwargs("prefix", "recurse", "remove_duplicate") - ): - result.append(named_embed(name, buffer)) - return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) - elif name == "children": - assert not (args or kwargs) - return wrap_values(self.value.named_children()) - return super().call_method(tx, name, args, kwargs) - - def call_function( - self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" - ) -> "VariableTracker": - return super().call_function(tx, args, kwargs) - - def var_getattr(self, tx, name): - if name in ["named_buffers", "children", "buffers"]: - # Route this to produce a ListIteratorVariable instead of getting the generator - return variables.LambdaVariable( - lambda *args, **kwargs: self.call_method(tx, name, args, kwargs) - ) - return super().var_getattr(tx, name) - - def as_python_constant(self): - return self.value def _gen_source(source, name):