Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions torch/distributed/fsdp/_dynamo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Set

import torch.nn as nn


def _annotate_modules_for_dynamo(
module: nn.Module,
ignored_modules: Set[nn.Module],
use_orig_params: bool,
):
"""
Annotates the submodules in ``module`` 's tree, except those in
``ignored_modules``, indicating that the submodules are FSDP-managed and
saving the ``use_orig_params`` setting passed to the FSDP constructor.
"""
for submodule in module.modules():
if submodule not in ignored_modules:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the comments below are copied directly.

"""[note: Dynamo treats FSDP wrapped modules as UnspecializedNNModule]

Dynamo doesn't get to see this instance (FullyShardedDataParallel) during tracing, since
it skips tracing all the torch.distributed.fsdp code.
- Why? Running the FSDP code eagerly avoids lots of issues trying to trace complex hooks, and also
gets us graph-breaks on FSDP module boundaries which we want anyway for comm ops.
- However, we _also_ want dynamo to treat the wrapped module inside FSDP 'unspecially' (*),
and we need a way to indicate to dynamo which modules are wrapped by FSDP.

(*) UnspecializedNNModules in dynamo are traced-through without any assumptions, and with thorough
guards. NNModules otherwise are 'specialized', meaning there is less overhead due to assuming
their code is well-behaved.

One particular issue with specialized NNModules for FSDP is that the
views created for orig_params are captured into the compiled graph on the first iteration, and while
they are always going to point to the correct flatparameter and give correct results, their order
of creation influences the order of backward execution, preventing overlap of comm and computation
during backward. We need to _use_ the new parameter views created on each forward iteration, in
order for backward to interleave hooks with compute per layer. UnspecializedNNModule lets us achieve
this by capturing the module code more 'functionally' and passing parameters in as inputs each time.
"""
submodule._is_fsdp_managed_module = True # type: ignore[assignment]

# Dynamo only supports FSDP with use_orig_params=True.
# This is hacky, but I could not think of another way to add an assertion to dynamo
# for this, since Dynamo skips all the FSDP code frames and thus can't inspect the
# FSDP module directly
submodule._fsdp_use_orig_params = use_orig_params # type: ignore[assignment]
34 changes: 3 additions & 31 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
HandleTrainingState,
TrainingState,
)
from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
from torch.distributed.fsdp._init_utils import (
_check_orig_params_flattened,
_get_default_comm_hook,
Expand Down Expand Up @@ -332,37 +333,8 @@ def __init__(
super().__init__()
_init_ignored_module_states(self, module, ignored_modules)

# Add module annotations for Dynamo support
for submodule in module.modules():
if submodule not in self._ignored_modules:
"""[note: Dynamo treats FSDP wrapped modules as UnspecializedNNModule]

Dynamo doesn't get to see this instance (FullyShardedDataParallel) during tracing, since
it skips tracing all the torch.distributed.fsdp code.
- Why? Running the FSDP code eagerly avoids lots of issues trying to trace complex hooks, and also
gets us graph-breaks on FSDP module boundaries which we want anyway for comm ops.
- However, we _also_ want dynamo to treat the wrapped module inside FSDP 'unspecially' (*),
and we need a way to indicate to dynamo which modules are wrapped by FSDP.

(*) UnspecializedNNModules in dynamo are traced-through without any assumptions, and with thorough
guards. NNModules otherwise are 'specialized', meaning there is less overhead due to assuming
their code is well-behaved.

One particular issue with specialized NNModules for FSDP is that the
views created for orig_params are captured into the compiled graph on the first iteration, and while
they are always going to point to the correct flatparameter and give correct results, their order
of creation influences the order of backward execution, preventing overlap of comm and computation
during backward. We need to _use_ the new parameter views created on each forward iteration, in
order for backward to interleave hooks with compute per layer. UnspecializedNNModule lets us achieve
this by capturing the module code more 'functionally' and passing parameters in as inputs each time.
"""
submodule._is_fsdp_managed_module = True

# Dynamo only supports FSDP with use_orig_params=True.
# This is hacky, but I could not think of another way to add an assertion to dynamo
# for this, since Dynamo skips all the FSDP code frames and thus can't inspect the
# FSDP module directly
submodule._fsdp_use_orig_params = use_orig_params
# Add module annotations for Dynamo support (see function for details)
_annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params)

if auto_wrap_policy is not None:
auto_wrap_kwargs = {
Expand Down