Skip to content

Conversation

speediedan
Copy link
Contributor

@speediedan speediedan commented Aug 14, 2024

Fixes #133499

The issue

Testing a variety of TP requires_grad patterns (validating maximally flexible finetuning) revealed DTensor sharding propagation of aten.native_layer_norm_backward (default) fails with an IndexError for certain requires_grad patterns (pattern 1) (e.g. output_mask [True, False, False]) and an AssertionError for others (pattern 2) (e.g. output mask [False, True, *]). Please see issue #133499 for a full description of the observed failure patterns along with reproduction.

Use Cases and Remediation

Failure pattern 1 is potentially problematic for a variety of finetuning scenarios. Though failure pattern 2 is really an xfail right now since it's not fully supported, IMHO there are use cases (e.g. especially wrt to mechanistic interpretability research, but certain finetuning scenarios too potentially) that justify supporting this output mask (especially since supporting it is fairly straightforward I think).

In this PR I propose some modest changes that:

  • Address the aforementioned failure modes.
  • Add a couple tests that I'm hopeful will help ensure DTensor op dispatch (which is so well implemented and such a pleasure working with btw! 🚀 🎉) accommodates a wide variety of (potentially unanticipated) requires_grad patterns as it evolves.

To address both failure modes, I'm proposing the following changes:

  1. To torch.distributed._tensor.ops._math_ops.layer_norm_bwd_strategy:
  • Refactor conditional output_mask handling such that the input and output specs in thePlacementStrategys of the returned output_strategy.strategies list remain aligned with the op_schema.args_spec (whose definition does not change at runtime based upon unused optional args).
  1. To torch.distributed._tensor._sharding_prop.propagate_op_sharding_non_cached:
  • When iterating through the active op_schema.args_spec to build the relevant expected_input_specs list, filter any None desired_specs.
  1. To torch/distributed/_tensor/_op_schema.OpSchema._inplace_rewrap_schema_suggestion
  • When inputs need a redistribute, for runtime-unrequired (None arguments in the aligned suggestion_args_schema), ignore the associated suggestion_args_spec

