-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Enable DTensor sharding propagation of native_layer_norm_backward
to more fully accommodate optional args
#133502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable DTensor sharding propagation of native_layer_norm_backward
to more fully accommodate optional args
#133502
Conversation
…o more fully accommodate optional args
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133502
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 1f9cf04 with merge base 255cd75 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…layernorm_bwd_op_ft_fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the fix! Left some questions and suggestions.
…layernorm_bwd_op_ft_fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me! Thanks for addressing the review comments. Here're more questions & suggestions after reading the code into more details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @speediedan for reporting the issue found in DTensor native_layer_norm_backward
and providing a fix along with test coverage. The PR looks good to me and I'll merge it once all CI signals are green.
@pytorchbot merge -f "lint error is not related to this PR" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…o more fully accommodate optional args (#133502) Fixes #133499 ### The issue Testing a variety of TP `requires_grad` patterns (validating maximally flexible finetuning) revealed `DTensor` sharding propagation of `aten.native_layer_norm_backward` (default) fails with an `IndexError` for certain `requires_grad` patterns (pattern 1) (e.g. `output_mask` `[True, False, False]`) and an `AssertionError` for others (pattern 2) (e.g. output mask `[False, True, *]`). Please see issue #133499 for a full description of the observed failure patterns along with reproduction. ### Use Cases and Remediation Failure pattern 1 is potentially problematic for a variety of finetuning scenarios. Though failure pattern 2 is really an xfail right now since it's not fully supported, IMHO there are use cases (e.g. especially wrt to mechanistic interpretability research, but certain finetuning scenarios too potentially) that justify supporting this output mask (especially since supporting it is fairly straightforward I think). In this PR I propose some modest changes that: * Address the aforementioned failure modes. * Add a couple tests that I'm hopeful will help ensure `DTenso`r op dispatch (which is so well implemented and such a pleasure working with btw! 🚀 🎉) accommodates a wide variety of (potentially unanticipated) `requires_grad` patterns as it evolves. To address both failure modes, I'm proposing the following changes: 1. To [`torch.distributed._tensor.ops._math_ops.layer_norm_bwd_strategy`](https://github.com/pytorch/pytorch/blob/7b269cc48434c94b038a749c3c3d8b7d586d0ca7/torch/distributed/_tensor/ops/_math_ops.py#L873): - Refactor conditional `output_mask` handling such that the input and output specs in the`PlacementStrategy`s of the returned `output_strategy.strategies` list remain aligned with the `op_schema.args_spec` (whose definition does not change at runtime based upon unused optional args). 2. To [`torch.distributed._tensor._sharding_prop.propagate_op_sharding_non_cached`](https://github.com/pytorch/pytorch/blob/7b269cc48434c94b038a749c3c3d8b7d586d0ca7/torch/distributed/_tensor/_sharding_prop.py#L256-L262): - When iterating through the active `op_schema.args_spec` to build the relevant `expected_input_specs` list, filter any `None` `desired_specs`. 3. To [`torch/distributed/_tensor/_op_schema.OpSchema._inplace_rewrap_schema_suggestion`](https://github.com/pytorch/pytorch/blob/7b269cc48434c94b038a749c3c3d8b7d586d0ca7/torch/distributed/_tensor/_op_schema.py#L418) - When inputs need a redistribute, for runtime-unrequired (`None` arguments in the aligned `suggestion_args_schema`), ignore the associated `suggestion_args_spec` ### Implementation considerations: - Regarding `1`, to avoid changing the op strategy return args ([`op_strategy`](https://github.com/pytorch/pytorch/blob/cf8118000784653b14e4757690aa1d02ba8216fc/torch/distributed/_tensor/_sharding_prop.py#L234)), the change in `1` allows `None` elements to exist temporarily in `PlacementStrategy.input_specs` (treating it as `Sequence[DTensorSpec | None] | None` when it's `Sequence[DTensorSpec] | None`. This could be addressed in any number of ways but I thought it best to leave that for a subsequent PR since it could have broader ramifications (e.g. allowing op_strategies to return an output_strategy.input_specs` mask explicitly, explicitly allowing `None`s in `PlacementStrategy.input_specs`, creating a `Null` DTensorSpec etc.). That's why I'm using an ignore arg-type directive there for now. - Regarding `2` and `3` above, I don't introspect `op_schema.op._schema.arguments` to verify any `None` arguments are `torch.OptionalType`, leaving adherence to the schema contract the responsibility of the given op. Regarding `2`, I assume any `desired_spec` will be either a `DTensorSpec` or `None`, so only `None` can be Falsy in this context. - I considered altering the active `args_schema`, which could be inspected and aligned with the active `output_strategy.input_specs` in some cases and avoid the changes in `3`, but I think that would rely on one of (among other possibilities): - all supported op signatures having optional Tensors (`DTensorSpec`) args after required tensors (which isn't a planned required as far as I know), - (somewhat brittle) heuristic-driven arg alignment - only supporting kwargs etc. ### Added Tests To facilitate detection of future `requires_grad` pattern op failure modes as `DTensor` evolves, I added the following two tests: 1. `test/distributed/_tensor/test_math_ops.py DistMathOpsTest.test_layer_norm_bwd_req_grad` - Tests `native_layer_norm_backward` specifically with 20 subtests that sweep valid `output_mask` patterns along in different LayerNorm dimensionality and `elementwise_affine` configurations. 2. `test/distributed/tensor/parallel/test_tp_examples.py DistTensorParallelExampleTest.test_transformer_req_grad` - Samples a subset of `requires_grad` patterns in a more realistic (relative to the `LayerNorm`-specific test) Transformer usage context with different `dtype` and `is_seq_parallel` configurations. Note since there was substantial overlap with the existing `test_transformer_training` test, I took the opportunity to refactor that test to allow relevant code-sharing. I also added an `ExpCommCounts` `NamedTuple` to facilitate the addition of additional `requires_grad` patterns that we may want to test in the future which may result in different comm counts. I created the separate `requires_grad` test to allow decoupling the multi-iteration `test_transformer_training` test and allow addition of new `requires_grad` scenarios as desired while being mindful of resources. Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community! Pull Request resolved: #133502 Approved by: https://github.com/XilunWu
Fixes #133499
The issue
Testing a variety of TP
requires_grad
patterns (validating maximally flexible finetuning) revealedDTensor
sharding propagation ofaten.native_layer_norm_backward
(default) fails with anIndexError
for certainrequires_grad
patterns (pattern 1) (e.g.output_mask
[True, False, False]
) and anAssertionError
for others (pattern 2) (e.g. output mask[False, True, *]
). Please see issue #133499 for a full description of the observed failure patterns along with reproduction.Use Cases and Remediation
Failure pattern 1 is potentially problematic for a variety of finetuning scenarios. Though failure pattern 2 is really an xfail right now since it's not fully supported, IMHO there are use cases (e.g. especially wrt to mechanistic interpretability research, but certain finetuning scenarios too potentially) that justify supporting this output mask (especially since supporting it is fairly straightforward I think).
In this PR I propose some modest changes that:
DTenso
r op dispatch (which is so well implemented and such a pleasure working with btw! 🚀 🎉) accommodates a wide variety of (potentially unanticipated)requires_grad
patterns as it evolves.To address both failure modes, I'm proposing the following changes:
torch.distributed._tensor.ops._math_ops.layer_norm_bwd_strategy
:output_mask
handling such that the input and output specs in thePlacementStrategy
s of the returnedoutput_strategy.strategies
list remain aligned with theop_schema.args_spec
(whose definition does not change at runtime based upon unused optional args).torch.distributed._tensor._sharding_prop.propagate_op_sharding_non_cached
:op_schema.args_spec
to build the relevantexpected_input_specs
list, filter anyNone
desired_specs
.torch/distributed/_tensor/_op_schema.OpSchema._inplace_rewrap_schema_suggestion
None
arguments in the alignedsuggestion_args_schema
), ignore the associatedsuggestion_args_spec
Implementation considerations:
1
, to avoid changing the op strategy return args (op_strategy
), the change in1
allowsNone
elements to exist temporarily inPlacementStrategy.input_specs
(treating it asSequence[DTensorSpec | None] | None
when it'sSequence[DTensorSpec] | None
. This could be addressed in any number of ways but I thought it best to leave that for a subsequent PR since it could have broader ramifications (e.g. allowing op_strategies to return an output_strategy.input_specsmask explicitly, explicitly allowing
Nones in
PlacementStrategy.input_specs, creating a
Null` DTensorSpec etc.). That's why I'm using an ignore arg-type directive there for now.2
and3
above, I don't introspectop_schema.op._schema.arguments
to verify anyNone
arguments aretorch.OptionalType
, leaving adherence to the schema contract the responsibility of the given op. Regarding2
, I assume anydesired_spec
will be either aDTensorSpec
orNone
, so onlyNone
can be Falsy in this context.args_schema
, which could be inspected and aligned with the activeoutput_strategy.input_specs
in some cases and avoid the changes in3
, but I think that would rely on one of (among other possibilities):DTensorSpec
) args after required tensors (which isn't a planned required as far as I know),Added Tests
To facilitate detection of future
requires_grad
pattern op failure modes asDTensor
evolves, I added the following two tests:test/distributed/_tensor/test_math_ops.py DistMathOpsTest.test_layer_norm_bwd_req_grad
native_layer_norm_backward
specifically with 20 subtests that sweep validoutput_mask
patterns along in different LayerNorm dimensionality andelementwise_affine
configurations.test/distributed/tensor/parallel/test_tp_examples.py DistTensorParallelExampleTest.test_transformer_req_grad
requires_grad
patterns in a more realistic (relative to theLayerNorm
-specific test) Transformer usage context with differentdtype
andis_seq_parallel
configurations. Note since there was substantial overlap with the existingtest_transformer_training
test, I took the opportunity to refactor that test to allow relevant code-sharing. I also added anExpCommCounts
NamedTuple
to facilitate the addition of additionalrequires_grad
patterns that we may want to test in the future which may result in different comm counts. I created the separaterequires_grad
test to allow decoupling the multi-iterationtest_transformer_training
test and allow addition of newrequires_grad
scenarios as desired while being mindful of resources.Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l