Skip to content

Conversation

AlnisM
Copy link
Contributor

@AlnisM AlnisM commented Jun 7, 2024

This PR introduces a heuristic for tuned_mixed_mm. The heuristic is only enabled on an A100, because it has only been tested on an A100, and it is only enabled if force_mixed_mm="heuristic".

I compared the heuristic to the aten fallback implementation and triton+autotune:
Geometric mean speedup: 2.51

m     n     k  triton + autotune (GB/s)  aten (GB/s)  heuristic (GB/s)  used_heuristic  speedup (heuristic/aten)
 1  4096  4096                    456.95       134.59            459.37            True                      3.41
 1  4096  8192                    523.93       138.29            553.50            True                      4.00
 1  4096 16394                    233.70       161.62            234.14            True                      1.45
 1  8192  4096                    633.25       140.64            574.86            True                      4.09
 1  8192  8192                    737.54       147.41            690.26            True                      4.68
 1  8192 16394                    413.67       175.88            408.68            True                      2.32
 1 16394  4096                    717.22       167.22            665.36            True                      3.98
 1 16394  8192                    812.69       177.17            815.90            True                      4.61
 1 16394 16394                    473.17       178.58            435.11            True                      2.44
 4  4096  4096                    479.46       134.80            486.74            True                      3.61
 4  4096  6333                    174.27       106.74            171.64            True                      1.61
 4  4096  8192                    567.14       138.32            571.09            True                      4.13
 4  4096 12313                    179.65       105.91            180.03            True                      1.70
 4  4096 16394                    222.96       145.54            222.81            True                      1.53
 4  6333  4096                    491.78       126.37            473.20            True                      3.74
 4  6333  6333                    268.79       143.40            269.75            True                      1.88
 4  6333  8192                    783.80       135.12            796.23            True                      5.89
 4  6333 12313                    286.35       142.37            287.30            True                      2.02
 4  6333 16394                    362.47       139.66            361.47            True                      2.59
 4  8192  4096                    642.73       140.53            641.88            True                      4.57
 4  8192  6333                    287.65       137.63            287.38            True                      2.09
 4  8192  8192                    738.42       150.16            721.59            True                      4.81
 4  8192 12313                    301.27       146.18            302.31            True                      2.07
 4  8192 16394                    415.37       167.66            393.41            True                      2.35
 4 12313  4096                    823.66       141.81            745.40            True                      5.26
 4 12313  6333                    433.92       148.17            429.83            True                      2.90
 4 12313  8192                    984.60       149.30            988.95            True                      6.62
 4 12313 12313                    452.00       150.87            452.50            True                      3.00
 4 12313 16394                    609.88       159.20            609.71            True                      3.83
 4 16394  4096                    779.44       157.46            777.10            True                      4.94
 4 16394  6333                    402.93       139.50            309.47            True                      2.22
 4 16394  8192                    950.38       175.49            949.67            True                      5.41
 4 16394 12313                    414.62       153.99            315.95            True                      2.05
 4 16394 16394                    497.56       174.97            461.77            True                      2.64
16  4096  4096                    475.92       134.45            478.57            True                      3.56
16  4096  6333                    146.36       112.50            145.35            True                      1.29
16  4096  8192                    560.00       138.22            557.19            True                      4.03
16  4096 12313                    152.02       105.06            151.27            True                      1.44
16  4096 16394                    222.48       156.72            222.88            True                      1.42
16  6333  4096                    692.41       122.14            696.88            True                      5.71
16  6333  6333                    220.74       140.90            225.41            True                      1.60
16  6333  8192                    813.56       140.21            820.28            True                      5.85
16  6333 12313                    232.48       131.19            232.55            True                      1.77
16  6333 16394                    367.39       134.93            361.87            True                      2.68
16  8192  4096                    665.54       140.29            266.24            True                      1.90
16  8192  6333                    254.77       136.65            240.12            True                      1.76
16  8192  8192                    750.63       146.26            736.93            True                      5.04
16  8192 12313                    266.61       127.13            251.81            True                      1.98
16  8192 16394                    397.25       160.42            390.76            True                      2.44
16 12313  4096                    857.48       141.36            851.36            True                      6.02
16 12313  6333                    423.21       132.40            357.55            True                      2.70
16 12313  8192                   1021.24       145.68           1024.60            True                      7.03
16 12313 12313                    370.12       143.94            383.52            True                      2.66
16 12313 16394                    608.52       141.03            608.48            True                      4.31
16 16394  4096                    826.48       155.94            826.74            True                      5.30
16 16394  6333                    420.38       144.09            265.23            True                      1.84
16 16394  8192                    988.07       156.21            984.63            True                      6.30
16 16394 12313                    431.40       146.92            265.49            True                      1.81
16 16394 16394                    497.39       167.86            461.79            True                      2.75
23  4096  4096                    344.43       132.84            338.64            True                      2.55
23  4096  6333                    195.34       118.48            195.31            True                      1.65
23  4096  8192                    389.83       140.02            376.62            True                      2.69
23  4096 12313                    204.49       137.96            204.80            True                      1.48
23  4096 16394                    242.48       148.99            242.74            True                      1.63
23  6333  4096                    429.25       126.52            517.75            True                      4.09
23  6333  6333                    295.56       133.51            296.14            True                      2.22
23  6333  8192                    594.88       137.05            581.78            True                      4.25
23  6333 12313                    315.18       131.67            314.64            True                      2.39
23  6333 16394                    386.46       141.45            386.54            True                      2.73
23  8192  4096                    553.52       142.05            568.35            True                      4.00
23  8192  6333                    215.58       139.01            210.86            True                      1.52
23  8192  8192                    609.21       154.85            528.76            True                      3.41
23  8192 12313                    220.38       142.93            233.54            True                      1.63
23  8192 16394                    402.63       158.39            403.21            True                      2.55
23 12313  4096                    723.54       131.58            581.94            True                      4.42
23 12313  6333                    307.90       131.58            307.90            True                      2.34
23 12313  8192                    893.36       129.97            623.72            True                      4.80
23 12313 12313                    322.40       134.84            317.80            True                      2.36
23 12313 16394                    512.97       142.31            409.45            True                      2.88
23 16394  4096                    703.66       154.54            643.53            True                      4.16
23 16394  6333                    305.55       127.55            293.17            True                      2.30
23 16394  8192                    768.12       154.60            681.53            True                      4.41
23 16394 12313                    311.61       140.92            307.01            True                      2.18
23 16394 16394                    467.24       171.07            467.29            True                      2.73
32  4096  4096                    344.71       132.30            338.62            True                      2.56
32  4096  6333                    206.48       107.59            205.55            True                      1.91
32  4096  8192                    387.24       137.82            353.12            True                      2.56
32  4096 12313                    216.35       120.61            214.50            True                      1.78
32  4096 16394                    242.05       149.92            241.94            True                      1.61
32  6333  4096                    525.50       127.12            518.02            True                      4.08
32  6333  6333                    300.50       118.41            296.55            True                      2.50
32  6333  8192                    600.92       136.99            601.94            True                      4.39
32  6333 12313                    316.13       136.45            316.03            True                      2.32
32  6333 16394                    386.11       141.34            386.10            True                      2.73
32  8192  4096                    546.18       140.18            341.14            True                      2.43
32  8192  6333                    218.40       130.65            263.42            True                      2.02
32  8192  8192                    608.29       147.16            542.12            True                      3.68
32  8192 12313                    225.60       135.04            225.23            True                      1.67
32  8192 16394                    434.75       160.42            401.28            True                      2.50
32 12313  4096                    787.80       136.28            583.60            True                      4.28
32 12313  6333                    316.66       125.76            323.35            True                      2.57
32 12313  8192                    891.38       128.88            639.50            True                      4.96
32 12313 12313                    326.11       132.37            325.88            True                      2.46
32 12313 16394                    521.64       139.47            395.69            True                      2.84
32 16394  4096                    625.55       158.46            651.16            True                      4.11
32 16394  6333                    304.14       131.13            284.55            True                      2.17
32 16394  8192                    767.79       162.95            704.34            True                      4.32
32 16394 12313                    310.74       137.68            303.39            True                      2.20
32 16394 16394                    465.92       171.43            465.37            True                      2.71
43  4096  4096                    345.05       133.87            196.47            True                      1.47
43  4096  6333                    148.64        99.92            148.97            True                      1.49
43  4096  8192                    386.50       135.39            214.00            True                      1.58
43  4096 12313                    190.39       109.36            156.27            True                      1.43
43  4096 16394                    203.63       150.24            204.05            True                      1.36
43  6333  4096                    421.35       106.04            132.25            True                      1.25
43  6333  6333                    224.75       113.01            224.97            True                      1.99
43  6333  8192                    471.11       117.61            327.39            True                      2.78
43  6333 12313                    234.55       115.61            234.74            True                      2.03
43  6333 16394                    311.56       132.24            312.01            True                      2.36
43  8192  4096                    400.73       140.12            269.11            True                      1.92
43  8192  6333                    167.32       119.13            168.84            True                      1.42
43  8192  8192                    435.45       146.98            286.21            True                      1.95
43  8192 12313                    161.05       127.82            162.78            True                      1.27
43  8192 16394                    207.16       156.40            208.90            True                      1.34
43 12313  4096                    484.01       120.10            313.35            True                      2.61
43 12313  6333                    234.54       106.63            232.85            True                      2.18
43 12313  8192                    515.34       130.23            411.70            True                      3.16
43 12313 12313                    239.39       130.04            239.03            True                      1.84
43 12313 16394                    316.02       137.39            316.29            True                      2.30
43 16394  4096                    475.60       152.57            340.97            True                      2.23
43 16394  6333                    241.21       132.49            208.59            True                      1.57
43 16394  8192                    499.34       157.43            361.61            True                      2.30
43 16394 12313                    246.25       132.31            211.68            True                      1.60
43 16394 16394                    302.90       158.56            277.05            True                      1.75
64  4096  4096                    280.48       126.82            195.97            True                      1.55
64  4096  6333                    150.94       101.63            150.48            True                      1.48
64  4096  8192                    305.47       135.06            211.03            True                      1.56
64  4096 12313                    158.12       110.06            158.15            True                      1.44
64  4096 16394                    206.68       136.21            201.28            True                      1.48
64  6333  4096                    409.11       105.10            296.07            True                      2.82
64  6333  6333                    229.98       108.46            230.59            True                      2.13
64  6333  8192                    469.32       112.24            330.58            True                      2.95
64  6333 12313                    245.02       117.16            244.84            True                      2.09
64  6333 16394                    317.78       125.80            318.37            True                      2.53
64  8192  4096                    323.42       139.92            267.31            True                      1.91
64  8192  6333                    167.51       118.45            167.56            True                      1.41
64  8192  8192                    341.13       146.71            284.88            True                      1.94
64  8192 12313                    172.21       123.42            171.97            True                      1.39
64  8192 16394                    217.22       153.18            216.99            True                      1.42
64 12313  4096                    482.19       123.32            311.82            True                      2.53
64 12313  6333                    238.73       123.88            238.66            True                      1.93
64 12313  8192                    516.32       122.11            330.50            True                      2.71
64 12313 12313                    248.73       125.32            296.82            True                      2.37
64 12313 16394                    314.98       134.06            320.31            True                      2.39
64 16394  4096                    476.59       154.58            340.84            True                      2.20
64 16394  6333                    240.54       119.60            214.82            True                      1.80
64 16394  8192                    501.36       149.02            359.45            True                      2.41
64 16394 12313                    244.65       126.01            222.47            True                      1.77
64 16394 16394                    302.48       160.36            283.66            True                      1.77

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@AlnisM AlnisM self-assigned this Jun 7, 2024
Copy link

pytorch-bot bot commented Jun 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128232

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9db041a with merge base 7b9c5e0 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@AlnisM AlnisM requested a review from Chillee June 7, 2024 17:50
@AlnisM AlnisM marked this pull request as ready for review June 8, 2024 00:52
@AlnisM AlnisM marked this pull request as draft June 10, 2024 15:47
@AlnisM AlnisM force-pushed the mixed-mm-heuristic branch from e56e62d to 6ed50d1 Compare June 10, 2024 17:31
@AlnisM AlnisM marked this pull request as ready for review June 12, 2024 00:26
Copy link
Collaborator

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@AlnisM
Copy link
Contributor Author

AlnisM commented Jun 13, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 13, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
This PR introduces a heuristic for tuned_mixed_mm. The heuristic is only enabled on an A100, because it has only been tested on an A100, and it is only enabled if force_mixed_mm="heuristic".

I compared the heuristic to the aten fallback implementation and triton+autotune:
 Geometric mean speedup: 2.51
 ```
 m     n     k  triton + autotune (GB/s)  aten (GB/s)  heuristic (GB/s)  used_heuristic  speedup (heuristic/aten)
  1  4096  4096                    456.95       134.59            459.37            True                      3.41
  1  4096  8192                    523.93       138.29            553.50            True                      4.00
  1  4096 16394                    233.70       161.62            234.14            True                      1.45
  1  8192  4096                    633.25       140.64            574.86            True                      4.09
  1  8192  8192                    737.54       147.41            690.26            True                      4.68
  1  8192 16394                    413.67       175.88            408.68            True                      2.32
  1 16394  4096                    717.22       167.22            665.36            True                      3.98
  1 16394  8192                    812.69       177.17            815.90            True                      4.61
  1 16394 16394                    473.17       178.58            435.11            True                      2.44
  4  4096  4096                    479.46       134.80            486.74            True                      3.61
  4  4096  6333                    174.27       106.74            171.64            True                      1.61
  4  4096  8192                    567.14       138.32            571.09            True                      4.13
  4  4096 12313                    179.65       105.91            180.03            True                      1.70
  4  4096 16394                    222.96       145.54            222.81            True                      1.53
  4  6333  4096                    491.78       126.37            473.20            True                      3.74
  4  6333  6333                    268.79       143.40            269.75            True                      1.88
  4  6333  8192                    783.80       135.12            796.23            True                      5.89
  4  6333 12313                    286.35       142.37            287.30            True                      2.02
  4  6333 16394                    362.47       139.66            361.47            True                      2.59
  4  8192  4096                    642.73       140.53            641.88            True                      4.57
  4  8192  6333                    287.65       137.63            287.38            True                      2.09
  4  8192  8192                    738.42       150.16            721.59            True                      4.81
  4  8192 12313                    301.27       146.18            302.31            True                      2.07
  4  8192 16394                    415.37       167.66            393.41            True                      2.35
  4 12313  4096                    823.66       141.81            745.40            True                      5.26
  4 12313  6333                    433.92       148.17            429.83            True                      2.90
  4 12313  8192                    984.60       149.30            988.95            True                      6.62
  4 12313 12313                    452.00       150.87            452.50            True                      3.00
  4 12313 16394                    609.88       159.20            609.71            True                      3.83
  4 16394  4096                    779.44       157.46            777.10            True                      4.94
  4 16394  6333                    402.93       139.50            309.47            True                      2.22
  4 16394  8192                    950.38       175.49            949.67            True                      5.41
  4 16394 12313                    414.62       153.99            315.95            True                      2.05
  4 16394 16394                    497.56       174.97            461.77            True                      2.64
16  4096  4096                    475.92       134.45            478.57            True                      3.56
16  4096  6333                    146.36       112.50            145.35            True                      1.29
16  4096  8192                    560.00       138.22            557.19            True                      4.03
16  4096 12313                    152.02       105.06            151.27            True                      1.44
16  4096 16394                    222.48       156.72            222.88            True                      1.42
16  6333  4096                    692.41       122.14            696.88            True                      5.71
16  6333  6333                    220.74       140.90            225.41            True                      1.60
16  6333  8192                    813.56       140.21            820.28            True                      5.85
16  6333 12313                    232.48       131.19            232.55            True                      1.77
16  6333 16394                    367.39       134.93            361.87            True                      2.68
16  8192  4096                    665.54       140.29            266.24            True                      1.90
16  8192  6333                    254.77       136.65            240.12            True                      1.76
16  8192  8192                    750.63       146.26            736.93            True                      5.04
16  8192 12313                    266.61       127.13            251.81            True                      1.98
16  8192 16394                    397.25       160.42            390.76            True                      2.44
16 12313  4096                    857.48       141.36            851.36            True                      6.02
16 12313  6333                    423.21       132.40            357.55            True                      2.70
16 12313  8192                   1021.24       145.68           1024.60            True                      7.03
16 12313 12313                    370.12       143.94            383.52            True                      2.66
16 12313 16394                    608.52       141.03            608.48            True                      4.31
16 16394  4096                    826.48       155.94            826.74            True                      5.30
16 16394  6333                    420.38       144.09            265.23            True                      1.84
16 16394  8192                    988.07       156.21            984.63            True                      6.30
16 16394 12313                    431.40       146.92            265.49            True                      1.81
16 16394 16394                    497.39       167.86            461.79            True                      2.75
23  4096  4096                    344.43       132.84            338.64            True                      2.55
23  4096  6333                    195.34       118.48            195.31            True                      1.65
23  4096  8192                    389.83       140.02            376.62            True                      2.69
23  4096 12313                    204.49       137.96            204.80            True                      1.48
23  4096 16394                    242.48       148.99            242.74            True                      1.63
23  6333  4096                    429.25       126.52            517.75            True                      4.09
23  6333  6333                    295.56       133.51            296.14            True                      2.22
23  6333  8192                    594.88       137.05            581.78            True                      4.25
23  6333 12313                    315.18       131.67            314.64            True                      2.39
23  6333 16394                    386.46       141.45            386.54            True                      2.73
23  8192  4096                    553.52       142.05            568.35            True                      4.00
23  8192  6333                    215.58       139.01            210.86            True                      1.52
23  8192  8192                    609.21       154.85            528.76            True                      3.41
23  8192 12313                    220.38       142.93            233.54            True                      1.63
23  8192 16394                    402.63       158.39            403.21            True                      2.55
23 12313  4096                    723.54       131.58            581.94            True                      4.42
23 12313  6333                    307.90       131.58            307.90            True                      2.34
23 12313  8192                    893.36       129.97            623.72            True                      4.80
23 12313 12313                    322.40       134.84            317.80            True                      2.36
23 12313 16394                    512.97       142.31            409.45            True                      2.88
23 16394  4096                    703.66       154.54            643.53            True                      4.16
23 16394  6333                    305.55       127.55            293.17            True                      2.30
23 16394  8192                    768.12       154.60            681.53            True                      4.41
23 16394 12313                    311.61       140.92            307.01            True                      2.18
23 16394 16394                    467.24       171.07            467.29            True                      2.73
32  4096  4096                    344.71       132.30            338.62            True                      2.56
32  4096  6333                    206.48       107.59            205.55            True                      1.91
32  4096  8192                    387.24       137.82            353.12            True                      2.56
32  4096 12313                    216.35       120.61            214.50            True                      1.78
32  4096 16394                    242.05       149.92            241.94            True                      1.61
32  6333  4096                    525.50       127.12            518.02            True                      4.08
32  6333  6333                    300.50       118.41            296.55            True                      2.50
32  6333  8192                    600.92       136.99            601.94            True                      4.39
32  6333 12313                    316.13       136.45            316.03            True                      2.32
32  6333 16394                    386.11       141.34            386.10            True                      2.73
32  8192  4096                    546.18       140.18            341.14            True                      2.43
32  8192  6333                    218.40       130.65            263.42            True                      2.02
32  8192  8192                    608.29       147.16            542.12            True                      3.68
32  8192 12313                    225.60       135.04            225.23            True                      1.67
32  8192 16394                    434.75       160.42            401.28            True                      2.50
32 12313  4096                    787.80       136.28            583.60            True                      4.28
32 12313  6333                    316.66       125.76            323.35            True                      2.57
32 12313  8192                    891.38       128.88            639.50            True                      4.96
32 12313 12313                    326.11       132.37            325.88            True                      2.46
32 12313 16394                    521.64       139.47            395.69            True                      2.84
32 16394  4096                    625.55       158.46            651.16            True                      4.11
32 16394  6333                    304.14       131.13            284.55            True                      2.17
32 16394  8192                    767.79       162.95            704.34            True                      4.32
32 16394 12313                    310.74       137.68            303.39            True                      2.20
32 16394 16394                    465.92       171.43            465.37            True                      2.71
43  4096  4096                    345.05       133.87            196.47            True                      1.47
43  4096  6333                    148.64        99.92            148.97            True                      1.49
43  4096  8192                    386.50       135.39            214.00            True                      1.58
43  4096 12313                    190.39       109.36            156.27            True                      1.43
43  4096 16394                    203.63       150.24            204.05            True                      1.36
43  6333  4096                    421.35       106.04            132.25            True                      1.25
43  6333  6333                    224.75       113.01            224.97            True                      1.99
43  6333  8192                    471.11       117.61            327.39            True                      2.78
43  6333 12313                    234.55       115.61            234.74            True                      2.03
43  6333 16394                    311.56       132.24            312.01            True                      2.36
43  8192  4096                    400.73       140.12            269.11            True                      1.92
43  8192  6333                    167.32       119.13            168.84            True                      1.42
43  8192  8192                    435.45       146.98            286.21            True                      1.95
43  8192 12313                    161.05       127.82            162.78            True                      1.27
43  8192 16394                    207.16       156.40            208.90            True                      1.34
43 12313  4096                    484.01       120.10            313.35            True                      2.61
43 12313  6333                    234.54       106.63            232.85            True                      2.18
43 12313  8192                    515.34       130.23            411.70            True                      3.16
43 12313 12313                    239.39       130.04            239.03            True                      1.84
43 12313 16394                    316.02       137.39            316.29            True                      2.30
43 16394  4096                    475.60       152.57            340.97            True                      2.23
43 16394  6333                    241.21       132.49            208.59            True                      1.57
43 16394  8192                    499.34       157.43            361.61            True                      2.30
43 16394 12313                    246.25       132.31            211.68            True                      1.60
43 16394 16394                    302.90       158.56            277.05            True                      1.75
64  4096  4096                    280.48       126.82            195.97            True                      1.55
64  4096  6333                    150.94       101.63            150.48            True                      1.48
64  4096  8192                    305.47       135.06            211.03            True                      1.56
64  4096 12313                    158.12       110.06            158.15            True                      1.44
64  4096 16394                    206.68       136.21            201.28            True                      1.48
64  6333  4096                    409.11       105.10            296.07            True                      2.82
64  6333  6333                    229.98       108.46            230.59            True                      2.13
64  6333  8192                    469.32       112.24            330.58            True                      2.95
64  6333 12313                    245.02       117.16            244.84            True                      2.09
64  6333 16394                    317.78       125.80            318.37            True                      2.53
64  8192  4096                    323.42       139.92            267.31            True                      1.91
64  8192  6333                    167.51       118.45            167.56            True                      1.41
64  8192  8192                    341.13       146.71            284.88            True                      1.94
64  8192 12313                    172.21       123.42            171.97            True                      1.39
64  8192 16394                    217.22       153.18            216.99            True                      1.42
64 12313  4096                    482.19       123.32            311.82            True                      2.53
64 12313  6333                    238.73       123.88            238.66            True                      1.93
64 12313  8192                    516.32       122.11            330.50            True                      2.71
64 12313 12313                    248.73       125.32            296.82            True                      2.37
64 12313 16394                    314.98       134.06            320.31            True                      2.39
64 16394  4096                    476.59       154.58            340.84            True                      2.20
64 16394  6333                    240.54       119.60            214.82            True                      1.80
64 16394  8192                    501.36       149.02            359.45            True                      2.41
64 16394 12313                    244.65       126.01            222.47            True                      1.77
64 16394 16394                    302.48       160.36            283.66            True                      1.77
```

