-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Allow nested FSDP wrapper to use different mixed precision #90523
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90523
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 FailuresAs of commit 3f687d6: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: c4afb38c7f795552cfd0f3a1ba6bfb9f23e6846b Pull Request resolved: #90523
…ecision" The main change is to move `args` and `kwargs` dtype convertion from `_root_pre_forward` to `_pre_forward`, so that every FSDP has a chance to apply its own precision. [ghstack-poisoned]
The main change is to move `args` and `kwargs` dtype convertion from `_root_pre_forward` to `_pre_forward`, so that every FSDP has a chance to apply its own precision. ghstack-source-id: c4afb38c7f795552cfd0f3a1ba6bfb9f23e6846b Pull Request resolved: #90523
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR looks good to me. Feel free to give a chance for @rohan-varma to review if he wants and also to let me know what you think about some of the comments I left.
@@ -257,12 +266,14 @@ def _pre_forward( | |||
handles: List[FlatParamHandle], | |||
unshard_fn: Callable, | |||
module: nn.Module, | |||
input: Any, | |||
): | |||
args: Tuple[Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Does this need to be Tuple[Any, ...]
? (also affects docstring)
My understanding is that Tuple[Any]
means a singleton tuple with an element of any type and that Tuple[Any, ...]
means a tuple with arbitrarily many elements, where each element may be any type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh Yep, thanks for the catch!
@@ -284,6 +296,12 @@ def _pre_forward( | |||
# the `grad_fn` is mutated. | |||
_register_post_backward_hooks(state, handles) | |||
|
|||
# Recursively convert args and kwargs to specified precision. | |||
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that this line only targets the wrapper FSDP code path.
I recommend adding a TODO
at this line to indicate that we need to change it to support the non-wrapper code path; however, we may also file an issue if you do not want to re-run CI.
The open problem is how should fully_shard
enable applying different fsdp_kwargs
to different FlatParamHandle
(for a select subset of the arguments). The main difference is that fully_shard
does not support manual wrapping, which is the solution for FullyShardedDataParallel
. Today, one application of fully_shard
does not interact with any other application of fully_shard
. If we want to migrate to the composable APIs, then we need to iron out the design here.
Given a away to specify different fsdp_kwargs
to different FlatParamHandle
, then the change to this line here is not complicated.
For the non-wrapper code path, we change the requirement to be that every handle in handles
(passed to this _pre_forward
) has the same low precision parameter dtype. In today's design (and probably for the near future), handles
is a singleton (i.e. len(handles) == 1
). Either way, the general constraint will be on handles
.
def _get_handles_param_dtype(handles: Iterable[FlatParamHandle]) -> Optional[torch.dtype]:
"""
Assumes all ``handles`` have the same low precision parameter dtype
setting, where ``None`` represents full precision. This returns the single
low precision parameter dtype setting across ``handles``.
"""
dtypes: Optional[torch.dtype] = set(handle._config.low_prec_param_dtype for handle in handles)
p_assert(
len(dtypes) == 1,
f"Expects uniform low precision parameter dtype but got {dtypes}",
)
return next(iter(dtypes))
input_dtype: Optional[torch.dtype] = _get_handles_param_dtype(handles)
(The question is how to specify handle._config._low_prec_param_dtype
to be different across FlatParamHandle
s of the same fully_shard
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Today, one application of fully_shard does not interact with any other application of fully_shard. If we want to migrate to the composable APIs, then we need to iron out the design here.
Can we do the following:
For short term, we require users to wrap children modules with fully_shard
before wrapping parent ones. When applying fully_shard
, recursively check all submodules' states and skip the ones that already have replicate
, fully_shard
on it. More specifically, fully_shard
can write a flag in every submodule, indicating that this submodule is handled by a parent module already. If users later applied another fully_shard
on the submodule, error out.
For the longer term, when we have the high-level wrapper API, that API can sort things out, and we no longer require users to wrap things in order.
@@ -284,6 +296,12 @@ def _pre_forward( | |||
# the `grad_fn` is mutated. | |||
_register_post_backward_hooks(state, handles) | |||
|
|||
# Recursively convert args and kwargs to specified precision. | |||
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype | |||
args, kwargs = _prepare_forward_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main assumption on the performance side from this PR is that the overhead from _prepare_forward_inputs()
/ _cast_fp_inputs_to_dtype()
is not a concern.
Has anyone confirmed if t.to(t.dtype)
is fully a no-op with effectively 0 overhead? If not, I wonder if we should change _cast_fp_inputs_to_dtype()
to guard the .to()
with an if
statement:
def cast_fn(x: torch.Tensor) -> torch.Tensor:
def cast_fn(x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == dtype:
return x
y = x.to(dtype)
# Explicitly copy over `requires_grad` since this runs inside
# `torch.no_grad()`
if x.is_leaf:
y.requires_grad = x.requires_grad
return y
In other words, we do not perform any of the logic if the tensor's dtype already matches. I think with this change, I would be more confident that there will be effectively no performance regression from this. The only additional overhead is a CPU-side recursion over the input data structure via _apply_to_tensors()
incurred for every FSDP instance.
On a related note, I think we need to be careful about using a side stream to do this downcasting. Today, we do not use such a stream, but we have mentioned considering to do so. There may be some strange memory freeing behavior and stream complications. We are investigating this for the post-backward, so I think we should hold off on any stream changes in the near term until we get a better understanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of Tensor.to()
is:
pytorch/aten/src/ATen/core/Tensor.cpp
Lines 41 to 51 in eeb3f8a
TensorBase TensorBase::to( | |
at::TensorOptions options, | |
bool non_blocking, | |
bool copy, | |
c10::optional<at::MemoryFormat> memory_format) const { | |
Tensor self(*this); | |
return at::_ops::to_dtype_layout::call( | |
self, optTypeMetaToScalarType(options.dtype_opt()), | |
options.layout_opt(), options.device_opt(), | |
options.pinned_memory_opt(), non_blocking, copy, memory_format); | |
} |
Then I traced where it calls to.
struct TORCH_API to_dtype {
using schema = at::Tensor (const at::Tensor &, at::ScalarType, bool, bool, c10::optional<at::MemoryFormat>);
using ptr_schema = schema*;
// See Note [static constexpr char* members for windows NVCC]
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::to")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "dtype")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, "to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")
static at::Tensor call(const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional<at::MemoryFormat> memory_format);
static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional<at::MemoryFormat> memory_format);
};
// aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
at::Tensor to_dtype::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, c10::optional<at::MemoryFormat> memory_format) {
static auto op = create_to_dtype_typed_handle();
return op.redispatch(dispatchKeySet, self, dtype, non_blocking, copy, memory_format);
}
// aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
static C10_NOINLINE c10::TypedOperatorHandle<to_dtype::schema> create_to_dtype_typed_handle() {
return c10::Dispatcher::singleton()
.findSchemaOrThrow(to_dtype::name, to_dtype::overload_name)
.typed<to_dtype::schema>();
}
pytorch/aten/src/ATen/core/dispatch/Dispatcher.h
Lines 489 to 492 in eeb3f8a
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && | |
C10_ALWAYS_INLINE Return redispatch(DispatchKeySet currentDispatchKeySet, Args... args) const { | |
return c10::Dispatcher::singleton().redispatch<Return, Args...>(*this, currentDispatchKeySet, std::forward<Args>(args)...); | |
} |
pytorch/aten/src/ATen/core/dispatch/Dispatcher.h
Lines 641 to 656 in eeb3f8a
// See [Note: Argument forwarding in the dispatcher] for why Args doesn't use && | |
template<class Return, class... Args> | |
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const { | |
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 | |
// do not use RecordFunction on redispatch | |
#ifndef NDEBUG | |
DispatchTraceNestingGuard debug_guard; | |
if (show_dispatch_trace()) { | |
auto nesting_value = dispatch_trace_nesting_value(); | |
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " "; | |
std::cerr << "[redispatch] op=[" << op.operator_name() << "], key=[" << toString(currentDispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; | |
} | |
#endif | |
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet); | |
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...); | |
} |
At this point, I didn't continue. Even if it does not create any CUDA kernel, the above is sufficient complexity + python/cpp context switch. Let's avoid that by using your suggestion above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, looks like even if tensor.to() is a noop, we can avoid a python/C++ switch and GIL release / reacquire here.
@@ -953,3 +954,26 @@ def perThreadTearDown(self): | |||
@property | |||
def world_size(self) -> int: | |||
raise RuntimeError("world size not implemented") | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see: So this CheckPrecisionModule
does not check precision itself -- it is used to check the heterogeneous mixed precision case externally.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me update the name to match that.
…ecision" The main change is to move `args` and `kwargs` dtype convertion from `_root_pre_forward` to `_pre_forward`, so that every FSDP has a chance to apply its own precision. [ghstack-poisoned]
The main change is to move `args` and `kwargs` dtype convertion from `_root_pre_forward` to `_pre_forward`, so that every FSDP has a chance to apply its own precision. ghstack-source-id: f9f4b994a700503580404b186d72834bf9f898c9 Pull Request resolved: #90523
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT!
@@ -284,6 +296,12 @@ def _pre_forward( | |||
# the `grad_fn` is mutated. | |||
_register_post_backward_hooks(state, handles) | |||
|
|||
# Recursively convert args and kwargs to specified precision. | |||
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype | |||
args, kwargs = _prepare_forward_inputs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q: do we skip the conversion if it's already the right type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, @awgu also pointed this out. Added a condition to skip when the dtype already match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge -f "test failure is irrelevant" |
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@mrshenli has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
The main change is to move `args` and `kwargs` dtype convertion from `_root_pre_forward` to `_pre_forward`, so that every FSDP has a chance to apply its own precision. ghstack-source-id: f9f4b994a700503580404b186d72834bf9f898c9 Pull Request resolved: pytorch#90523
Stack from ghstack (oldest at bottom):
The main change is to move
args
andkwargs
dtype convertionfrom
_root_pre_forward
to_pre_forward
, so that everyFSDP has a chance to apply its own precision.
Differential Revision: D41883717