Skip to content

Commit

Permalink
[Traceable FSDP2] [Dynamo] Fix OptimizedModule._initialize to allow t…
Browse files Browse the repository at this point in the history
…racing into FSDP2 module hooks for module from user-defined module class (#129046)

This is a workaround to allow inplace fully-sharded module to still go into this branch:
https://github.com/pytorch/pytorch/blob/3a185778edb18abfbad155a87ff3b2d716e4c220/torch/_dynamo/eval_frame.py#L163
instead of the second branch:
https://github.com/pytorch/pytorch/blob/3a185778edb18abfbad155a87ff3b2d716e4c220/torch/_dynamo/eval_frame.py#L166

If we don't do this, `torch.compile(fully_shard(module_from_user_defined_module_class))` will ignore all module hooks which will break FSDP tracing.

Pull Request resolved: #129046
Approved by: https://github.com/anijain2305
  • Loading branch information
yf225 authored and pytorchmergebot committed Jun 20, 2024
1 parent 859fa18 commit d8db074
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ def _initialize(self):
if isinstance(self.dynamo_ctx, DisableContext):
# No need to check trace rules
self.forward = self.dynamo_ctx(self._orig_mod.__call__)
elif isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check(
self._orig_mod.forward
elif isinstance(self._orig_mod.forward, types.MethodType) and (
trace_rules.check(self._orig_mod.forward)
or getattr(self._orig_mod, "_is_fsdp_managed_module", False)
):
# This may be a torch.nn.* instance in trace_rules.py which
# won't trigger a frame evaluation workaround to add an extra
Expand Down

0 comments on commit d8db074

Please sign in to comment.