Pull Request resolved: pytorch#128232
Approved by: https://github.com/Chillee
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this pull request Jun 14, 2024
This PR introduces a heuristic for tuned_mixed_mm. The heuristic is only enabled on an A100, because it has only been tested on an A100, and it is only enabled if force_mixed_mm="heuristic".

I compared the heuristic to the aten fallback implementation and triton+autotune:
 Geometric mean speedup: 2.51
 ```
 m     n     k  triton + autotune (GB/s)  aten (GB/s)  heuristic (GB/s)  used_heuristic  speedup (heuristic/aten)
  1  4096  4096                    456.95       134.59            459.37            True                      3.41
  1  4096  8192                    523.93       138.29            553.50            True                      4.00
  1  4096 16394                    233.70       161.62            234.14            True                      1.45
  1  8192  4096                    633.25       140.64            574.86            True                      4.09
  1  8192  8192                    737.54       147.41            690.26            True                      4.68
  1  8192 16394                    413.67       175.88            408.68            True                      2.32
  1 16394  4096                    717.22       167.22            665.36            True                      3.98
  1 16394  8192                    812.69       177.17            815.90            True                      4.61
  1 16394 16394                    473.17       178.58            435.11            True                      2.44
  4  4096  4096                    479.46       134.80            486.74            True                      3.61
  4  4096  6333                    174.27       106.74            171.64            True                      1.61
  4  4096  8192                    567.14       138.32            571.09            True                      4.13
  4  4096 12313                    179.65       105.91            180.03            True                      1.70
  4  4096 16394                    222.96       145.54            222.81            True                      1.53
  4  6333  4096                    491.78       126.37            473.20            True                      3.74
  4  6333  6333                    268.79       143.40            269.75            True                      1.88
  4  6333  8192                    783.80       135.12            796.23            True                      5.89
  4  6333 12313                    286.35       142.37            287.30            True                      2.02
  4  6333 16394                    362.47       139.66            361.47            True                      2.59
  4  8192  4096                    642.73       140.53            641.88            True                      4.57
  4  8192  6333                    287.65       137.63            287.38            True                      2.09
  4  8192  8192                    738.42       150.16            721.59            True                      4.81
  4  8192 12313                    301.27       146.18            302.31            True                      2.07
  4  8192 16394                    415.37       167.66            393.41            True                      2.35
  4 12313  4096                    823.66       141.81            745.40            True                      5.26
  4 12313  6333                    433.92       148.17            429.83            True                      2.90
  4 12313  8192                    984.60       149.30            988.95            True                      6.62
  4 12313 12313                    452.00       150.87            452.50            True                      3.00
  4 12313 16394                    609.88       159.20            609.71            True                      3.83
  4 16394  4096                    779.44       157.46            777.10            True                      4.94
  4 16394  6333                    402.93       139.50            309.47            True                      2.22
  4 16394  8192                    950.38       175.49            949.67            True                      5.41
  4 16394 12313                    414.62       153.99            315.95            True                      2.05
  4 16394 16394                    497.56       174.97            461.77            True                      2.64
16  4096  4096                    475.92       134.45            478.57            True                      3.56
16  4096  6333                    146.36       112.50            145.35            True                      1.29
16  4096  8192                    560.00       138.22            557.19            True                      4.03
16  4096 12313                    152.02       105.06            151.27            True                      1.44
16  4096 16394                    222.48       156.72            222.88            True                      1.42
16  6333  4096                    692.41       122.14            696.88            True                      5.71
16  6333  6333                    220.74       140.90            225.41            True                      1.60
16  6333  8192                    813.56       140.21            820.28            True                      5.85
16  6333 12313                    232.48       131.19            232.55            True                      1.77
16  6333 16394                    367.39       134.93            361.87            True                      2.68
16  8192  4096                    665.54       140.29            266.24            True                      1.90
16  8192  6333                    254.77       136.65            240.12            True                      1.76
16  8192  8192                    750.63       146.26            736.93            True                      5.04
16  8192 12313                    266.61       127.13            251.81            True                      1.98
16  8192 16394                    397.25       160.42            390.76            True                      2.44
16 12313  4096                    857.48       141.36            851.36            True                      6.02
16 12313  6333                    423.21       132.40            357.55            True                      2.70
16 12313  8192                   1021.24       145.68           1024.60            True                      7.03
16 12313 12313                    370.12       143.94            383.52            True                      2.66
16 12313 16394                    608.52       141.03            608.48            True                      4.31
16 16394  4096                    826.48       155.94            826.74            True                      5.30
16 16394  6333                    420.38       144.09            265.23            True                      1.84
16 16394  8192                    988.07       156.21            984.63            True                      6.30
16 16394 12313                    431.40       146.92            265.49            True                      1.81
16 16394 16394                    497.39       167.86            461.79            True                      2.75
23  4096  4096                    344.43       132.84            338.64            True                      2.55
23  4096  6333                    195.34       118.48            195.31            True                      1.65
23  4096  8192                    389.83       140.02            376.62            True                      2.69
23  4096 12313                    204.49       137.96            204.80            True                      1.48
23  4096 16394                    242.48       148.99            242.74            True                      1.63
23  6333  4096                    429.25       126.52            517.75            True                      4.09
23  6333  6333                    295.56       133.51            296.14            True                      2.22
23  6333  8192                    594.88       137.05            581.78            True                      4.25
23  6333 12313                    315.18       131.67            314.64            True                      2.39
23  6333 16394                    386.46       141.45            386.54            True                      2.73
23  8192  4096                    553.52       142.05            568.35            True                      4.00
23  8192  6333                    215.58       139.01            210.86            True                      1.52
23  8192  8192                    609.21       154.85            528.76            True                      3.41
23  8192 12313                    220.38       142.93            233.54            True                      1.63
23  8192 16394                    402.63       158.39            403.21            True                      2.55
23 12313  4096                    723.54       131.58            581.94            True                      4.42
23 12313  6333                    307.90       131.58            307.90            True                      2.34
23 12313  8192                    893.36       129.97            623.72            True                      4.80
23 12313 12313                    322.40       134.84            317.80            True                      2.36
23 12313 16394                    512.97       142.31            409.45            True                      2.88
23 16394  4096                    703.66       154.54            643.53            True                      4.16
23 16394  6333                    305.55       127.55            293.17            True                      2.30
23 16394  8192                    768.12       154.60            681.53            True                      4.41
23 16394 12313                    311.61       140.92            307.01            True                      2.18
23 16394 16394                    467.24       171.07            467.29            True                      2.73
32  4096  4096                    344.71       132.30            338.62            True                      2.56
32  4096  6333                    206.48       107.59            205.55            True                      1.91
32  4096  8192                    387.24       137.82            353.12            True                      2.56
32  4096 12313                    216.35       120.61            214.50            True                      1.78
32  4096 16394                    242.05       149.92            241.94            True                      1.61
32  6333  4096                    525.50       127.12            518.02            True                      4.08
32  6333  6333                    300.50       118.41            296.55            True                      2.50
32  6333  8192                    600.92       136.99            601.94            True                      4.39
32  6333 12313                    316.13       136.45            316.03            True                      2.32
32  6333 16394                    386.11       141.34            386.10            True                      2.73
32  8192  4096                    546.18       140.18            341.14            True                      2.43
32  8192  6333                    218.40       130.65            263.42            True                      2.02
32  8192  8192                    608.29       147.16            542.12            True                      3.68
32  8192 12313                    225.60       135.04            225.23            True                      1.67
32  8192 16394                    434.75       160.42            401.28            True                      2.50
32 12313  4096                    787.80       136.28            583.60            True                      4.28
32 12313  6333                    316.66       125.76            323.35            True                      2.57
32 12313  8192                    891.38       128.88            639.50            True                      4.96
32 12313 12313                    326.11       132.37            325.88            True                      2.46
32 12313 16394                    521.64       139.47            395.69            True                      2.84
32 16394  4096                    625.55       158.46            651.16            True                      4.11
32 16394  6333                    304.14       131.13            284.55            True                      2.17
32 16394  8192                    767.79       162.95            704.34            True                      4.32
32 16394 12313                    310.74       137.68            303.39            True                      2.20
32 16394 16394                    465.92       171.43            465.37            True                      2.71
43  4096  4096                    345.05       133.87            196.47            True                      1.47
43  4096  6333                    148.64        99.92            148.97            True                      1.49
43  4096  8192                    386.50       135.39            214.00            True                      1.58
43  4096 12313                    190.39       109.36            156.27            True                      1.43
43  4096 16394                    203.63       150.24            204.05            True                      1.36
43  6333  4096                    421.35       106.04            132.25            True                      1.25
43  6333  6333                    224.75       113.01            224.97            True                      1.99
43  6333  8192                    471.11       117.61            327.39            True                      2.78
43  6333 12313                    234.55       115.61            234.74            True                      2.03
43  6333 16394                    311.56       132.24            312.01            True                      2.36
43  8192  4096                    400.73       140.12            269.11            True                      1.92
43  8192  6333                    167.32       119.13            168.84            True                      1.42
43  8192  8192                    435.45       146.98            286.21            True                      1.95
43  8192 12313                    161.05       127.82            162.78            True                      1.27
43  8192 16394                    207.16       156.40            208.90            True                      1.34
43 12313  4096                    484.01       120.10            313.35            True                      2.61
43 12313  6333                    234.54       106.63            232.85            True                      2.18
43 12313  8192                    515.34       130.23            411.70            True                      3.16
43 12313 12313                    239.39       130.04            239.03            True                      1.84
43 12313 16394                    316.02       137.39            316.29            True                      2.30
43 16394  4096                    475.60       152.57            340.97            True                      2.23
43 16394  6333                    241.21       132.49            208.59            True                      1.57
43 16394  8192                    499.34       157.43            361.61            True                      2.30
43 16394 12313                    246.25       132.31            211.68            True                      1.60
43 16394 16394                    302.90       158.56            277.05            True                      1.75
64  4096  4096                    280.48       126.82            195.97            True                      1.55
64  4096  6333                    150.94       101.63            150.48            True                      1.48
64  4096  8192                    305.47       135.06            211.03            True                      1.56
64  4096 12313                    158.12       110.06            158.15            True                      1.44
64  4096 16394                    206.68       136.21            201.28            True                      1.48
64  6333  4096                    409.11       105.10            296.07            True                      2.82
64  6333  6333                    229.98       108.46            230.59            True                      2.13
64  6333  8192                    469.32       112.24            330.58            True                      2.95
64  6333 12313                    245.02       117.16            244.84            True                      2.09
64  6333 16394                    317.78       125.80            318.37            True                      2.53
64  8192  4096                    323.42       139.92            267.31            True                      1.91
64  8192  6333                    167.51       118.45            167.56            True                      1.41
64  8192  8192                    341.13       146.71            284.88            True                      1.94
64  8192 12313                    172.21       123.42            171.97            True                      1.39
64  8192 16394                    217.22       153.18            216.99            True                      1.42
64 12313  4096                    482.19       123.32            311.82            True                      2.53
64 12313  6333                    238.73       123.88            238.66            True                      1.93
64 12313  8192                    516.32       122.11            330.50            True                      2.71
64 12313 12313                    248.73       125.32            296.82            True                      2.37
64 12313 16394                    314.98       134.06            320.31            True                      2.39
64 16394  4096                    476.59       154.58            340.84            True                      2.20
64 16394  6333                    240.54       119.60            214.82            True                      1.80
64 16394  8192                    501.36       149.02            359.45            True                      2.41
64 16394 12313                    244.65       126.01            222.47            True                      1.77
64 16394 16394                    302.48       160.36            283.66            True                      1.77
```

Pull Request resolved: pytorch#128232
Approved by: https://github.com/Chillee
@msaroufim
Copy link
Member

msaroufim commented Jun 15, 2024

This PR caused a breaking change on the nightlies for ao since it specifically deprecated force_mixed_mm = False - We'll forward fix this but we should really start thinking about promoting some inductor configs to a public namespace

@Chillee
Copy link
Collaborator

Chillee commented Jun 17, 2024

Oh hmm... I actually did a search for any uses of force_mixed_mm, but I missed it because the preview for torchao doesn't show it (and it's only in the tests and not any other file).

image

@github-actions github-actions bot deleted the mixed-mm-heuristic branch July 18, 2024 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants