Skip to content

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Apr 2, 2025

Stack from ghstack (oldest at bottom):

By using cooperative simd_sum/simd_product instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions

Using such reduction increases the torch.compile performance for gpt-fast using stories110M from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:

size before after
512x512 202.1 131.8
1024x1024 780.6 176.9
2048x2048 1423.4 339.9
4096x4097 2982.2 1047.2

Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using simd_shuffle_down of 64-bit values represented as int2 types, that yields reduction in $log_2(threadgroup\_size)$ steps. `mlx/kernels/reduction/ops.h contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running

import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
  out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)

that returns following on M4

tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,  0,  0,  0,  0, 0,  0,  0,  0], device='mps:0', dtype=torch.int32)

but same kernel running on M1 returns

tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)

This discrepancy in behavior can be addressed by using simd_shuffle_and_fill_down, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150566

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 22 Pending, 1 Unrelated Failure

As of commit 7525368 with merge base 15dbad2 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) module: inductor labels Apr 2, 2025
@malfet malfet added topic: performance topic category release notes: mps Release notes category labels Apr 2, 2025
malfet and others added 5 commits April 2, 2025 16:21
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@malfet malfet requested a review from kulinseth as a code owner April 5, 2025 00:20
malfet added a commit that referenced this pull request Apr 5, 2025
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions

TODO:
 - Unhardcode max_threadgroup / simdgroup_size and query them at runtime instead

ghstack-source-id: fab1d5a
Pull Request resolved: #150566
@malfet
Copy link
Contributor Author

malfet commented Apr 5, 2025

@pytorchbot merge -f "Lint + MPS are finally green"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

timocafe pushed a commit to timocafe/pytorch that referenced this pull request Apr 16, 2025
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions

Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512         | 202.1       | 131.8       |
| 1024x1024   |   780.6    | 176.9       |
| 2048x2048    |   1423.4       | 339.9      |
| 4096x4097    |    2982.2 | 1047.2      |

Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L15-L18) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
  out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,  0,  0,  0,  0, 0,  0,  0,  0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.

Pull Request resolved: pytorch#150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: pytorch#150452, pytorch#150457
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions

Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows:
|size| before | after |
|------------------------|------------|-------------|
| 512x512         | 202.1       | 131.8       |
| 1024x1024   |   780.6    | 176.9       |
| 2048x2048    |   1423.4       | 339.9      |
| 4096x4097    |    2982.2 | 1047.2      |

Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L15-L18) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running
```python
import torch
lib=torch.mps.compile_shader("""
kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) {
  out[idx] = metal::simd_shuffle_down(in[idx], 8);
}
""")
x=torch.arange(22, device='mps', dtype=torch.int32)
y=torch.empty_like(x)
lib.do_sum(y, x)
print(y)
```
that returns following on M4
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,  0,  0,  0,  0, 0,  0,  0,  0], device='mps:0', dtype=torch.int32)
```
but same kernel running on M1 returns
```
tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32)
```
This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests.

Pull Request resolved: pytorch#150566
Approved by: https://github.com/manuelcandales
ghstack dependencies: pytorch#150452, pytorch#150457
@github-actions github-actions bot deleted the gh/malfet/265/head branch May 15, 2025 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) Merged module: inductor release notes: mps Release notes category topic: performance topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants