Skip to content

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Jan 15, 2025

In SPMD partitioner, preprocess the sharding on singleton dimensions (dimensions whose size is 1).

It is meaningless to partition a dimension whose size is 1. Redundant padding and unpadding may be inserted. To avoid this, we replicate the sharding on these dimensions as a pre-processing.

Take the following input as example

ENTRY entry {
  %constant.785 = f32[1,8] constant({{0,1,2,3,4,5,6,7}}), sharding={devices=[1,8]<=[8]}
  %slice.62 = f32[1,1] slice(%constant.785), slice={[0:1], [0:1]}, sharding={devices=[1,8]<=[8]}
  ROOT %reshape.779 = f32[] reshape(%slice.62), sharding={replicated}
}

Previous result with redundant instructions

ENTRY %entry_spmd () -> f32[] {
  %constant.8 = u32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7})
  %partition-id = u32[] partition-id()
  %dynamic-slice.3 = u32[1]{0} dynamic-slice(u32[8]{0} %constant.8, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape.2 = u32[] reshape(u32[1]{0} %dynamic-slice.3)
  %constant.9 = u32[] constant(0)
  %compare = pred[] compare(u32[] %reshape.2, u32[] %constant.9), direction=EQ
  %broadcast = pred[1,1]{1,0} broadcast(pred[] %compare), dimensions={}
  %constant.0 = f32[1,8]{1,0} constant({ { 0, 1, 2, 3, 4, 5, 6, 7 } })
  %constant.1 = s32[] constant(0)
  %constant.2 = s32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7})
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[8]{0} %constant.2, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape = s32[] reshape(s32[1]{0} %dynamic-slice)
  %dynamic-slice.1 = f32[1,1]{1,0} dynamic-slice(f32[1,8]{1,0} %constant.0, s32[] %constant.1, s32[] %reshape), dynamic_slice_sizes={1,1}
  %copy = f32[1,1]{1,0} copy(f32[1,1]{1,0} %dynamic-slice.1)
  %constant.10 = f32[] constant(0)
  %broadcast.1 = f32[1,1]{1,0} broadcast(f32[] %constant.10), dimensions={}
  %select = f32[1,1]{1,0} select(pred[1,1]{1,0} %broadcast, f32[1,1]{1,0} %copy, f32[1,1]{1,0} %broadcast.1)
  %all-reduce = f32[1,1]{1,0} all-reduce(f32[1,1]{1,0} %select), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=%add.clone
  ROOT %reshape.3 = f32[] reshape(f32[1,1]{1,0} %all-reduce)
}

Result with this improvement

ENTRY %entry_spmd () -> f32[] {
  %constant.0 = f32[1,8]{1,0} constant({ { 0, 1, 2, 3, 4, 5, 6, 7 } })
  %slice.0 = f32[1,1]{1,0} slice(f32[1,8]{1,0} %constant.0), slice={[0:1], [0:1]}
  ROOT %reshape.1 = f32[] reshape(f32[1,1]{1,0} %slice.0)
}

Reverts a7703e7

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#21273 from nvcastet:ncclCommInitRankScalable dd6362af36a1f4d22532ad15b2007527898b5fa1

@copybara-service copybara-service bot force-pushed the exported_pr_715559819 branch from f5cb551 to d76f776 Compare January 15, 2025 21:04
…(dimensions whose size is 1).

It is meaningless to partition a dimension whose size is 1. Redundant padding and unpadding may be inserted. To avoid this, we replicate the sharding on these dimensions as a pre-processing.

Take the following input as example
```
ENTRY entry {
  %constant.785 = f32[1,8] constant({{0,1,2,3,4,5,6,7}}), sharding={devices=[1,8]<=[8]}
  %slice.62 = f32[1,1] slice(%constant.785), slice={[0:1], [0:1]}, sharding={devices=[1,8]<=[8]}
  ROOT %reshape.779 = f32[] reshape(%slice.62), sharding={replicated}
}
```

Previous result with redundant instructions
```
ENTRY %entry_spmd () -> f32[] {
  %constant.8 = u32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7})
  %partition-id = u32[] partition-id()
  %dynamic-slice.3 = u32[1]{0} dynamic-slice(u32[8]{0} %constant.8, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape.2 = u32[] reshape(u32[1]{0} %dynamic-slice.3)
  %constant.9 = u32[] constant(0)
  %compare = pred[] compare(u32[] %reshape.2, u32[] %constant.9), direction=EQ
  %broadcast = pred[1,1]{1,0} broadcast(pred[] %compare), dimensions={}
  %constant.0 = f32[1,8]{1,0} constant({ { 0, 1, 2, 3, 4, 5, 6, 7 } })
  %constant.1 = s32[] constant(0)
  %constant.2 = s32[8]{0} constant({0, 1, 2, 3, 4, 5, 6, 7})
  %dynamic-slice = s32[1]{0} dynamic-slice(s32[8]{0} %constant.2, u32[] %partition-id), dynamic_slice_sizes={1}
  %reshape = s32[] reshape(s32[1]{0} %dynamic-slice)
  %dynamic-slice.1 = f32[1,1]{1,0} dynamic-slice(f32[1,8]{1,0} %constant.0, s32[] %constant.1, s32[] %reshape), dynamic_slice_sizes={1,1}
  %copy = f32[1,1]{1,0} copy(f32[1,1]{1,0} %dynamic-slice.1)
  %constant.10 = f32[] constant(0)
  %broadcast.1 = f32[1,1]{1,0} broadcast(f32[] %constant.10), dimensions={}
  %select = f32[1,1]{1,0} select(pred[1,1]{1,0} %broadcast, f32[1,1]{1,0} %copy, f32[1,1]{1,0} %broadcast.1)
  %all-reduce = f32[1,1]{1,0} all-reduce(f32[1,1]{1,0} %select), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, use_global_device_ids=true, to_apply=%add.clone
  ROOT %reshape.3 = f32[] reshape(f32[1,1]{1,0} %all-reduce)
}
```

Result with this improvement
```
ENTRY %entry_spmd () -> f32[] {
  %constant.0 = f32[1,8]{1,0} constant({ { 0, 1, 2, 3, 4, 5, 6, 7 } })
  %slice.0 = f32[1,1]{1,0} slice(f32[1,8]{1,0} %constant.0), slice={[0:1], [0:1]}
  ROOT %reshape.1 = f32[] reshape(f32[1,1]{1,0} %slice.0)
}
```

PiperOrigin-RevId: 715924899
@copybara-service copybara-service bot force-pushed the exported_pr_715559819 branch from d76f776 to 312fe36 Compare January 15, 2025 22:18
@copybara-service copybara-service bot closed this Jan 15, 2025
@copybara-service copybara-service bot deleted the exported_pr_715559819 branch January 15, 2025 22:18
@copybara-service copybara-service bot merged commit 312fe36 into master Jan 15, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant