Fixes #187017
median and nanmedian to metal. Perf:
### Global reduction (`torch.median(x)` / `torch.nanmedian(x)`)
| config | dtype | before (us) | after (us) | speedup |
|---|---|---:|---:|---:|
| global 1e6 | f32 | 376.1 | 104.9 | 3.59x |
| global 1e7 | f32 | 3,403.7 | 725.0 | 4.69x |
| global 1e6 | bf16 | 416.3 | 51.5 | 8.08x |
| global 1e7 | bf16 | 3,961.4 | 238.6 | 16.61x |
| global 1e6 | i32 | 427.4 | 114.5 | 3.73x |
| global 1e7 | i32 | 3,994.5 | 647.5 | 6.17x |
| global 1e6 | i64 | 1,114.7 | 206.9 | 5.39x |
| global 1e7 | i64 | 16,138.2 | 2,511.3 | 6.43x |
| nan10% 1e7 nanmed glob | f32 | 18,441.6 | 660.2 | 27.93x |
| nan10% 1e7 nanmed glob | bf16 | OOM | 239.8 | inf (OOM before) |
### Reduction along a dim (`torch.median(x, dim)`)
| config | dtype | before (us) | after (us) | speedup |
|---|---|---:|---:|---:|
| 4096x4096 dim=1 | f32 | 10,186.3 | 8,234.8 | 1.24x |
| 4096x4096 dim=0 | f32 | 18,222.7 | 10,010.9 | 1.82x |
| 64x1e6 dim=1 | f32 | 131,392.2 | 44,488.6 | 2.95x |
| 1e6x64 dim=1 | f32 | 799,542.7 | 53,340.6 | 14.99x |
| 1024x65536 dim=1 | f32 | 165,552.8 | 36,321.2 | 4.56x |
| 512x131072 dim=1 | f32 | 192,594.4 | 50,550.8 | 3.81x |
| 256^3 dim=2 | f32 | 85,459.8 | 8,062.8 | 10.60x |
| 256^3 dim=0 | f32 | 99,853.3 | 7,937.4 | 12.58x |
| 128x1024 dim=1 | f32 | 263.4 | 65.8 | 4.00x |
| 2x3x4x5x6 dim=1 | f32 | 210.5 | 25.9 | 8.13x |
| 4096x4096 dim=1 | bf16 | 8,259.1 | 4,460.5 | 1.85x |
| 4096x4096 dim=0 | bf16 | 14,593.3 | 5,263.7 | 2.77x |
| 64x1e6 dim=1 | bf16 | 71,337.7 | 16,727.9 | 4.26x |
| 1e6x64 dim=1 | bf16 | 718,870.6 | 63,828.1 | 11.26x |
| 1024x65536 dim=1 | bf16 | 99,392.7 | 31,771.3 | 3.13x |
| 512x131072 dim=1 | bf16 | 69,957.3 | 26,466.0 | 2.64x |
| 256^3 dim=2 | bf16 | 61,144.0 | 8,276.0 | 7.39x |
| 256^3 dim=0 | bf16 | 73,448.6 | 8,124.6 | 9.04x |
| 128x1024 dim=1 | bf16 | 192.0 | 72.1 | 2.66x |
| 2x3x4x5x6 dim=1 | bf16 | 214.8 | 28.5 | 7.54x |
| 4096x4096 dim=1 | i32 | 7,865.9 | 3,397.0 | 2.32x |
| 4096x4096 dim=0 | i32 | 14,520.8 | 4,709.3 | 3.08x |
| 64x1e6 dim=1 | i32 | 125,680.6 | 36,122.3 | 3.48x |
| 1e6x64 dim=1 | i32 | 722,882.5 | 10,845.9 | 66.65x |
| 1024x65536 dim=1 | i32 | 159,814.7 | 32,887.0 | 4.86x |
| 512x131072 dim=1 | i32 | 147,411.6 | 40,269.4 | 3.66x |
| 256^3 dim=2 | i32 | 58,854.0 | 1,665.9 | 35.33x |
| 256^3 dim=0 | i32 | 71,037.4 | 2,494.1 | 28.48x |
| 128x1024 dim=1 | i32 | 188.2 | 23.9 | 7.88x |
| 2x3x4x5x6 dim=1 | i32 | 210.3 | 9.8 | 21.49x |
| 4096x4096 dim=1 | i64 | 15,432.2 | 6,173.9 | 2.50x |
| 4096x4096 dim=0 | i64 | 24,024.7 | 7,869.9 | 3.05x |
| 64x1e6 dim=1 | i64 | 384,666.6 | 76,788.9 | 5.01x |
| 1e6x64 dim=1 | i64 | 1,401,332.0 | 17,755.0 | 78.93x |
| 1024x65536 dim=1 | i64 | 306,227.5 | 49,172.7 | 6.23x |
| 512x131072 dim=1 | i64 | 383,091.9 | 61,487.8 | 6.23x |
| 256^3 dim=2 | i64 | 109,296.3 | 2,737.0 | 39.93x |
| 256^3 dim=0 | i64 | 125,894.7 | 3,872.0 | 32.51x |
| 128x1024 dim=1 | i64 | 317.5 | 34.5 | 9.21x |
| 2x3x4x5x6 dim=1 | i64 | 397.5 | 12.6 | 31.57x |
| 512^3 dim=1 | f32 | 654,396.9 | 65,719.6 | 9.96x |
| 512^3 dim=1 | bf16 | 594,299.7 | 73,546.9 | 8.08x |
| 512^3 dim=1 | i32 | 333,311.3 | 25,239.7 | 13.21x |
| 512^3 dim=1 | i64 | OOM | 36,270.9 | inf (OOM before) |
| 1024x1024x512 dim=1 | f32 | OOM | 278,978.0 | inf (OOM before) |
| 1024x1024x512 dim=1 | bf16 | OOM | 311,786.9 | inf (OOM before) |
### Strided / sliced layouts
| config | dtype | before (us) | after (us) | speedup |
|---|---|---:|---:|---:|
| 4096^2.T dim=0 | f32 | 12,089.2 | 8,482.2 | 1.43x |
| 4096^2.T dim=1 | f32 | 16,644.8 | 10,942.0 | 1.52x |
| 4096^2[::2] dim=1 | f32 | 5,158.8 | 4,526.9 | 1.14x |
| 4096^2[:,1k:3k] dim=0 | f32 | 9,151.6 | 5,053.1 | 1.81x |
| 4096^2.T dim=0 | bf16 | 14,767.4 | 4,472.3 | 3.30x |
| 4096^2.T dim=1 | bf16 | 11,008.9 | 5,410.5 | 2.03x |
| 4096^2[::2] dim=1 | bf16 | 4,127.8 | 2,515.7 | 1.64x |
| 4096^2[:,1k:3k] dim=0 | bf16 | 7,348.2 | 2,566.0 | 2.86x |
| 4096^2.T dim=0 | i32 | 9,637.6 | 3,409.0 | 2.83x |
| 4096^2.T dim=1 | i32 | 12,800.6 | 4,815.9 | 2.66x |
| 4096^2[::2] dim=1 | i32 | 3,874.0 | 1,711.5 | 2.26x |
| 4096^2[:,1k:3k] dim=0 | i32 | 7,214.9 | 2,360.5 | 3.06x |
| 4096^2.T dim=0 | i64 | 17,010.4 | 6,161.7 | 2.76x |
| 4096^2.T dim=1 | i64 | 22,469.8 | 6,447.1 | 3.49x |
| 4096^2[::2] dim=1 | i64 | 7,788.3 | 3,129.9 | 2.49x |
| 4096^2[:,1k:3k] dim=0 | i64 | 11,912.6 | 3,966.6 | 3.00x |
### nanmedian with NaNs
| config | dtype | before (us) | after (us) | speedup |
|---|---|---:|---:|---:|
| nan10% 4096^2 nanmed d1 | f32 | 10,439.2 | 9,199.2 | 1.13x |
| nan10% 4096^2 nanmed d0 | f32 | 18,530.9 | 10,941.3 | 1.69x |
| clean 4096^2 nanmed d1 | f32 | 10,470.4 | 8,214.9 | 1.27x |
| nan10% 4096^2 nanmed d1 | bf16 | 18,870.5 | 4,368.2 | 4.32x |
| nan10% 4096^2 nanmed d0 | bf16 | 14,820.6 | 5,037.6 | 2.94x |
| clean 4096^2 nanmed d1 | bf16 | 18,981.2 | 4,390.5 | 4.32x |
| nan10% 512^3 nanmed d1 | f32 | 415,150.8 | 72,076.8 | 5.76x |
| nan10% 512^3 nanmed d1 | bf16 | 633,870.6 | 87,350.0 | 7.26x |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/187060
Approved by: https://github.com/malfet