Skip to content

Commit

Permalink
[dtensor][op] Fixed stack op strategy (#129018)
Browse files Browse the repository at this point in the history
**Summary**
The previous stack op strategy was causing the input to be resharded, resulting in list index out of range error. I delayed the resharding for after the input_specs were created so that the new dimension could be inserted, preventing the error above. I have also ran all the other test cases to ensure changes did not introduce any new bugs

**Test Plan**
pytest test/distributed/_tensor/test_tensor_ops.py -s -k test_stack

Pull Request resolved: #129018
Approved by: https://github.com/XilunWu
  • Loading branch information
sinhaanshul authored and pytorchmergebot committed Jun 21, 2024
1 parent 6b5fbc5 commit aee512c
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torch/distributed/_tensor/ops/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
follow_placements = _derive_follow_placements_from_tuple_strategy(
input_tuple_strategy
)
follow_placements = normalize_shard_for_stack(follow_placements, dim)

# create op strategy base on the follow placements
op_strategy = OpStrategy([])
Expand All @@ -522,6 +521,9 @@ def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs))
)

follow_placements = normalize_shard_for_stack(follow_placements, dim)

op_strategy.strategies.append(
PlacementStrategy(
output_specs=DTensorSpec(mesh, tuple(follow_placements)),
Expand Down

0 comments on commit aee512c

Please sign in to comment.