Implementation considerations:

  • Regarding 1, to avoid changing the op strategy return args (op_strategy), the change in 1 allows None elements to exist temporarily in PlacementStrategy.input_specs (treating it as Sequence[DTensorSpec | None] | None when it's Sequence[DTensorSpec] | None. This could be addressed in any number of ways but I thought it best to leave that for a subsequent PR since it could have broader ramifications (e.g. allowing op_strategies to return an output_strategy.input_specsmask explicitly, explicitly allowingNones in PlacementStrategy.input_specs, creating a Null` DTensorSpec etc.). That's why I'm using an ignore arg-type directive there for now.
  • Regarding 2 and 3 above, I don't introspect op_schema.op._schema.arguments to verify any None arguments are torch.OptionalType, leaving adherence to the schema contract the responsibility of the given op. Regarding 2, I assume any desired_spec will be either a DTensorSpec or None, so only None can be Falsy in this context.
  • I considered altering the active args_schema, which could be inspected and aligned with the active output_strategy.input_specs in some cases and avoid the changes in 3, but I think that would rely on one of (among other possibilities):
    • all supported op signatures having optional Tensors (DTensorSpec) args after required tensors (which isn't a planned required as far as I know),
    • (somewhat brittle) heuristic-driven arg alignment
    • only supporting kwargs etc.

Added Tests

To facilitate detection of future requires_grad pattern op failure modes as DTensor evolves, I added the following two tests:

  1. test/distributed/_tensor/test_math_ops.py DistMathOpsTest.test_layer_norm_bwd_req_grad

    • Tests native_layer_norm_backward specifically with 20 subtests that sweep valid output_mask patterns along in different LayerNorm dimensionality and elementwise_affine configurations.
  2. test/distributed/tensor/parallel/test_tp_examples.py DistTensorParallelExampleTest.test_transformer_req_grad

    • Samples a subset of requires_grad patterns in a more realistic (relative to the LayerNorm-specific test) Transformer usage context with different dtype and is_seq_parallel configurations. Note since there was substantial overlap with the existing test_transformer_training test, I took the opportunity to refactor that test to allow relevant code-sharing. I also added an ExpCommCounts NamedTuple to facilitate the addition of additional requires_grad patterns that we may want to test in the future which may result in different comm counts. I created the separate requires_grad test to allow decoupling the multi-iteration test_transformer_training test and allow addition of new requires_grad scenarios as desired while being mindful of resources.

Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @tianyu-l

Copy link

pytorch-bot bot commented Aug 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133502

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 1f9cf04 with merge base 255cd75 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 14, 2024
@speediedan speediedan marked this pull request as ready for review August 14, 2024 20:18
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 16, 2024
@gnadathur gnadathur requested a review from wz337 August 16, 2024 17:12
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the fix! Left some questions and suggestions.

@speediedan speediedan requested a review from XilunWu August 22, 2024 00:55
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me! Thanks for addressing the review comments. Here're more questions & suggestions after reading the code into more details.

@speediedan speediedan requested a review from XilunWu August 22, 2024 19:31
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @speediedan for reporting the issue found in DTensor native_layer_norm_backward and providing a fix along with test coverage. The PR looks good to me and I'll merge it once all CI signals are green.

@XilunWu XilunWu added ciflow/trunk Trigger trunk jobs on your pull request better-engineering Relatively self-contained tasks for better engineering contributors topic: not user facing topic category module: dtensor distributed tensor tag labels Aug 23, 2024
@XilunWu
Copy link
Contributor

XilunWu commented Aug 24, 2024

@pytorchbot merge -f "lint error is not related to this PR"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Sep 13, 2024
…o more fully accommodate optional args (#133502)

Fixes #133499

### The issue

Testing a variety of TP `requires_grad` patterns (validating maximally flexible finetuning) revealed `DTensor` sharding propagation of `aten.native_layer_norm_backward` (default) fails with an `IndexError` for certain `requires_grad` patterns (pattern 1) (e.g. `output_mask` `[True, False, False]`) and an `AssertionError` for others (pattern 2) (e.g. output mask `[False, True, *]`). Please see issue #133499 for a full description of the observed failure patterns along with reproduction.

### Use Cases and Remediation

Failure pattern 1 is potentially problematic for a variety of finetuning scenarios. Though failure pattern 2 is really an xfail right now since it's not fully supported, IMHO there are use cases (e.g. especially wrt to mechanistic interpretability research, but certain finetuning scenarios too potentially) that justify supporting this output mask (especially since supporting it is fairly straightforward I think).

In this PR I propose some modest changes that:
  * Address the aforementioned failure modes.
  * Add a couple tests that I'm hopeful will help ensure `DTenso`r op dispatch (which is so well implemented and such a pleasure working with btw! 🚀 🎉) accommodates a wide variety of (potentially unanticipated) `requires_grad` patterns as it evolves.

To address both failure modes, I'm proposing the following changes:
1. To [`torch.distributed._tensor.ops._math_ops.layer_norm_bwd_strategy`](https://github.com/pytorch/pytorch/blob/7b269cc48434c94b038a749c3c3d8b7d586d0ca7/torch/distributed/_tensor/ops/_math_ops.py#L873):
  - Refactor conditional `output_mask` handling such that the input and output specs in the`PlacementStrategy`s of the returned `output_strategy.strategies` list remain aligned with the `op_schema.args_spec` (whose definition does not change at runtime based upon unused optional args).
2. To [`torch.distributed._tensor._sharding_prop.propagate_op_sharding_non_cached`](https://github.com/pytorch/pytorch/blob/7b269cc48434c94b038a749c3c3d8b7d586d0ca7/torch/distributed/_tensor/_sharding_prop.py#L256-L262):
  - When iterating through the active `op_schema.args_spec` to build the relevant `expected_input_specs` list, filter any `None` `desired_specs`.
3. To [`torch/distributed/_tensor/_op_schema.OpSchema._inplace_rewrap_schema_suggestion`](https://github.com/pytorch/pytorch/blob/7b269cc48434c94b038a749c3c3d8b7d586d0ca7/torch/distributed/_tensor/_op_schema.py#L418)
  - When inputs need a redistribute, for runtime-unrequired (`None` arguments in the aligned `suggestion_args_schema`), ignore the associated `suggestion_args_spec`

### Implementation considerations:

- Regarding `1`, to avoid changing the op strategy return args ([`op_strategy`](https://github.com/pytorch/pytorch/blob/cf8118000784653b14e4757690aa1d02ba8216fc/torch/distributed/_tensor/_sharding_prop.py#L234)), the change in `1` allows `None` elements to exist temporarily in `PlacementStrategy.input_specs` (treating it as `Sequence[DTensorSpec | None] | None` when it's `Sequence[DTensorSpec] | None`. This could be addressed in any number of ways but I thought it best to leave that for a subsequent PR since it could have broader ramifications (e.g. allowing op_strategies to return an output_strategy.input_specs` mask explicitly, explicitly allowing `None`s in `PlacementStrategy.input_specs`, creating a `Null` DTensorSpec etc.). That's why I'm using an ignore arg-type directive there for now.
- Regarding `2` and `3` above, I don't introspect `op_schema.op._schema.arguments` to verify any `None` arguments are `torch.OptionalType`, leaving adherence to the schema contract the responsibility of the given op. Regarding `2`, I assume any `desired_spec` will be either a `DTensorSpec` or `None`, so only `None` can be Falsy in this context.
- I considered altering the active `args_schema`, which could be inspected and aligned with the active `output_strategy.input_specs` in some cases and avoid the changes in `3`, but I think that would rely on one of (among other possibilities):
    - all supported op signatures having optional Tensors (`DTensorSpec`) args after required tensors (which isn't a planned required as far as I know),
    -  (somewhat brittle) heuristic-driven arg alignment
    -  only supporting kwargs etc.

### Added Tests

To facilitate detection of future `requires_grad` pattern op failure modes as `DTensor` evolves, I added the following two tests:

1. `test/distributed/_tensor/test_math_ops.py DistMathOpsTest.test_layer_norm_bwd_req_grad`
    - Tests `native_layer_norm_backward` specifically with 20 subtests that sweep valid `output_mask` patterns along in different LayerNorm dimensionality and `elementwise_affine` configurations.

2. `test/distributed/tensor/parallel/test_tp_examples.py DistTensorParallelExampleTest.test_transformer_req_grad`
    - Samples a subset of `requires_grad` patterns in a more realistic (relative to the `LayerNorm`-specific test) Transformer usage context with different `dtype` and `is_seq_parallel` configurations. Note since there was substantial overlap with the existing `test_transformer_training` test, I took the opportunity to refactor that test to allow relevant code-sharing. I also added an `ExpCommCounts` `NamedTuple` to facilitate the addition of additional `requires_grad` patterns that we may want to test in the future which may result in different comm counts. I created the separate `requires_grad` test to allow decoupling the multi-iteration `test_transformer_training` test and allow addition of new `requires_grad` scenarios as desired while being mindful of resources.

Thanks again to the PyTorch distributed team for your immensely valuable contributions to the open-source ML community!

Pull Request resolved: #133502
Approved by: https://github.com/XilunWu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

better-engineering Relatively self-contained tasks for better engineering contributors ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dtensor distributed tensor tag oncall: distributed Add this issue/PR to distributed oncall triage queue open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DTensor sharding propagation of native_layer_norm_backward does not fully accommodate optional args

5 participants