-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Introduce heuristic for mixed_mm on A100 #128232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128232
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9db041a with merge base 7b9c5e0 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
e56e62d
to
6ed50d1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR introduces a heuristic for tuned_mixed_mm. The heuristic is only enabled on an A100, because it has only been tested on an A100, and it is only enabled if force_mixed_mm="heuristic". I compared the heuristic to the aten fallback implementation and triton+autotune: Geometric mean speedup: 2.51 ``` m n k triton + autotune (GB/s) aten (GB/s) heuristic (GB/s) used_heuristic speedup (heuristic/aten) 1 4096 4096 456.95 134.59 459.37 True 3.41 1 4096 8192 523.93 138.29 553.50 True 4.00 1 4096 16394 233.70 161.62 234.14 True 1.45 1 8192 4096 633.25 140.64 574.86 True 4.09 1 8192 8192 737.54 147.41 690.26 True 4.68 1 8192 16394 413.67 175.88 408.68 True 2.32 1 16394 4096 717.22 167.22 665.36 True 3.98 1 16394 8192 812.69 177.17 815.90 True 4.61 1 16394 16394 473.17 178.58 435.11 True 2.44 4 4096 4096 479.46 134.80 486.74 True 3.61 4 4096 6333 174.27 106.74 171.64 True 1.61 4 4096 8192 567.14 138.32 571.09 True 4.13 4 4096 12313 179.65 105.91 180.03 True 1.70 4 4096 16394 222.96 145.54 222.81 True 1.53 4 6333 4096 491.78 126.37 473.20 True 3.74 4 6333 6333 268.79 143.40 269.75 True 1.88 4 6333 8192 783.80 135.12 796.23 True 5.89 4 6333 12313 286.35 142.37 287.30 True 2.02 4 6333 16394 362.47 139.66 361.47 True 2.59 4 8192 4096 642.73 140.53 641.88 True 4.57 4 8192 6333 287.65 137.63 287.38 True 2.09 4 8192 8192 738.42 150.16 721.59 True 4.81 4 8192 12313 301.27 146.18 302.31 True 2.07 4 8192 16394 415.37 167.66 393.41 True 2.35 4 12313 4096 823.66 141.81 745.40 True 5.26 4 12313 6333 433.92 148.17 429.83 True 2.90 4 12313 8192 984.60 149.30 988.95 True 6.62 4 12313 12313 452.00 150.87 452.50 True 3.00 4 12313 16394 609.88 159.20 609.71 True 3.83 4 16394 4096 779.44 157.46 777.10 True 4.94 4 16394 6333 402.93 139.50 309.47 True 2.22 4 16394 8192 950.38 175.49 949.67 True 5.41 4 16394 12313 414.62 153.99 315.95 True 2.05 4 16394 16394 497.56 174.97 461.77 True 2.64 16 4096 4096 475.92 134.45 478.57 True 3.56 16 4096 6333 146.36 112.50 145.35 True 1.29 16 4096 8192 560.00 138.22 557.19 True 4.03 16 4096 12313 152.02 105.06 151.27 True 1.44 16 4096 16394 222.48 156.72 222.88 True 1.42 16 6333 4096 692.41 122.14 696.88 True 5.71 16 6333 6333 220.74 140.90 225.41 True 1.60 16 6333 8192 813.56 140.21 820.28 True 5.85 16 6333 12313 232.48 131.19 232.55 True 1.77 16 6333 16394 367.39 134.93 361.87 True 2.68 16 8192 4096 665.54 140.29 266.24 True 1.90 16 8192 6333 254.77 136.65 240.12 True 1.76 16 8192 8192 750.63 146.26 736.93 True 5.04 16 8192 12313 266.61 127.13 251.81 True 1.98 16 8192 16394 397.25 160.42 390.76 True 2.44 16 12313 4096 857.48 141.36 851.36 True 6.02 16 12313 6333 423.21 132.40 357.55 True 2.70 16 12313 8192 1021.24 145.68 1024.60 True 7.03 16 12313 12313 370.12 143.94 383.52 True 2.66 16 12313 16394 608.52 141.03 608.48 True 4.31 16 16394 4096 826.48 155.94 826.74 True 5.30 16 16394 6333 420.38 144.09 265.23 True 1.84 16 16394 8192 988.07 156.21 984.63 True 6.30 16 16394 12313 431.40 146.92 265.49 True 1.81 16 16394 16394 497.39 167.86 461.79 True 2.75 23 4096 4096 344.43 132.84 338.64 True 2.55 23 4096 6333 195.34 118.48 195.31 True 1.65 23 4096 8192 389.83 140.02 376.62 True 2.69 23 4096 12313 204.49 137.96 204.80 True 1.48 23 4096 16394 242.48 148.99 242.74 True 1.63 23 6333 4096 429.25 126.52 517.75 True 4.09 23 6333 6333 295.56 133.51 296.14 True 2.22 23 6333 8192 594.88 137.05 581.78 True 4.25 23 6333 12313 315.18 131.67 314.64 True 2.39 23 6333 16394 386.46 141.45 386.54 True 2.73 23 8192 4096 553.52 142.05 568.35 True 4.00 23 8192 6333 215.58 139.01 210.86 True 1.52 23 8192 8192 609.21 154.85 528.76 True 3.41 23 8192 12313 220.38 142.93 233.54 True 1.63 23 8192 16394 402.63 158.39 403.21 True 2.55 23 12313 4096 723.54 131.58 581.94 True 4.42 23 12313 6333 307.90 131.58 307.90 True 2.34 23 12313 8192 893.36 129.97 623.72 True 4.80 23 12313 12313 322.40 134.84 317.80 True 2.36 23 12313 16394 512.97 142.31 409.45 True 2.88 23 16394 4096 703.66 154.54 643.53 True 4.16 23 16394 6333 305.55 127.55 293.17 True 2.30 23 16394 8192 768.12 154.60 681.53 True 4.41 23 16394 12313 311.61 140.92 307.01 True 2.18 23 16394 16394 467.24 171.07 467.29 True 2.73 32 4096 4096 344.71 132.30 338.62 True 2.56 32 4096 6333 206.48 107.59 205.55 True 1.91 32 4096 8192 387.24 137.82 353.12 True 2.56 32 4096 12313 216.35 120.61 214.50 True 1.78 32 4096 16394 242.05 149.92 241.94 True 1.61 32 6333 4096 525.50 127.12 518.02 True 4.08 32 6333 6333 300.50 118.41 296.55 True 2.50 32 6333 8192 600.92 136.99 601.94 True 4.39 32 6333 12313 316.13 136.45 316.03 True 2.32 32 6333 16394 386.11 141.34 386.10 True 2.73 32 8192 4096 546.18 140.18 341.14 True 2.43 32 8192 6333 218.40 130.65 263.42 True 2.02 32 8192 8192 608.29 147.16 542.12 True 3.68 32 8192 12313 225.60 135.04 225.23 True 1.67 32 8192 16394 434.75 160.42 401.28 True 2.50 32 12313 4096 787.80 136.28 583.60 True 4.28 32 12313 6333 316.66 125.76 323.35 True 2.57 32 12313 8192 891.38 128.88 639.50 True 4.96 32 12313 12313 326.11 132.37 325.88 True 2.46 32 12313 16394 521.64 139.47 395.69 True 2.84 32 16394 4096 625.55 158.46 651.16 True 4.11 32 16394 6333 304.14 131.13 284.55 True 2.17 32 16394 8192 767.79 162.95 704.34 True 4.32 32 16394 12313 310.74 137.68 303.39 True 2.20 32 16394 16394 465.92 171.43 465.37 True 2.71 43 4096 4096 345.05 133.87 196.47 True 1.47 43 4096 6333 148.64 99.92 148.97 True 1.49 43 4096 8192 386.50 135.39 214.00 True 1.58 43 4096 12313 190.39 109.36 156.27 True 1.43 43 4096 16394 203.63 150.24 204.05 True 1.36 43 6333 4096 421.35 106.04 132.25 True 1.25 43 6333 6333 224.75 113.01 224.97 True 1.99 43 6333 8192 471.11 117.61 327.39 True 2.78 43 6333 12313 234.55 115.61 234.74 True 2.03 43 6333 16394 311.56 132.24 312.01 True 2.36 43 8192 4096 400.73 140.12 269.11 True 1.92 43 8192 6333 167.32 119.13 168.84 True 1.42 43 8192 8192 435.45 146.98 286.21 True 1.95 43 8192 12313 161.05 127.82 162.78 True 1.27 43 8192 16394 207.16 156.40 208.90 True 1.34 43 12313 4096 484.01 120.10 313.35 True 2.61 43 12313 6333 234.54 106.63 232.85 True 2.18 43 12313 8192 515.34 130.23 411.70 True 3.16 43 12313 12313 239.39 130.04 239.03 True 1.84 43 12313 16394 316.02 137.39 316.29 True 2.30 43 16394 4096 475.60 152.57 340.97 True 2.23 43 16394 6333 241.21 132.49 208.59 True 1.57 43 16394 8192 499.34 157.43 361.61 True 2.30 43 16394 12313 246.25 132.31 211.68 True 1.60 43 16394 16394 302.90 158.56 277.05 True 1.75 64 4096 4096 280.48 126.82 195.97 True 1.55 64 4096 6333 150.94 101.63 150.48 True 1.48 64 4096 8192 305.47 135.06 211.03 True 1.56 64 4096 12313 158.12 110.06 158.15 True 1.44 64 4096 16394 206.68 136.21 201.28 True 1.48 64 6333 4096 409.11 105.10 296.07 True 2.82 64 6333 6333 229.98 108.46 230.59 True 2.13 64 6333 8192 469.32 112.24 330.58 True 2.95 64 6333 12313 245.02 117.16 244.84 True 2.09 64 6333 16394 317.78 125.80 318.37 True 2.53 64 8192 4096 323.42 139.92 267.31 True 1.91 64 8192 6333 167.51 118.45 167.56 True 1.41 64 8192 8192 341.13 146.71 284.88 True 1.94 64 8192 12313 172.21 123.42 171.97 True 1.39 64 8192 16394 217.22 153.18 216.99 True 1.42 64 12313 4096 482.19 123.32 311.82 True 2.53 64 12313 6333 238.73 123.88 238.66 True 1.93 64 12313 8192 516.32 122.11 330.50 True 2.71 64 12313 12313 248.73 125.32 296.82 True 2.37 64 12313 16394 314.98 134.06 320.31 True 2.39 64 16394 4096 476.59 154.58 340.84 True 2.20 64 16394 6333 240.54 119.60 214.82 True 1.80 64 16394 8192 501.36 149.02 359.45 True 2.41 64 16394 12313 244.65 126.01 222.47 True 1.77 64 16394 16394 302.48 160.36 283.66 True 1.77 ``` Pull Request resolved: pytorch#128232 Approved by: https://github.com/Chillee
This PR introduces a heuristic for tuned_mixed_mm. The heuristic is only enabled on an A100, because it has only been tested on an A100, and it is only enabled if force_mixed_mm="heuristic". I compared the heuristic to the aten fallback implementation and triton+autotune: Geometric mean speedup: 2.51 ``` m n k triton + autotune (GB/s) aten (GB/s) heuristic (GB/s) used_heuristic speedup (heuristic/aten) 1 4096 4096 456.95 134.59 459.37 True 3.41 1 4096 8192 523.93 138.29 553.50 True 4.00 1 4096 16394 233.70 161.62 234.14 True 1.45 1 8192 4096 633.25 140.64 574.86 True 4.09 1 8192 8192 737.54 147.41 690.26 True 4.68 1 8192 16394 413.67 175.88 408.68 True 2.32 1 16394 4096 717.22 167.22 665.36 True 3.98 1 16394 8192 812.69 177.17 815.90 True 4.61 1 16394 16394 473.17 178.58 435.11 True 2.44 4 4096 4096 479.46 134.80 486.74 True 3.61 4 4096 6333 174.27 106.74 171.64 True 1.61 4 4096 8192 567.14 138.32 571.09 True 4.13 4 4096 12313 179.65 105.91 180.03 True 1.70 4 4096 16394 222.96 145.54 222.81 True 1.53 4 6333 4096 491.78 126.37 473.20 True 3.74 4 6333 6333 268.79 143.40 269.75 True 1.88 4 6333 8192 783.80 135.12 796.23 True 5.89 4 6333 12313 286.35 142.37 287.30 True 2.02 4 6333 16394 362.47 139.66 361.47 True 2.59 4 8192 4096 642.73 140.53 641.88 True 4.57 4 8192 6333 287.65 137.63 287.38 True 2.09 4 8192 8192 738.42 150.16 721.59 True 4.81 4 8192 12313 301.27 146.18 302.31 True 2.07 4 8192 16394 415.37 167.66 393.41 True 2.35 4 12313 4096 823.66 141.81 745.40 True 5.26 4 12313 6333 433.92 148.17 429.83 True 2.90 4 12313 8192 984.60 149.30 988.95 True 6.62 4 12313 12313 452.00 150.87 452.50 True 3.00 4 12313 16394 609.88 159.20 609.71 True 3.83 4 16394 4096 779.44 157.46 777.10 True 4.94 4 16394 6333 402.93 139.50 309.47 True 2.22 4 16394 8192 950.38 175.49 949.67 True 5.41 4 16394 12313 414.62 153.99 315.95 True 2.05 4 16394 16394 497.56 174.97 461.77 True 2.64 16 4096 4096 475.92 134.45 478.57 True 3.56 16 4096 6333 146.36 112.50 145.35 True 1.29 16 4096 8192 560.00 138.22 557.19 True 4.03 16 4096 12313 152.02 105.06 151.27 True 1.44 16 4096 16394 222.48 156.72 222.88 True 1.42 16 6333 4096 692.41 122.14 696.88 True 5.71 16 6333 6333 220.74 140.90 225.41 True 1.60 16 6333 8192 813.56 140.21 820.28 True 5.85 16 6333 12313 232.48 131.19 232.55 True 1.77 16 6333 16394 367.39 134.93 361.87 True 2.68 16 8192 4096 665.54 140.29 266.24 True 1.90 16 8192 6333 254.77 136.65 240.12 True 1.76 16 8192 8192 750.63 146.26 736.93 True 5.04 16 8192 12313 266.61 127.13 251.81 True 1.98 16 8192 16394 397.25 160.42 390.76 True 2.44 16 12313 4096 857.48 141.36 851.36 True 6.02 16 12313 6333 423.21 132.40 357.55 True 2.70 16 12313 8192 1021.24 145.68 1024.60 True 7.03 16 12313 12313 370.12 143.94 383.52 True 2.66 16 12313 16394 608.52 141.03 608.48 True 4.31 16 16394 4096 826.48 155.94 826.74 True 5.30 16 16394 6333 420.38 144.09 265.23 True 1.84 16 16394 8192 988.07 156.21 984.63 True 6.30 16 16394 12313 431.40 146.92 265.49 True 1.81 16 16394 16394 497.39 167.86 461.79 True 2.75 23 4096 4096 344.43 132.84 338.64 True 2.55 23 4096 6333 195.34 118.48 195.31 True 1.65 23 4096 8192 389.83 140.02 376.62 True 2.69 23 4096 12313 204.49 137.96 204.80 True 1.48 23 4096 16394 242.48 148.99 242.74 True 1.63 23 6333 4096 429.25 126.52 517.75 True 4.09 23 6333 6333 295.56 133.51 296.14 True 2.22 23 6333 8192 594.88 137.05 581.78 True 4.25 23 6333 12313 315.18 131.67 314.64 True 2.39 23 6333 16394 386.46 141.45 386.54 True 2.73 23 8192 4096 553.52 142.05 568.35 True 4.00 23 8192 6333 215.58 139.01 210.86 True 1.52 23 8192 8192 609.21 154.85 528.76 True 3.41 23 8192 12313 220.38 142.93 233.54 True 1.63 23 8192 16394 402.63 158.39 403.21 True 2.55 23 12313 4096 723.54 131.58 581.94 True 4.42 23 12313 6333 307.90 131.58 307.90 True 2.34 23 12313 8192 893.36 129.97 623.72 True 4.80 23 12313 12313 322.40 134.84 317.80 True 2.36 23 12313 16394 512.97 142.31 409.45 True 2.88 23 16394 4096 703.66 154.54 643.53 True 4.16 23 16394 6333 305.55 127.55 293.17 True 2.30 23 16394 8192 768.12 154.60 681.53 True 4.41 23 16394 12313 311.61 140.92 307.01 True 2.18 23 16394 16394 467.24 171.07 467.29 True 2.73 32 4096 4096 344.71 132.30 338.62 True 2.56 32 4096 6333 206.48 107.59 205.55 True 1.91 32 4096 8192 387.24 137.82 353.12 True 2.56 32 4096 12313 216.35 120.61 214.50 True 1.78 32 4096 16394 242.05 149.92 241.94 True 1.61 32 6333 4096 525.50 127.12 518.02 True 4.08 32 6333 6333 300.50 118.41 296.55 True 2.50 32 6333 8192 600.92 136.99 601.94 True 4.39 32 6333 12313 316.13 136.45 316.03 True 2.32 32 6333 16394 386.11 141.34 386.10 True 2.73 32 8192 4096 546.18 140.18 341.14 True 2.43 32 8192 6333 218.40 130.65 263.42 True 2.02 32 8192 8192 608.29 147.16 542.12 True 3.68 32 8192 12313 225.60 135.04 225.23 True 1.67 32 8192 16394 434.75 160.42 401.28 True 2.50 32 12313 4096 787.80 136.28 583.60 True 4.28 32 12313 6333 316.66 125.76 323.35 True 2.57 32 12313 8192 891.38 128.88 639.50 True 4.96 32 12313 12313 326.11 132.37 325.88 True 2.46 32 12313 16394 521.64 139.47 395.69 True 2.84 32 16394 4096 625.55 158.46 651.16 True 4.11 32 16394 6333 304.14 131.13 284.55 True 2.17 32 16394 8192 767.79 162.95 704.34 True 4.32 32 16394 12313 310.74 137.68 303.39 True 2.20 32 16394 16394 465.92 171.43 465.37 True 2.71 43 4096 4096 345.05 133.87 196.47 True 1.47 43 4096 6333 148.64 99.92 148.97 True 1.49 43 4096 8192 386.50 135.39 214.00 True 1.58 43 4096 12313 190.39 109.36 156.27 True 1.43 43 4096 16394 203.63 150.24 204.05 True 1.36 43 6333 4096 421.35 106.04 132.25 True 1.25 43 6333 6333 224.75 113.01 224.97 True 1.99 43 6333 8192 471.11 117.61 327.39 True 2.78 43 6333 12313 234.55 115.61 234.74 True 2.03 43 6333 16394 311.56 132.24 312.01 True 2.36 43 8192 4096 400.73 140.12 269.11 True 1.92 43 8192 6333 167.32 119.13 168.84 True 1.42 43 8192 8192 435.45 146.98 286.21 True 1.95 43 8192 12313 161.05 127.82 162.78 True 1.27 43 8192 16394 207.16 156.40 208.90 True 1.34 43 12313 4096 484.01 120.10 313.35 True 2.61 43 12313 6333 234.54 106.63 232.85 True 2.18 43 12313 8192 515.34 130.23 411.70 True 3.16 43 12313 12313 239.39 130.04 239.03 True 1.84 43 12313 16394 316.02 137.39 316.29 True 2.30 43 16394 4096 475.60 152.57 340.97 True 2.23 43 16394 6333 241.21 132.49 208.59 True 1.57 43 16394 8192 499.34 157.43 361.61 True 2.30 43 16394 12313 246.25 132.31 211.68 True 1.60 43 16394 16394 302.90 158.56 277.05 True 1.75 64 4096 4096 280.48 126.82 195.97 True 1.55 64 4096 6333 150.94 101.63 150.48 True 1.48 64 4096 8192 305.47 135.06 211.03 True 1.56 64 4096 12313 158.12 110.06 158.15 True 1.44 64 4096 16394 206.68 136.21 201.28 True 1.48 64 6333 4096 409.11 105.10 296.07 True 2.82 64 6333 6333 229.98 108.46 230.59 True 2.13 64 6333 8192 469.32 112.24 330.58 True 2.95 64 6333 12313 245.02 117.16 244.84 True 2.09 64 6333 16394 317.78 125.80 318.37 True 2.53 64 8192 4096 323.42 139.92 267.31 True 1.91 64 8192 6333 167.51 118.45 167.56 True 1.41 64 8192 8192 341.13 146.71 284.88 True 1.94 64 8192 12313 172.21 123.42 171.97 True 1.39 64 8192 16394 217.22 153.18 216.99 True 1.42 64 12313 4096 482.19 123.32 311.82 True 2.53 64 12313 6333 238.73 123.88 238.66 True 1.93 64 12313 8192 516.32 122.11 330.50 True 2.71 64 12313 12313 248.73 125.32 296.82 True 2.37 64 12313 16394 314.98 134.06 320.31 True 2.39 64 16394 4096 476.59 154.58 340.84 True 2.20 64 16394 6333 240.54 119.60 214.82 True 1.80 64 16394 8192 501.36 149.02 359.45 True 2.41 64 16394 12313 244.65 126.01 222.47 True 1.77 64 16394 16394 302.48 160.36 283.66 True 1.77 ``` Pull Request resolved: pytorch#128232 Approved by: https://github.com/Chillee
This PR caused a breaking change on the nightlies for ao since it specifically deprecated |
This PR introduces a heuristic for tuned_mixed_mm. The heuristic is only enabled on an A100, because it has only been tested on an A100, and it is only enabled if force_mixed_mm="heuristic".
I compared the heuristic to the aten fallback implementation and triton+autotune:
Geometric mean speedup: 2.51
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang