Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor CUTLASS backend] Step 4: CUDA (template) kernels #107931

Closed
wants to merge 31 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107931

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit e62717f with merge base f9a250c (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

return None


class ChoiceCaller:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small doc comment would be great: What's the purpose of this class / which problem does it solve? Is it supposed to have subclasses?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original code moved from select_algorithm.py. Let me add some comments.


class KernelTemplate:
"""
Base class for defining kernel templates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which kind of kernel templates? ( e.g. Triton / C++ / Cutlass / any involving Jinja templates )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.

def __init__(self, name: str):
self.name = name

def maybe_append_choice(self, choices, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the "choices" argument here? ( e.g. datatype and intended usage). A small doc comment would help clarify I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.

from ..cpp import CppPrinter, DTYPE_TO_CPP


cexpr = CppPrinter().doprint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might benefit from a type annotation and/or a small comment what it is, so we don't need to follow the link to CppPrinter when reading the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment, basically it's a print function.

"""
Kernels defined by the CUDA language.
"""
overrides = OpOverrides
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which consequences does this line have? It's a bit hard to read out of the inductor codegen, how exactly these overrides fields are used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to Inductor "define-by-run" IR design I think. If you check lowering.py, you could find that the lowered Loops contain code like ops.some_method(). These ops.some_method() are overridden by OpOverrides of each kernel. e.g. TritonKernel has its Triton overrides. CPPKernel has its CPP overrides.

The CUDAKernel here is not a general backend since it only supports templates now, so OpOverrides doesn't matter. However, with flexible epilogue fusion, it could be used to generate cutlass epilogue visitor tree code, so could be relevant here.

return "0"
return str(node.get_layout().offset)

def ptr(self, node: IRNode, default_node: IRNode = None) -> str:
Copy link
Contributor

@kadeng kadeng Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small doc comment maybe? Same for other methods here, like dtype, offset, call_kernel etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.

from typing import List

from ... import config
from ...codecache import code_hash, get_path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to use absolute imports here, as that eases code navigation tools ( both in IDEs and github code views for example )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most existing Inductor code uses relative imports. I feel it's easier to just follow existing coding style.

# import cutlass libs
import scripts as cutlass_lib

from ...autotune_process import CUDABenchmarkRequest, TensorMeta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, absolute imports would be preferable for most tools

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Copy link
Contributor Author

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kadeng !

return None


class ChoiceCaller:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original code moved from select_algorithm.py. Let me add some comments.

def __init__(self, name: str):
self.name = name

def maybe_append_choice(self, choices, **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.

from typing import List

from ... import config
from ...codecache import code_hash, get_path
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most existing Inductor code uses relative imports. I feel it's easier to just follow existing coding style.

from ..cpp import CppPrinter, DTYPE_TO_CPP


cexpr = CppPrinter().doprint
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment, basically it's a print function.

"""
Kernels defined by the CUDA language.
"""
overrides = OpOverrides
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to Inductor "define-by-run" IR design I think. If you check lowering.py, you could find that the lowered Loops contain code like ops.some_method(). These ops.some_method() are overridden by OpOverrides of each kernel. e.g. TritonKernel has its Triton overrides. CPPKernel has its CPP overrides.

The CUDAKernel here is not a general backend since it only supports templates now, so OpOverrides doesn't matter. However, with flexible epilogue fusion, it could be used to generate cutlass epilogue visitor tree code, so could be relevant here.

return "0"
return str(node.get_layout().offset)

def ptr(self, node: IRNode, default_node: IRNode = None) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.


class KernelTemplate:
"""
Base class for defining kernel templates.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.

@ipiszy ipiszy marked this pull request as ready for review August 27, 2023 01:39
@ipiszy ipiszy requested a review from jansel August 27, 2023 01:39
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Copy link
Contributor Author

@ipiszy ipiszy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @aakhundov ! Updated the PR, ptal.

Comment on lines +118 to +119
self.named_nodes[name] = node
self.args.input_buffers[node.get_name()] = name
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, dict order is useful when generating func args.

else list(range(len(self.input_nodes)))
)
expected_args = list(
unique(self.input_nodes[idx].get_name() for idx in input_reorder)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine. We want to dedup args in the function definition. It doesn't affect function implementation codegen. Let me add a unittest to verify.


class CUDATemplateKernel(CUDAKernel):
"""
Template kernels defined by the CUDA language.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I think it depends on how you define this. Maybe language extension is more accurate. I'd like to distinguish it with "CUDA platform" which also contains things like PTX, cubin, etc. Let me change it to "C++ CUDA".

if isinstance(node.node, ir.CUDATemplateBuffer):
from .codegen.cuda.cuda_scheduling import CUDAScheduling

CUDAScheduling(self).codegen_template(node, epilogue)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can_fuse() always returns False for CUDATemplateBuffer (in triton.py). It invokes Triton codegen_template() here to codegen only the non-epilogue part.


if node is None:
return None
return {**self.args.input_buffers, **self.args.output_buffers}.get(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OSS linter formats it in this way..

This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
Copy link
Contributor

@aakhundov aakhundov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ipiszy there are some CI jobs failing, worth checking as not filtered out as flaky / broken trunk. As the 3 PRs below this one in the stack are green and the one above is also red, my guess is that the root cause may be in this one.

This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 7, 2023

@ipiszy there are some CI jobs failing, worth checking as not filtered out as flaky / broken trunk. As the 3 PRs below this one in the stack are green and the one above is also red, my guess is that the root cause may be in this one.

Yes on it. There had been a bunch of test failures. I've fixed some and now there is something new. It's an iterative process..

This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
@ipiszy
Copy link
Contributor Author

ipiszy commented Sep 9, 2023

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Sep 9, 2023
input_reorder: The actual order of input nodes.
e.g. The template might have input argument defined as [X, W, Bias],
and the actual input passed into this template could be [Bias, X, W].
In this case, the `input_reorder` would be [2, 0, 1].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why we allow different orders for the actual input and the template declaration in the first place? Could we enforce the actual input to have the same order specified by the template?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume, for template front-end similarity with the existing Triton templates?

This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
This is the step 4 to add cutlass as an alternative inductor backend.
Full tests can be found from the last PR in the stack.

Feature request: #106991.




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 12, 2023
This is the step 5 to add cutlass as an alternative inductor backend.

Feature request: #106991.

Pull Request resolved: #108015
Approved by: https://github.com/kadeng, https://github.com/jansel, https://github.com/aakhundov
ghstack dependencies: #107802, #107847, #107901, #107931
@facebook-github-bot facebook-github-bot deleted the gh/ipiszy@gmail.com/4/head branch September 16, 2023 14:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants