Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Allow nested FSDP wrapper to use different mixed precision #90523

Closed
wants to merge 3 commits into from

Conversation

mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Dec 9, 2022

Stack from ghstack (oldest at bottom):

The main change is to move args and kwargs dtype convertion
from _root_pre_forward to _pre_forward, so that every
FSDP has a chance to apply its own precision.

Differential Revision: D41883717

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 9, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90523

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Failures

As of commit 3f687d6:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Dec 9, 2022
mrshenli added a commit that referenced this pull request Dec 9, 2022
ghstack-source-id: c4afb38c7f795552cfd0f3a1ba6bfb9f23e6846b
Pull Request resolved: #90523
…ecision"

The main change is to move `args` and `kwargs` dtype convertion
from `_root_pre_forward` to `_pre_forward`, so that every
FSDP has a chance to apply its own precision.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Dec 9, 2022
The main change is to move `args` and `kwargs` dtype convertion
from `_root_pre_forward` to `_pre_forward`, so that every
FSDP has a chance to apply its own precision.

ghstack-source-id: c4afb38c7f795552cfd0f3a1ba6bfb9f23e6846b
Pull Request resolved: #90523
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR looks good to me. Feel free to give a chance for @rohan-varma to review if he wants and also to let me know what you think about some of the comments I left.

@@ -257,12 +266,14 @@ def _pre_forward(
handles: List[FlatParamHandle],
unshard_fn: Callable,
module: nn.Module,
input: Any,
):
args: Tuple[Any],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Does this need to be Tuple[Any, ...]? (also affects docstring)

My understanding is that Tuple[Any] means a singleton tuple with an element of any type and that Tuple[Any, ...] means a tuple with arbitrarily many elements, where each element may be any type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh Yep, thanks for the catch!

@@ -284,6 +296,12 @@ def _pre_forward(
# the `grad_fn` is mutated.
_register_post_backward_hooks(state, handles)

# Recursively convert args and kwargs to specified precision.
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this line only targets the wrapper FSDP code path.

I recommend adding a TODO at this line to indicate that we need to change it to support the non-wrapper code path; however, we may also file an issue if you do not want to re-run CI.


The open problem is how should fully_shard enable applying different fsdp_kwargs to different FlatParamHandle (for a select subset of the arguments). The main difference is that fully_shard does not support manual wrapping, which is the solution for FullyShardedDataParallel. Today, one application of fully_shard does not interact with any other application of fully_shard. If we want to migrate to the composable APIs, then we need to iron out the design here.

Given a away to specify different fsdp_kwargs to different FlatParamHandle, then the change to this line here is not complicated.

For the non-wrapper code path, we change the requirement to be that every handle in handles (passed to this _pre_forward) has the same low precision parameter dtype. In today's design (and probably for the near future), handles is a singleton (i.e. len(handles) == 1). Either way, the general constraint will be on handles.

def _get_handles_param_dtype(handles: Iterable[FlatParamHandle]) -> Optional[torch.dtype]:
    """
    Assumes all ``handles`` have the same low precision parameter dtype
    setting, where ``None`` represents full precision. This returns the single
    low precision parameter dtype setting across ``handles``.
    """
    dtypes: Optional[torch.dtype] = set(handle._config.low_prec_param_dtype for handle in handles)
    p_assert(
        len(dtypes) == 1,
        f"Expects uniform low precision parameter dtype but got {dtypes}",
    )
    return next(iter(dtypes))        
input_dtype: Optional[torch.dtype] = _get_handles_param_dtype(handles)

(The question is how to specify handle._config._low_prec_param_dtype to be different across FlatParamHandles of the same fully_shard.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today, one application of fully_shard does not interact with any other application of fully_shard. If we want to migrate to the composable APIs, then we need to iron out the design here.

Can we do the following:

For short term, we require users to wrap children modules with fully_shard before wrapping parent ones. When applying fully_shard, recursively check all submodules' states and skip the ones that already have replicate, fully_shard on it. More specifically, fully_shard can write a flag in every submodule, indicating that this submodule is handled by a parent module already. If users later applied another fully_shard on the submodule, error out.

For the longer term, when we have the high-level wrapper API, that API can sort things out, and we no longer require users to wrap things in order.

@@ -284,6 +296,12 @@ def _pre_forward(
# the `grad_fn` is mutated.
_register_post_backward_hooks(state, handles)

# Recursively convert args and kwargs to specified precision.
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _prepare_forward_inputs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main assumption on the performance side from this PR is that the overhead from _prepare_forward_inputs() / _cast_fp_inputs_to_dtype() is not a concern.

Has anyone confirmed if t.to(t.dtype) is fully a no-op with effectively 0 overhead? If not, I wonder if we should change _cast_fp_inputs_to_dtype() to guard the .to() with an if statement:

def cast_fn(x: torch.Tensor) -> torch.Tensor:
    def cast_fn(x: torch.Tensor) -> torch.Tensor:
    if not torch.is_floating_point(x) or x.dtype == dtype:
        return x
    y = x.to(dtype)
    # Explicitly copy over `requires_grad` since this runs inside
    # `torch.no_grad()`
    if x.is_leaf:
        y.requires_grad = x.requires_grad
    return y

In other words, we do not perform any of the logic if the tensor's dtype already matches. I think with this change, I would be more confident that there will be effectively no performance regression from this. The only additional overhead is a CPU-side recursion over the input data structure via _apply_to_tensors() incurred for every FSDP instance.


On a related note, I think we need to be careful about using a side stream to do this downcasting. Today, we do not use such a stream, but we have mentioned considering to do so. There may be some strange memory freeing behavior and stream complications. We are investigating this for the post-backward, so I think we should hold off on any stream changes in the near term until we get a better understanding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of Tensor.to() is:

TensorBase TensorBase::to(
at::TensorOptions options,
bool non_blocking,
bool copy,
c10::optional<at::MemoryFormat> memory_format) const {
Tensor self(*this);
return at::_ops::to_dtype_layout::call(
self, optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(), options.device_opt(),
options.pinned_memory_opt(), non_blocking, copy, memory_format);
}

Then I traced where it calls to.

struct TORCH_API to_dtype {
  using schema = at::Tensor (const at::Tensor &, at::ScalarType, bool, bool, c10::optional<at::MemoryFormat>);
  using ptr_schema = schema*;
  // See Note [static constexpr char* members for windows NVCC]
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::to")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "dtype")
  STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")
  static at::Tensor call(const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional<at::MemoryFormat> memory_format);
  static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional<at::MemoryFormat> memory_format);
};
// aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
at::Tensor to_dtype::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional<at::MemoryFormat> memory_format) {

    static auto op = create_to_dtype_typed_handle();
    return op.redispatch(dispatchKeySet, self, dtype, non_blocking, copy, memory_format);
}
// aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
static C10_NOINLINE c10::TypedOperatorHandle<to_dtype::schema> create_to_dtype_typed_handle() {
  return c10::Dispatcher::singleton()
      .findSchemaOrThrow(to_dtype::name, to_dtype::overload_name)
      .typed<to_dtype::schema>();
}

// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const {
return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...);
}

// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use &&
template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// do not use RecordFunction on redispatch
#ifndef NDEBUG
DispatchTraceNestingGuard debug_guard;
if (show_dispatch_trace()) {
auto nesting_value = dispatch_trace_nesting_value();
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
}
#endif
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet);
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}

At this point, I didn't continue. Even if it does not create any CUDA kernel, the above is sufficient complexity + python/cpp context switch. Let's avoid that by using your suggestion above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, looks like even if tensor.to() is a noop, we can avoid a python/C++ switch and GIL release / reacquire here.

@@ -953,3 +954,26 @@ def perThreadTearDown(self):
@property
def world_size(self) -> int:
raise RuntimeError("world size not implemented")


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see: So this CheckPrecisionModule does not check precision itself -- it is used to check the heterogeneous mixed precision case externally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me update the name to match that.

…ecision"

The main change is to move `args` and `kwargs` dtype convertion
from `_root_pre_forward` to `_pre_forward`, so that every
FSDP has a chance to apply its own precision.

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Dec 9, 2022
The main change is to move `args` and `kwargs` dtype convertion
from `_root_pre_forward` to `_pre_forward`, so that every
FSDP has a chance to apply its own precision.

ghstack-source-id: f9f4b994a700503580404b186d72834bf9f898c9
Pull Request resolved: #90523
@mrshenli mrshenli added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 9, 2022
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT!

@@ -284,6 +296,12 @@ def _pre_forward(
# the `grad_fn` is mutated.
_register_post_backward_hooks(state, handles)

# Recursively convert args and kwargs to specified precision.
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _prepare_forward_inputs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: do we skip the conversion if it's already the right type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, @awgu also pointed this out. Added a condition to skip when the dtype already match.

Copy link
Contributor

@awgu awgu Dec 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is what we proposed in this PR, and I believe @mrshenli has added that change to _cast_fp_inputs_to_dtype().

Edit: Ignore this comment. My Github was not refreshed, so I did not see @mrshenli's response.

@mrshenli
Copy link
Contributor Author

mrshenli commented Dec 9, 2022

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mrshenli
Copy link
Contributor Author

mrshenli commented Dec 9, 2022

@pytorchbot merge -f "test failure is irrelevant"

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mrshenli
Copy link
Contributor Author

mrshenli commented Dec 9, 2022

@mrshenli has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

mrshenli added a commit to mrshenli/pytorch that referenced this pull request Dec 12, 2022
The main change is to move `args` and `kwargs` dtype convertion
from `_root_pre_forward` to `_pre_forward`, so that every
FSDP has a chance to apply its own precision.

ghstack-source-id: f9f4b994a700503580404b186d72834bf9f898c9
Pull Request resolved: pytorch#90523
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants