Skip to content

Conversation

eellison
Copy link
Contributor

@eellison eellison commented Mar 15, 2024

Stack from ghstack (oldest at bottom):

Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

@triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Mar 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/121998

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit f4cbd2d with merge base 5891c5b (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eellison eellison changed the title WIP precompile triton templates Precompile triton templates Mar 18, 2024
Before this PR we were precompiling triton templates in parallel. We would compile as part of benchmarking so compilation was not parallelized. 

Triton benchmarking templates were emitted as :

```
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

[ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

[ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

[ghstack-poisoned]
@eellison eellison mentioned this pull request Mar 19, 2024
@shunting314
Copy link
Contributor

Trying to understand more, why the full kernel specification is required if we want to precompile in parallel?

kernel_name=kernel_name,
output_node=fake_out,
use_jit=True,
use_jit=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this, is there any remaining cases that set use_jit to True? If not, we can probably remove this argument

@eellison
Copy link
Contributor Author

@shunting314 in order to compile with JIT you need actually invoke it. There is a "warmup_only" arg in triton we could use, however we'd still need to provide real tensor inputs.

In the case that you have arbitrarily many mms you want to compile in parallel, then you are potentially allocating the arguments for all of them at once which is not feasible for memory.

We already generate the full argument specification when we codegen anyway, so it's better to reuse this then the alternative.

@shunting314
Copy link
Contributor

I see. You actually want the precompile ability of triton_heuristics.template

@eellison eellison added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 19, 2024
@huydhn
Copy link
Contributor

huydhn commented Mar 21, 2024

@pytorchbot revert -m 'Sorry for reverting your change but it is causing all ROCm trunk job to fail https://hud.pytorch.org/pytorch/pytorch/commit/b8df2f0ca530ebe01fa079c891c170a1f4b22823' -c nosignal

@huydhn huydhn added the ciflow/rocm Trigger "default" config CI on ROCm label Mar 21, 2024
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Mar 21, 2024
This reverts commit b8df2f0.

Reverted #121998 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is causing all ROCm trunk job to fail https://hud.pytorch.org/pytorch/pytorch/commit/b8df2f0ca530ebe01fa079c891c170a1f4b22823 ([comment](#121998 (comment)))
@pytorchmergebot
Copy link
Collaborator

@eellison your PR has been successfully reverted.

Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yoyoyocmu added a commit to yoyoyocmu/pytorch that referenced this pull request Apr 3, 2024
Summary:

We are reverting pytorch#121998 because the change plus `search-autotune-cache` led to significant compilation time increase, causing stuck job detector to trigger and then kill the training job.

Test Plan: CI tests

Reviewed By: nmacchioni

Differential Revision: D55712203
pytorchmergebot pushed a commit that referenced this pull request Apr 4, 2024
Summary: We are reverting #121998 because the change plus search-autotune-cache led to significant compilation time increase, causing stuck job detector to trigger and then kill the training job.

Test Plan:
CI tests

Reviewed By: nmacchioni

Differential Revision: D55712203

Pull Request resolved: #123305
Approved by: https://github.com/eellison, https://github.com/nmacchioni, https://github.com/xw285cornell
pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
@triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

Pull Request resolved: #121998
Approved by: https://github.com/jansel
ghstack dependencies: #121996, #120275, #121997
pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
This reverts commit b8df2f0.

Reverted #121998 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is causing all ROCm trunk job to fail https://hud.pytorch.org/pytorch/pytorch/commit/b8df2f0ca530ebe01fa079c891c170a1f4b22823 ([comment](#121998 (comment)))
sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
Summary: We are reverting pytorch#121998 because the change plus search-autotune-cache led to significant compilation time increase, causing stuck job detector to trigger and then kill the training job.

Test Plan:
CI tests

Reviewed By: nmacchioni

Differential Revision: D55712203

Pull Request resolved: pytorch#123305
Approved by: https://github.com/eellison, https://github.com/nmacchioni, https://github.com/xw285cornell
pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
Before this PR we were not precompiling triton templates in parallel. Compilation would occur during benchmarking.

Triton benchmarking templates were emitted as :

```
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

In order to precompile we need to give the full kernel specification, as we do when we emit the template in the final output code generation.

```
@triton_heuristics.template(
    num_stages=3,
    num_warps=8,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
    inductor_meta={'kernel_name': 'Placeholder.DESCRIPTIVE_NAME', 'backend_hash': 'cdeecfeccd31ad7810f96b5752194b1c2406d0a81e39a6ca09c8ee150baae183'},
)
@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
```

Pull Request resolved: #121998
Approved by: https://github.com/jansel
@github-actions github-actions bot deleted the gh/eellison/614/head branch April 25, 2024 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants