Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor CUTLASS backend] Step 4: CUDA (template) kernels #107931

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
75cfaa0
[Inductor CUTLASS backend] Step 4: CUDA (template) kernels
ipiszy Aug 25, 2023
0c6a52d
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 25, 2023
c0b885d
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
a96e751
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
5ad9fee
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
51fa22c
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
701bbee
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
9949e3f
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
46da92b
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
d4008e5
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 26, 2023
44bee1c
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 27, 2023
c10e787
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 27, 2023
c75608d
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 27, 2023
8530cde
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 27, 2023
0f879a0
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Aug 29, 2023
43912cb
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 6, 2023
c638a77
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 6, 2023
0b54e9c
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 6, 2023
e4246c5
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 6, 2023
67d8bb9
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 7, 2023
052e685
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 7, 2023
97015e0
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 7, 2023
558ee89
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 7, 2023
6d6b0da
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 7, 2023
836ae1f
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 7, 2023
04ea177
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 8, 2023
9183a36
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 8, 2023
ea84a0d
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 11, 2023
49c919d
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 11, 2023
e6ec8d5
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 11, 2023
e62717f
Update on "[Inductor CUTLASS backend] Step 4: CUDA (template) kernels"
ipiszy Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
112 changes: 97 additions & 15 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from .. import metrics
from ..utils import (
DeferredLineBase,
do_bench_using_profiling,
free_symbol_startswith,
get_sympy_Expr_dtype,
IndentedBuffer,
sympy_dot,
sympy_subs,
Expand Down Expand Up @@ -560,17 +560,6 @@ def wrap_size_arg(self, size):
def cpp_argdefs(self):
from .cpp import DTYPE_TO_CPP, INDEX_TYPE

# TODO(jansel): replace this with data from scheduler
buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers}
for name, val in V.graph.graph_inputs.items():
if isinstance(val, sympy.Expr):
buffer_types[name] = get_sympy_Expr_dtype(val)
else:
buffer_types[name] = val.get_dtype()
buffer_types.update(
{name: val.dtype for name, val in V.graph.constants.items()}
)

call_args = []
arg_defs = []
arg_types = []
Expand All @@ -579,23 +568,23 @@ def cpp_argdefs(self):
continue
outer = inplaced.other_names[-1]
inner = inplaced.inner_name
dtype = buffer_types[outer]
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.input_buffers.items():
if outer in self.inplace_buffers:
continue
dtype = buffer_types[outer]
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"const {cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"const {cpp_dtype}*")
for outer, inner in self.output_buffers.items():
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
continue
dtype = buffer_types[outer]
dtype = V.graph.get_dtype(outer)
cpp_dtype = DTYPE_TO_CPP[dtype]
arg_defs.append(f"{cpp_dtype}* {inner}")
call_args.append(self.wrap_ptr_arg(outer, dtype))
Expand Down Expand Up @@ -1045,3 +1034,96 @@ class OptimizationContext:

# Load uint8 value as float32
is_load_uint8_as_float: bool = False


@functools.lru_cache(None)
def jinja2_env():
try:
import jinja2

return jinja2.Environment(
undefined=jinja2.StrictUndefined,
)
except ImportError:
return None


class ChoiceCaller:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small doc comment would be great: What's the purpose of this class / which problem does it solve? Is it supposed to have subclasses?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original code moved from select_algorithm.py. Let me add some comments.

"""
Represents a possible choice used in autotune_process.py.
During autotuning, self.benchmark() is first called to get benchmark result,
and if this choice is selected, self.output_node() is called to get the output_node.

Children classes: TritonTemplateCaller, CUDATemplateCaller.
"""

def __init__(self, name, input_nodes, layout):
super().__init__()
self.name = name
self.layout = layout
self.input_nodes = input_nodes

def benchmark(self, *args, out) -> float:
algo = self.to_callable()
return do_bench_using_profiling(lambda: algo(*args, out=out))

def call_name(self) -> str:
raise NotImplementedError()

def to_callable(self):
raise NotImplementedError()

def hash_key(self) -> str:
raise NotImplementedError()

def output_node(self) -> "TensorBox": # type: ignore[name-defined]
raise NotImplementedError()


class KernelTemplate:
"""
Base class for defining kernel templates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which kind of kernel templates? ( e.g. Triton / C++ / Cutlass / any involving Jinja templates )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.


Children classes: TritonTemplate, CUDATemplate
"""

@staticmethod
def _template_from_string(source):
env = jinja2_env()
if env is not None:
return env.from_string(source)
return None

@staticmethod
def _fake_get_dtype(fake_out):
_get_dtype_real = V.graph.get_dtype

def get_dtype(name):
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)

return get_dtype

def __init__(self, name: str):
self.name = name

def maybe_append_choice(self, choices, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the "choices" argument here? ( e.g. datatype and intended usage). A small doc comment would help clarify I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments.

"""
Maybe generates a new ChoiceCaller and appends it into existing choices.

choices: A list of ChoiceCallers.
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
"""

try:
choices.append(self.generate(**kwargs))
except NotImplementedError:
pass

def generate(self, **kwargs) -> ChoiceCaller:
"""
Generates a ChoiceCaller instance from the given arguments.
"""

raise NotImplementedError()
1 change: 1 addition & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def _print_Max(self, expr):
return f"std::max({il})"


# A function to print, useful for printing sympy symbols.
cexpr = CppPrinter().doprint


Expand Down