-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor CUTLASS backend] Step 4: CUDA (template) kernels #107931
Changes from 27 commits
75cfaa0
0c6a52d
c0b885d
a96e751
5ad9fee
51fa22c
701bbee
9949e3f
46da92b
d4008e5
44bee1c
c10e787
c75608d
8530cde
0f879a0
43912cb
c638a77
0b54e9c
e4246c5
67d8bb9
052e685
97015e0
558ee89
6d6b0da
836ae1f
04ea177
9183a36
ea84a0d
49c919d
e6ec8d5
e62717f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,8 @@ | |
from .. import metrics | ||
from ..utils import ( | ||
DeferredLineBase, | ||
do_bench_using_profiling, | ||
free_symbol_startswith, | ||
get_sympy_Expr_dtype, | ||
IndentedBuffer, | ||
sympy_dot, | ||
sympy_subs, | ||
|
@@ -560,17 +560,6 @@ def wrap_size_arg(self, size): | |
def cpp_argdefs(self): | ||
from .cpp import DTYPE_TO_CPP, INDEX_TYPE | ||
|
||
# TODO(jansel): replace this with data from scheduler | ||
buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers} | ||
for name, val in V.graph.graph_inputs.items(): | ||
if isinstance(val, sympy.Expr): | ||
buffer_types[name] = get_sympy_Expr_dtype(val) | ||
else: | ||
buffer_types[name] = val.get_dtype() | ||
buffer_types.update( | ||
{name: val.dtype for name, val in V.graph.constants.items()} | ||
) | ||
|
||
call_args = [] | ||
arg_defs = [] | ||
arg_types = [] | ||
|
@@ -579,23 +568,23 @@ def cpp_argdefs(self): | |
continue | ||
outer = inplaced.other_names[-1] | ||
inner = inplaced.inner_name | ||
dtype = buffer_types[outer] | ||
dtype = V.graph.get_dtype(outer) | ||
cpp_dtype = DTYPE_TO_CPP[dtype] | ||
arg_defs.append(f"{cpp_dtype}* {inner}") | ||
call_args.append(self.wrap_ptr_arg(outer, dtype)) | ||
arg_types.append(f"{cpp_dtype}*") | ||
for outer, inner in self.input_buffers.items(): | ||
if outer in self.inplace_buffers: | ||
continue | ||
dtype = buffer_types[outer] | ||
dtype = V.graph.get_dtype(outer) | ||
cpp_dtype = DTYPE_TO_CPP[dtype] | ||
arg_defs.append(f"const {cpp_dtype}* {inner}") | ||
call_args.append(self.wrap_ptr_arg(outer, dtype)) | ||
arg_types.append(f"const {cpp_dtype}*") | ||
for outer, inner in self.output_buffers.items(): | ||
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): | ||
continue | ||
dtype = buffer_types[outer] | ||
dtype = V.graph.get_dtype(outer) | ||
cpp_dtype = DTYPE_TO_CPP[dtype] | ||
arg_defs.append(f"{cpp_dtype}* {inner}") | ||
call_args.append(self.wrap_ptr_arg(outer, dtype)) | ||
|
@@ -1045,3 +1034,96 @@ class OptimizationContext: | |
|
||
# Load uint8 value as float32 | ||
is_load_uint8_as_float: bool = False | ||
|
||
|
||
@functools.lru_cache(None) | ||
def jinja2_env(): | ||
try: | ||
import jinja2 | ||
|
||
return jinja2.Environment( | ||
undefined=jinja2.StrictUndefined, | ||
) | ||
except ImportError: | ||
return None | ||
|
||
|
||
class ChoiceCaller: | ||
""" | ||
Represents a possible choice used in autotune_process.py. | ||
During autotuning, self.benchmark() is first called to get benchmark result, | ||
and if this choice is selected, self.output_node() is called to get the output_node. | ||
|
||
Children classes: TritonTemplateCaller, CUDATemplateCaller. | ||
""" | ||
|
||
def __init__(self, name, input_nodes, layout): | ||
super().__init__() | ||
self.name = name | ||
self.layout = layout | ||
self.input_nodes = input_nodes | ||
|
||
def benchmark(self, *args, out) -> float: | ||
algo = self.to_callable() | ||
return do_bench_using_profiling(lambda: algo(*args, out=out)) | ||
|
||
def call_name(self) -> str: | ||
raise NotImplementedError() | ||
|
||
def to_callable(self): | ||
raise NotImplementedError() | ||
|
||
def hash_key(self) -> str: | ||
raise NotImplementedError() | ||
|
||
def output_node(self) -> "TensorBox": # type: ignore[name-defined] | ||
raise NotImplementedError() | ||
|
||
|
||
class KernelTemplate: | ||
""" | ||
Base class for defining kernel templates. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which kind of kernel templates? ( e.g. Triton / C++ / Cutlass / any involving Jinja templates ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some comments. |
||
|
||
Children classes: TritonTemplate, CUDATemplate | ||
""" | ||
|
||
@staticmethod | ||
def _template_from_string(source): | ||
env = jinja2_env() | ||
if env is not None: | ||
return env.from_string(source) | ||
return None | ||
|
||
@staticmethod | ||
def _fake_get_dtype(fake_out): | ||
_get_dtype_real = V.graph.get_dtype | ||
|
||
def get_dtype(name): | ||
if name == fake_out.get_name(): | ||
return fake_out.get_dtype() | ||
return _get_dtype_real(name) | ||
|
||
return get_dtype | ||
|
||
def __init__(self, name: str): | ||
self.name = name | ||
|
||
def maybe_append_choice(self, choices, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the "choices" argument here? ( e.g. datatype and intended usage). A small doc comment would help clarify I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some comments. |
||
""" | ||
Maybe generates a new ChoiceCaller and appends it into existing choices. | ||
|
||
choices: A list of ChoiceCallers. | ||
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller. | ||
""" | ||
|
||
try: | ||
choices.append(self.generate(**kwargs)) | ||
except NotImplementedError: | ||
pass | ||
|
||
def generate(self, **kwargs) -> ChoiceCaller: | ||
""" | ||
Generates a ChoiceCaller instance from the given arguments. | ||
""" | ||
|
||
raise NotImplementedError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A small doc comment would be great: What's the purpose of this class / which problem does it solve? Is it supposed to have subclasses?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the original code moved from select_algorithm.py. Let me add some comments.