Skip to content

Conversation

tenpercent
Copy link
Collaborator

@tenpercent tenpercent commented May 3, 2024

This PR adds an alternative backend for Inductor, adding Composable Kernel Universal GEMM instances to the autotune instance selection.

The implementation is heavily influenced by the series of PRs which adds CUTLASS backend (#106991). The main differences are
(1) customizing compiler for the ROCm platform
(2) customizing template code generation for Composable Kernel Universal GEMM instances.

We provide config tuning knobs for balancing between instance sources compilation time and finding the best instance.

Testing

Install the ck library

pip install git+https://github.com/rocm/composable_kernel@develop

Run the test

TORCH_LOGS=+torch._inductor \
pytest --capture=tee-sys test/inductor/test_ck_backend.py

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@tenpercent
Copy link
Collaborator Author

@pytorchbot drci

@tenpercent
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 5, 2024
…#130576)

Add functional support for torch.addmm with CK backend. See also #125453

# Implementation details
1. It turns out we can use the same template between addmm and matmul; essentially, matmul is addmm with empty bias
2. The Python generator in CK was updated to generate the shared cpp template. The pip package can be installed from `pip install git+https://github.com/rocm/composable_kernel@add-addmm` and will be merged into `develop` branch after this PR lands to avoid breaking the current matmul

# Testing
`pytest test/inductor/test_ck_backend.py -k addmm`

Pull Request resolved: #130576
Approved by: https://github.com/chenyang78
pytorchmergebot pushed a commit that referenced this pull request Aug 16, 2024
… autotune (#133285)

This PR enables dynamic shapes for the CK backend for gemm max autotune (see #125453).

This is achieved via unhardcoding the problem sizes from the template body and passing them as parameters instead.

We handle passing the problem sizes for the kernel call as well as for the benchmark call.

# Testing

`pytest test/inductor/test_ck_backend.py [-k dynamic]`

Pull Request resolved: #133285
Approved by: https://github.com/ColinPeppler
pytorchmergebot pushed a commit that referenced this pull request Aug 27, 2024
MakeArgument signature was changed in ROCm/composable_kernel#1453 adding splitK argument to universal gemm templates which are used to codegen addmm and matmul

(part of the series started at #125453 )

# Testing
`pytest test/inductor/test_ck_backend.py`

Pull Request resolved: #134483
Approved by: https://github.com/ColinPeppler
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…34483)

MakeArgument signature was changed in ROCm/composable_kernel#1453 adding splitK argument to universal gemm templates which are used to codegen addmm and matmul

(part of the series started at pytorch#125453 )

# Testing
`pytest test/inductor/test_ck_backend.py`

Pull Request resolved: pytorch#134483
Approved by: https://github.com/ColinPeppler
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 26, 2024
… autotune (pytorch#133285)

This PR enables dynamic shapes for the CK backend for gemm max autotune (see pytorch#125453).

This is achieved via unhardcoding the problem sizes from the template body and passing them as parameters instead.

We handle passing the problem sizes for the kernel call as well as for the benchmark call.

# Testing

`pytest test/inductor/test_ck_backend.py [-k dynamic]`

Pull Request resolved: pytorch#133285
Approved by: https://github.com/ColinPeppler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: rocm AMD GPU support for Pytorch open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants