Skip to content

Conversation

copybara-service[bot]
Copy link

PR #20808: [GSPMD] Partitions collective permute instructions in manual sharding group.

Imported from GitHub PR openxla/xla#20808

This is a small fix in GSPMD partitioning for partitioning collective permutes instructions added in manual sharding group.

In JAX, we can add ppermute instruction in shard_map. In cases where we have shard_map with auto axes specified, collective permuting an operand even with the same sharding will end up with an all-gather and then collective permute, which leads to inefficient collectives. The correct and efficient way is to partition the collective permute as an element-wise op.

The unit test added provides a repro. Also, the JAX unit test in https://github.com/jax-ml/jax/blob/fa9c7edf736516052df6eab22947bc627d0deca3/tests/shard_map_test.py#L2167 gives a real-world JAX example.
Copybara import of the project:

--
8ee6ecd51f6e4aae8e3d92a6a439a60f53ab02ae by Yunlong Liu yunlongl@x.ai:

A hacky fix on partitioning collective permute.

--
e50e87696defb290f7561a7808ee42ebbc11e144 by Yunlong Liu yunlongl@x.ai:

Local change.

--
84eb38597c783a4488774823c2c464296a8c54c7 by Yunlong Liu yunlongl@x.ai:

Simplifies sharding in tests.

Merging this change closes #20808

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#20808 from yliu120:cp_sharding_2 84eb38597c783a4488774823c2c464296a8c54c7

@copybara-service copybara-service bot force-pushed the exported_pr_714284894 branch from d1419f9 to de85df5 Compare January 13, 2025 07:47
…al sharding group.

Imported from GitHub PR openxla/xla#20808

This is a small fix in GSPMD partitioning for partitioning collective permutes instructions added in manual sharding group.

In JAX, we can add `ppermute` instruction in shard_map. In cases where we have shard_map with auto axes specified, collective permuting an operand even with the same sharding will end up with an `all-gather` and then collective permute, which leads to inefficient collectives. The correct and efficient way is to partition the collective permute as an element-wise op.

The unit test added provides a repro. Also, the JAX unit test in https://github.com/jax-ml/jax/blob/fa9c7edf736516052df6eab22947bc627d0deca3/tests/shard_map_test.py#L2167 gives a real-world JAX example.
Copybara import of the project:

--
8ee6ecd51f6e4aae8e3d92a6a439a60f53ab02ae by Yunlong Liu <yunlongl@x.ai>:

A hacky fix on partitioning collective permute.

--
e50e87696defb290f7561a7808ee42ebbc11e144 by Yunlong Liu <yunlongl@x.ai>:

Local change.

--
84eb38597c783a4488774823c2c464296a8c54c7 by Yunlong Liu <yunlongl@x.ai>:

Simplifies sharding in tests.

Merging this change closes #20808

PiperOrigin-RevId: 714851861
@copybara-service copybara-service bot force-pushed the exported_pr_714284894 branch from de85df5 to b5d22c7 Compare January 13, 2025 08:26
@copybara-service copybara-service bot closed this Jan 13, 2025
@copybara-service copybara-service bot deleted the exported_pr_714284894 branch January 13, 2025 08:26
@copybara-service copybara-service bot merged commit b5d22c7 into master Jan 13, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant