-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Precompile triton templates #121998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Precompile triton templates #121998
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121998
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit f4cbd2d with merge base 5891c5b ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Before this PR we were precompiling triton templates in parallel. We would compile as part of benchmarking so compilation was not parallelized. Triton benchmarking templates were emitted as : ``` triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` [ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` [ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` [ghstack-poisoned]
Trying to understand more, why the full kernel specification is required if we want to precompile in parallel? |
kernel_name=kernel_name, | ||
output_node=fake_out, | ||
use_jit=True, | ||
use_jit=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this, is there any remaining cases that set use_jit
to True? If not, we can probably remove this argument
@shunting314 in order to compile with JIT you need actually invoke it. There is a "warmup_only" arg in triton we could use, however we'd still need to provide real tensor inputs. In the case that you have arbitrarily many mms you want to compile in parallel, then you are potentially allocating the arguments for all of them at once which is not feasible for memory. We already generate the full argument specification when we codegen anyway, so it's better to reuse this then the alternative. |
I see. You actually want the precompile ability of triton_heuristics.template |
@pytorchbot revert -m 'Sorry for reverting your change but it is causing all ROCm trunk job to fail https://hud.pytorch.org/pytorch/pytorch/commit/b8df2f0ca530ebe01fa079c891c170a1f4b22823' -c nosignal |
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit b8df2f0. Reverted #121998 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is causing all ROCm trunk job to fail https://hud.pytorch.org/pytorch/pytorch/commit/b8df2f0ca530ebe01fa079c891c170a1f4b22823 ([comment](#121998 (comment)))
@eellison your PR has been successfully reverted. |
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: We are reverting pytorch#121998 because the change plus `search-autotune-cache` led to significant compilation time increase, causing stuck job detector to trigger and then kill the training job. Test Plan: CI tests Reviewed By: nmacchioni Differential Revision: D55712203
Summary: We are reverting #121998 because the change plus search-autotune-cache led to significant compilation time increase, causing stuck job detector to trigger and then kill the training job. Test Plan: CI tests Reviewed By: nmacchioni Differential Revision: D55712203 Pull Request resolved: #123305 Approved by: https://github.com/eellison, https://github.com/nmacchioni, https://github.com/xw285cornell
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` @triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` @triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) @triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` Pull Request resolved: #121998 Approved by: https://github.com/jansel ghstack dependencies: #121996, #120275, #121997
This reverts commit b8df2f0. Reverted #121998 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is causing all ROCm trunk job to fail https://hud.pytorch.org/pytorch/pytorch/commit/b8df2f0ca530ebe01fa079c891c170a1f4b22823 ([comment](#121998 (comment)))
Summary: We are reverting pytorch#121998 because the change plus search-autotune-cache led to significant compilation time increase, causing stuck job detector to trigger and then kill the training job. Test Plan: CI tests Reviewed By: nmacchioni Differential Revision: D55712203 Pull Request resolved: pytorch#123305 Approved by: https://github.com/eellison, https://github.com/nmacchioni, https://github.com/xw285cornell
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking. Triton benchmarking templates were emitted as : ``` @triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation. ``` @triton_heuristics.template( num_stages=3, num_warps=8, triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]}, inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'}, ) @triton.jit def triton_mm(arg_A, arg_B, out_ptr0): ``` Pull Request resolved: #121998 Approved by: https://github.com/jansel
Stack from ghstack (oldest at bottom):
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.
Triton benchmarking templates were emitted as :
In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang