From 75cfaa017b2c03562ff9ee6bc5ba352463737678 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Thu, 24 Aug 2023 22:39:57 -0700 Subject: [PATCH] [Inductor CUTLASS backend] Step 4: CUDA (template) kernels [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 98 ++++++- torch/_inductor/codegen/cuda/cuda_kernel.py | 272 ++++++++++++++++++ .../_inductor/codegen/cuda/cuda_scheduling.py | 46 +++ torch/_inductor/codegen/cuda/cuda_template.py | 186 ++++++++++++ torch/_inductor/codegen/triton.py | 15 +- torch/_inductor/codegen/wrapper.py | 29 +- torch/_inductor/ir.py | 38 +++ torch/_inductor/scheduler.py | 6 +- torch/_inductor/select_algorithm.py | 114 ++------ 9 files changed, 690 insertions(+), 114 deletions(-) create mode 100644 torch/_inductor/codegen/cuda/cuda_kernel.py create mode 100644 torch/_inductor/codegen/cuda/cuda_scheduling.py create mode 100644 torch/_inductor/codegen/cuda/cuda_template.py diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a4611610f7b0..bbd83eca71a7 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -19,6 +19,7 @@ from .. import metrics from ..utils import ( DeferredLineBase, + do_bench_using_profiling, free_symbol_startswith, get_sympy_Expr_dtype, IndentedBuffer, @@ -555,17 +556,6 @@ def wrap_size_arg(self, size): def cpp_argdefs(self): from .cpp import DTYPE_TO_CPP, INDEX_TYPE - # TODO(jansel): replace this with data from scheduler - buffer_types = {x.get_name(): x.get_dtype() for x in V.graph.buffers} - for name, val in V.graph.graph_inputs.items(): - if isinstance(val, sympy.Expr): - buffer_types[name] = get_sympy_Expr_dtype(val) - else: - buffer_types[name] = val.get_dtype() - buffer_types.update( - {name: val.dtype for name, val in V.graph.constants.items()} - ) - call_args = [] arg_defs = [] arg_types = [] @@ -574,7 +564,7 @@ def cpp_argdefs(self): continue outer = inplaced.other_names[-1] inner = inplaced.inner_name - dtype = buffer_types[outer] + dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"{cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) @@ -582,7 +572,7 @@ def cpp_argdefs(self): for outer, inner in self.input_buffers.items(): if outer in self.inplace_buffers: continue - dtype = buffer_types[outer] + dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"const {cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) @@ -590,7 +580,7 @@ def cpp_argdefs(self): for outer, inner in self.output_buffers.items(): if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner): continue - dtype = buffer_types[outer] + dtype = V.graph.get_dtype(outer) cpp_dtype = DTYPE_TO_CPP[dtype] arg_defs.append(f"{cpp_dtype}* {inner}") call_args.append(self.wrap_ptr_arg(outer, dtype)) @@ -1029,3 +1019,83 @@ class OptimizationContext: # Load uint8 value as float32 is_load_uint8_as_float: bool = False + + +@functools.lru_cache(None) +def jinja2_env(): + try: + import jinja2 + + return jinja2.Environment( + undefined=jinja2.StrictUndefined, + ) + except ImportError: + return None + + +class ChoiceCaller: + def __init__(self, name, input_nodes, layout): + super().__init__() + self.name = name + self.layout = layout + self.input_nodes = input_nodes + + def benchmark(self, *args, out): + algo = self.to_callable() + return do_bench_using_profiling(lambda: algo(*args, out=out)) + + def call_name(self): + raise NotImplementedError() + + def to_callable(self): + raise NotImplementedError() + + def hash_key(self): + raise NotImplementedError() + + def output_node(self): + raise NotImplementedError() + + +class KernelTemplate: + """ + Base class for defining kernel templates. + """ + + @staticmethod + def _template_from_string(source): + env = jinja2_env() + if env is not None: + return env.from_string(source) + return None + + @staticmethod + def fake_get_dtype(fake_out): + _get_dtype_real = V.graph.get_dtype + + def get_dtype(name): + if name == fake_out.get_name(): + return fake_out.get_dtype() + return _get_dtype_real(name) + + return get_dtype + + def __init__(self, name: str): + self.name = name + + def maybe_append_choice(self, choices, **kwargs): + """ + Maybe generates a ChoiceCaller and appends it into choices. + """ + + try: + choices.append(self.generate(**kwargs)) + except NotImplementedError: + pass + + def generate(self, **kwargs) -> ChoiceCaller: + """ + Generates a ChoiceCaller instance from the given arguments. + """ + + raise NotImplementedError() diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 000000000000..e2268db6f25f --- /dev/null +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,272 @@ +from typing import List, Optional + +import sympy + +from ...autotune_process import CUDABenchmarkRequest +from ...ir import Callable, CUDATemplateBuffer, IRNode, Layout, TensorBox +from ...select_algorithm import ChoiceCaller +from ...utils import sympy_product +from ...virtualized import V + +from ..common import IndentedBuffer, Kernel, OpOverrides +from ..cpp import CppPrinter, DTYPE_TO_CPP + + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class CUDAKernel(Kernel): + """ + Kernels defined by the CUDA language. + """ + overrides = OpOverrides + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by the CUDA language. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__( + self, + kernel_name, + ): + super().__init__() + self.kernel_name = kernel_name + self.named_nodes = {} + + def arg_name(self, node: IRNode) -> Optional[str]: + if node is None: + return None + return {**self.args.input_buffers, **self.args.output_buffers}.get( + node.get_name(), None + ) + + def check_not_null(self, node: IRNode) -> str: + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + """ + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """.format( + name_str=name_str, size_str=size_str + ) + ) + return res.getvalue() + + def def_kernel( + self, + inputs: List[IRNode], + outputs: List[IRNode], + names_str: str = "", + input_reorder: List[int] = None, + ) -> str: + """ + Hook called from template code to generate function def and + needed args. + """ + + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs() + return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})" + + def call_kernel(self, name: str, node: CUDATemplateBuffer) -> None: + wrapper = V.graph.wrapper_code + _, call_args, _ = self.args.python_argdefs() + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + else: + call_args[i] = f"c_void_p({call_args[i]}.data_ptr())" + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + call_args.append("None") + + if node.get_workspace_size() > 0: + call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())") + else: + call_args.append("None") + + wrapper.generate_kernel_call( + name, + call_args, + V.graph.scheduler.current_device.index, + cuda=True, + triton=False, + ) + + def dtype(self, node: IRNode) -> str: + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def offset(self, node: IRNode) -> str: + if node is None: + return "0" + return str(node.get_layout().offset) + + def ptr(self, node: IRNode, default_node: IRNode = None) -> str: + if node is None: + if default_node is not None: + node = default_node + else: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + + sizes = node.get_size()[start_index : end_index + 1] + if len(sizes) == 0: + return str(default_value) + + val = sympy_product(sizes) + return cexpr(self.rename_indexing(val)) + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + return cexpr(self.rename_indexing(stride)) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + Will add needed args to pass it in if it is dynamic. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + +class CUDATemplateCaller(ChoiceCaller): + def __init__( + self, + name: str, + category: str, + input_nodes: List[IRNode], + layout: Layout, + make_kernel_render: Callable[[str], str], + bmreq: CUDABenchmarkRequest, + ): + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + + def benchmark(self, *args, out): + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def __str__(self): + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def hash_key(self): + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def output_node(self): + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + ) + ) diff --git a/torch/_inductor/codegen/cuda/cuda_scheduling.py b/torch/_inductor/codegen/cuda/cuda_scheduling.py new file mode 100644 index 000000000000..5811f11973dd --- /dev/null +++ b/torch/_inductor/codegen/cuda/cuda_scheduling.py @@ -0,0 +1,46 @@ +from typing import List + +from ... import config +from ...codecache import code_hash, get_path +from ...scheduler import BaseSchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata +from ...virtualized import V + +from ..common import IndentedBuffer +from ..triton import TritonScheduling + + +class CUDAScheduling(TritonScheduling): + """ + Final codegen for CUDAKernels. + """ + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + basename, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.cuda('so', r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 000000000000..1aac2335b6df --- /dev/null +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,186 @@ +import functools +import itertools +from copy import copy +from typing import List +from unittest.mock import patch + +import sympy +import torch + +# import cutlass libs +import scripts as cutlass_lib + +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import jinja2_env, KernelTemplate + +from . import cutlass_utils +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel + +log = logging.getLogger(__name__) + + +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, name: str, input_nodes: List[IRNode], layout: Layout, + input_reorder: List[int]=None, + ): + super().__init__(name) + self.input_nodes = input_nodes + self.output_node = Buffer("buf_out", layout) + self.input_reorder = input_reorder + + def generate(self, **kwargs) -> CUDATemplateCaller: + kernel_name = f"cuda_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), CUDATemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _ = kernel.args.python_argdefs() + log.debug(f"Generated Code:\n{code}") + log.debug(f"Args: {kernel.args.cpp_argdefs()}, {kernel.args.python_argdefs()}") + + input_reorder = self.input_reorder if self.input_reorder is not None else list(range(len(self.input_nodes))) + expected_args = list(unique(self.input_nodes[idx].get_name() for idx in input_reorder)) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + + kernel_hash_name = f"cuda_{self.name}_{next(self.index_counter)}" + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render(output_node): + kernel = CUDATemplateKernel( + kernel_name="KERNEL_NAME", + ) + render = functools.partial( + self.render, + kernel=kernel, + output_node=output_node, + **kwargs, + ) + return kernel, render + + return CUDATemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + using bfloat16 = nv_bfloat16; + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + +class CUTLASSTemplate(CUDATemplate): + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in {"1", "1L"}: + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0fc8b7f54396..f8de1fb21645 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -20,7 +20,7 @@ from .. import config, ir, scheduler from ..codecache import code_hash, get_path from ..dependencies import MemoryDep, StarDep -from ..ir import ReductionHint +from ..ir import IRNode, ReductionHint from ..optimize_indexing import indexing_dtype_strength_reduction from ..scheduler import BaseScheduling from ..triton_heuristics import AutotuneHint @@ -2012,7 +2012,7 @@ def dense_size_str(self): return f"[{', '.join(sizes)}]" - def call_kernel(self, name: str): + def call_kernel(self, name: str, node: IRNode = None): wrapper = V.graph.wrapper_code _, call_args, _ = self.args.python_argdefs() # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar @@ -2037,6 +2037,8 @@ def call_kernel(self, name: str): call_args, grid, V.graph.scheduler.current_device.index, + cuda=True, + triton=True, ) def warn_mix_layout(self, kernel_name): @@ -2137,7 +2139,9 @@ def can_fuse(self, node1, node2): return False if node1.is_template(): - return True # skip checks for compatible tiling + # Only allow fusion for TritonTemplates for now. + # Fusion for CUDATemplates are not supported. + return isinstance(node1.node, TritonTemplateBuffer) # check for a bad combined tiling tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) @@ -2504,11 +2508,12 @@ def codegen_template(self, template_node, epilogue_nodes): node.codegen(kernel.split_and_set_ranges(node.get_ranges())) # finalize must be called after adding epilogue above - src_code = partial_code.finalize() + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + src_code = partial_code if isinstance(partial_code, str) else partial_code.finalize() node_schedule = [template_node, *epilogue_nodes] kernel_name = self.define_kernel(src_code, node_schedule) self.codegen_comment(node_schedule) - kernel.call_kernel(kernel_name) + kernel.call_kernel(kernel_name, template_node.node) self.scheduler.free_buffers() def codegen_sync(self): diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 0208cab2b861..c5229b3dce1e 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -722,8 +722,17 @@ def generate_profiler_mark_wrapper_call(self, stack): stack.enter_context(self.wrapper_call.indent()) def generate_kernel_call( - self, name, call_args, grid=None, device_index=None, cuda=True + self, name, call_args, grid=None, device_index=None, cuda=True, triton=True, ): + """ + Generates kernel call code. + + cuda: Defines whether the backend is GPU. Otherwise the backend is CPU. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ if cuda: call_args_str = ", ".join(pexpr(item) for item in call_args) grid_str = ", ".join(pexpr(item) for item in grid) @@ -733,6 +742,16 @@ def generate_kernel_call( self.writeline( f"{name}.run({call_args_str}, grid=grid({grid_str}), stream={stream_name})" ) + if triton: + grid_str = ", ".join(pexpr(item) for item in grid) + self.writeline( + f"{name}.run({call_args_str}, grid=grid({grid_str}), stream={stream_name})" + ) + else: + stream_ptr = f"c_void_p({stream_name})" + self.writeline( + f"{name}.{name}({call_args_str}, {stream_ptr})" + ) else: self.writeline(self.wrap_kernel_call(name, call_args)) @@ -809,6 +828,8 @@ def use_preallocated_ouput(self, buffer): ) def codegen_allocation(self, buffer): + assert buffer.get_workspace_size() == 0, "Only support zero size workspace size for now!" + name = buffer.get_name() if name in V.graph.removed_buffers or name in self.allocated: @@ -842,6 +863,8 @@ def codegen_allocation(self, buffer): ) def codegen_free(self, buffer): + assert buffer.get_workspace_size() == 0, "Only support zero size workspace size for now!" + name = buffer.get_name() # can be freed but not reused @@ -1433,11 +1456,11 @@ def generate_args_decl(self, call_args): return ", ".join(new_args) def generate_kernel_call( - self, name, call_args, grid=None, device_index=None, cuda=True + self, name, call_args, grid=None, device_index=None, cuda=True, triton=True ): if not cuda: return super().generate_kernel_call( - name, call_args, grid, device_index, cuda + name, call_args, grid, device_index, cuda, triton ) params = CudaKernelParamCache.get(name) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 188ab1028bcf..90ac5fa026db 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -304,6 +304,12 @@ def is_user_of(self, name): def get_read_names(self): return {dep.name for dep in self.get_reads()} + def get_layout(self): + raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!") + + def get_size(self): + raise NotImplementedError(f"get_size() is not implemented by {type(self)}!") + def get_numel(self): return sympy_product(self.get_size()) @@ -1446,6 +1452,9 @@ def loader(idx): def get_dtype(self): return self.data.get_dtype() + def get_layout(self): + return self.data.get_layout() + def get_device(self): return self.data.get_device() @@ -1847,6 +1856,9 @@ def get_origin_node(self): def get_dtype(self): return self.layout.dtype + def get_layout(self): + return self.layout + def get_size(self): return list(self.layout.size) @@ -2422,6 +2434,12 @@ def get_reads(self): def realize(self): pass + def get_workspace_size(self): + """ + Gets extra global memory size needed by this buffer. + Some algorithms (e.g. group gemm) may require extra global memory in the generated code. + """ + return 0 class InputBuffer(Buffer): pass @@ -2755,6 +2773,23 @@ def simplify_and_reorder(self): None, ) +class TritonTemplateBuffer(TemplateBuffer): + pass + +class CUDATemplateBuffer(TemplateBuffer): + def __init__( + self, + layout, + inputs, + make_kernel_render, + workspace_size: int = 0, + ): + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + + def get_workspace_size(self): + return self.workspace_size if self.workspace_size is not None else 0 @dataclasses.dataclass class InputsKernel(Buffer): @@ -4725,6 +4760,9 @@ def realize(self): def layout(self): return self.data.layout + def get_layout(self): + return self.layout + def __str__(self): if isinstance(self.data, MutableBox): line0 = f"{type(self).__name__}({type(self.data).__name__}(" diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 8eaa3728a7b3..736db8bddf6d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1644,7 +1644,11 @@ def codegen(self): if node.is_template(): node, *epilogue = node.get_nodes() - self.get_backend(device).codegen_template(node, epilogue) + if isinstance(node.node, ir.CUDATemplateBuffer): + from .codegen.cuda.cuda_scheduling import CUDAScheduling + CUDAScheduling(self).codegen_template(node, epilogue) + else: + self.get_backend(device).codegen_template(node, epilogue) elif node.is_extern(): self.codegen_extern_call(node) elif node.is_foreach(): diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5b7cbdbc56ff..ca6002f82277 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -18,15 +18,13 @@ from torch._dynamo.utils import counters, identity from . import config, ir -from .autotune_process import BenchmarkRequest, TensorMeta +from .autotune_process import TritonBenchmarkRequest, TensorMeta from .codecache import code_hash, PersistentCache, PyCodeCache - -from .codegen.common import IndentedBuffer +from .codegen.common import ChoiceCaller, IndentedBuffer, jinja2_env, KernelTemplate from .codegen.triton import texpr, TritonKernel, TritonPrinter, TritonScheduling - from .codegen.triton_utils import config_of, signature_to_meta - -from .utils import do_bench, sympy_dot, sympy_product, unique +from .exc import CUDACompileError +from .utils import do_bench_using_profiling, sympy_dot, sympy_product, unique from .virtualized import V log = logging.getLogger(__name__) @@ -337,7 +335,8 @@ def initialize_range_tree(self, pid_cache): self.body.clear() self.indexing_code.clear() - def call_kernel(self, name: str): + def call_kernel(self, node: ir.TritonTemplateBuffer): + name = node.get_name() wrapper = V.graph.wrapper_code _, call_args, _ = self.args.python_argdefs() @@ -383,54 +382,18 @@ def _jinja2_env(): return None -class TritonTemplate: +class TritonTemplate(KernelTemplate): index_counter = itertools.count() all_templates: Dict[str, "TritonTemplate"] = dict() - @staticmethod - def _template_from_string(source): - env = _jinja2_env() - if env is not None: - return env.from_string(source) - return None - def __init__(self, name: str, grid: Any, source: str, debug=False): - super().__init__() - self.name = name + super().__init__(name) self.grid = grid self.template = self._template_from_string(source) assert name not in self.all_templates, "duplicate template name" self.all_templates[name] = self self.debug = debug - def maybe_append_choice( - self, - choices, - input_nodes, - layout, - num_stages, - num_warps, - prefix_args=0, - suffix_args=0, - epilogue_fn=identity, - **kwargs, - ): - try: - choices.append( - self.generate( - input_nodes=input_nodes, - layout=layout, - num_stages=num_stages, - num_warps=num_warps, - prefix_args=prefix_args, - suffix_args=suffix_args, - epilogue_fn=epilogue_fn, - **kwargs, - ) - ) - except NotImplementedError: - pass - def generate( self, input_nodes, @@ -530,16 +493,16 @@ def make_kernel_render(out_node): # create the BenchmarkRequest grid = self.grid(*V.graph.sizevars.size_hints(layout.size), kwargs) - bmreq = BenchmarkRequest( + bmreq = TritonBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(layout), module_path=mod.__file__, module_cache_key=mod.key, - kernel_name=kernel_name, grid=grid, extra_args=extra_args, num_stages=num_stages, num_warps=num_warps, - input_tensors=TensorMeta.from_irnodes(input_nodes), - output_tensor=TensorMeta.from_irnodes(layout), ) return TritonTemplateCaller( @@ -551,17 +514,6 @@ def make_kernel_render(out_node): bmreq, ) - @staticmethod - def fake_get_dtype(fake_out): - _get_dtype_real = V.graph.get_dtype - - def get_dtype(name): - if name == fake_out.get_name(): - return fake_out.get_dtype() - return _get_dtype_real(name) - - return get_dtype - class ExternKernelChoice: def __init__( @@ -608,30 +560,6 @@ def bind(self, input_nodes, layout, ordered_kwargs_for_cpp_kernel=(), **kwargs): ) -class ChoiceCaller: - def __init__(self, name, input_nodes, layout): - super().__init__() - self.name = name - self.layout = layout - self.input_nodes = input_nodes - - def benchmark(self, *args, out): - algo = self.to_callable() - return do_bench(lambda: algo(*args, out=out)) - - def call_name(self): - raise NotImplementedError() - - def to_callable(self): - raise NotImplementedError() - - def hash_key(self): - raise NotImplementedError() - - def output_node(self): - raise NotImplementedError() - - class TritonTemplateCaller(ChoiceCaller): def __init__( self, name, input_nodes, layout, make_kernel_render, debug_extra, bmreq @@ -661,7 +589,7 @@ def hash_key(self): def output_node(self): return ir.TensorBox.create( - ir.TemplateBuffer( + ir.TritonTemplateBuffer( layout=self.layout, inputs=self.input_nodes, make_kernel_render=self.make_kernel_render, @@ -697,7 +625,7 @@ def benchmark(self, *args, out): out_new, tuple(out.size()), tuple(out.stride()) ) out.copy_(out_new) # for correctness checking - return do_bench(lambda: algo(*args)) + return do_bench_using_profiling(lambda: algo(*args)) def to_callable(self): fn = self.choice.to_callable() @@ -752,9 +680,7 @@ def __call__(self, name, choices: List[ChoiceCaller], input_nodes, layout): "No choices to select, please consider adding ATEN into max_autotune_gemm_backends " "config (defined in torch/_inductor/config.py) to allow at least one choice. " ) - - if len(choices) == 1: - return choices[0].output_node() + log.debug(f"Max autotune selects from {len(choices)} choices.") @functools.lru_cache(None) def make_benchmark_fn(): @@ -766,6 +692,9 @@ def autotune(choice): timing = benchmark_fn( choice, ) + except CUDACompileError as e: + log.warning(f"CUDA compilation error: \n{str(e)}. \nIgnore this choice.") + return float('inf') except RuntimeError as e: msg = str(e) if "invalid argument" in msg: @@ -800,7 +729,9 @@ def autotune(choice): if make_benchmark_fn.cache_info().currsize: counters["inductor"]["select_algorithm_autotune"] += 1 self.log_results(name, input_nodes, timings, autotune_elapse) - return builtins.min(timings, key=timings.__getitem__).output_node() + selected_choice = builtins.min(timings, key=timings.__getitem__).output_node() + log.debug(f"selected choice: {str(selected_choice)}") + return selected_choice @classmethod def make_benchmark_fn( @@ -903,7 +834,8 @@ def log_results(name, input_nodes, timings, elapse): for n in input_nodes ] ) - top_k = sorted(timings, key=timings.__getitem__)[:10] + n = None if log.getEffectiveLevel() == logging.DEBUG else 10 + top_k = sorted(timings, key=timings.__getitem__)[:n] best = top_k[0] best_time = timings[best] sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")