diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index dc5cb7b722c5e..80da71ccf6e80 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -2,12 +2,13 @@ import os import unittest -from typing import List, Optional +from typing import Callable, List, Optional import torch from torch import multiprocessing as mp from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import reset_rng_state +from torch._dynamo.utils import counters from torch._inductor import config from torch._inductor.autotune_process import ( BenchmarkRequest, @@ -256,6 +257,154 @@ def mm(a, b): Y = mm(a, b) torch.testing.assert_close(Y_compiled, Y) + def _test_max_autotune_cutlass_backend_epilogue_fusion( + self, + dynamic: bool = False, + max_autotune_gemm_backends: str = "CUTLASS", + mixed_precision=False, + fp16=True, + expected_fuse_count=1, + mm: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None, + ): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + mixed_precision + ) + + # Note: The ops that are available + # also depend on the alignment of the shapes + # so if these shapes don't all align to at least 8 elements + # it can happen that no Cutlass 3.x op is available + # that allows fusions + a = torch.randn(256, 32).cuda() + b = torch.randn(32, 256).cuda() + if fp16: + a = a.half() + b = b.half() + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_gemm_backends": max_autotune_gemm_backends, + "cuda.cutlass_dir": _CUTLASS_DIR, + "cuda.cutlass_max_profiling_configs": 4, + "cuda.cutlass_only_evt_capable_ops": True, + "cuda.version": "12.2", # required to enable the Kernels we need + } + ): + counters["inductor"]["cuda_epilogue_fusion_counter"] = 0 + Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b) + Y = mm(a, b) + actual_count = counters["inductor"]["cuda_epilogue_fusion_counter"] + assert ( + actual_count == expected_fuse_count + ), f"Expected fuse count of {expected_fuse_count} but got {actual_count}" + torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_simple_fusion_fp16(self): + def mm(a, b): + return (a @ b) * 3.0 + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self): + def mm(a, b): + return (a @ b) * 3.0 + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_chained_fusion_fp16(self): + def mm(a, b): + return (a @ b) * 3.3 - 1.234 + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self): + def mm(a, b): + return (a @ b) * 3.3 - 1.234 + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_relu_fusion_fp16(self): + def mm(a, b): + return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=False, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self): + def mm(a, b): + return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self): + def mm(a, b): + return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0) + + # The pointwise ops seem to be pre-fused into a single Pointwise + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self): + def mm(a, b): + # this should not be fused, since the output dtype is different from the matmul dtype + return (a @ b).to(torch.float32) * 0.00001 + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm + ) + + @unittest.skipIf(not SM90OrLater, "need sm_90") + @unittest.skipIf(torch.version.hip, "HIP not supported") + @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") + def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self): + def mm(a, b): + return (a @ b) / b.size(1) + + self._test_max_autotune_cutlass_backend_epilogue_fusion( + mixed_precision=True, fp16=True, expected_fuse_count=1, mm=mm + ) + # TODO: Enable dynamic test cases when dynamic support is added. @unittest.skipIf(not SM75OrLater, "need sm_75") @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup") @@ -263,7 +412,7 @@ def mm(a, b): @parametrize("max_autotune_gemm_backends", ("CUTLASS", "ATen,Triton,CUTLASS")) @unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_backend_mm_bias( - self, dynamic: bool, max_autotune_gemm_backends: str + self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS" ): """ Make sure autotuning mm in sub processes work without crashes. @@ -294,7 +443,7 @@ def mm(a, b, bias): torch.testing.assert_close(Y_compiled, Y, atol=1e-1, rtol=1e-1) @parametrize("dynamic", (False, True)) - def test_max_autotune_addmm(self, dynamic): + def test_max_autotune_addmm(self, dynamic=False): """ Make sure autotuning addmm in sub processes work without crashes. """ diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000..c1e554ed184a1 --- /dev/null +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,212 @@ +import logging +from typing import cast, List + +from ...._dynamo.utils import counters + +from ... import config, ir +from ...codecache import code_hash, get_path +from ...ir import ComputedBuffer, CUDATemplateBuffer, Pointwise +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer + +from .cutlass_epilogue_gen import CUTLASSEVTOpNotImplementedError + +log = logging.getLogger(__name__) + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + def __init__(self, scheduler: Scheduler): + super().__init__() + self.scheduler = scheduler + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def is_cuda_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template( + node.get_template_node() + ) + + def _can_fuse_epilogue_impl( + self, + cuda_template_buffer: CUDATemplateBuffer, + epilogue_nodes: List[ir.IRNode], + additional_node: ir.IRNode, + ) -> bool: + """ + Check if the given node can be fused with the epilogue. At the moment, Kernels + support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes. + + Args: + cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer + epilogue_nodes : List[ir.Buffer]: The list of already fused epilogue nodes. + additional_node: The ir.Buffer node to be checked if it can be fused with the epilogue. + Returns: + - bool: True if the given node can be fused with the epilogue, False otherwise. + + """ + if not isinstance(cuda_template_buffer, CUDATemplateBuffer): + return False + if not cuda_template_buffer.template.can_fuse_epilogue: + # The used GEMM op does not support fusing epilogues + return False + if not isinstance(additional_node, ComputedBuffer): + return False + if not isinstance(additional_node.data, Pointwise): + return False + # We can fuse a Pointwise op that depends on the last fused epilogue node + # if any. If there is no epilogue node yet, it needs to depend on the template + # node + node_name = additional_node.get_computed_buffer_name() # type: ignore[attr-defined] + if node_name is None: + return False + + if len(epilogue_nodes) == 0: + if cuda_template_buffer.name not in additional_node.get_read_names(): + return False + else: + last_epilogue_node = epilogue_nodes[-1] + assert isinstance(last_epilogue_node, ir.ComputedBuffer) # for mypy + last_epilogue_name = ( + last_epilogue_node.name + if last_epilogue_node.name is not None + else last_epilogue_node.data.name # type: ignore[attr-defined] + ) + if last_epilogue_name not in additional_node.get_read_names(): + return False + if additional_node.layout != cuda_template_buffer.layout: + return False + try: + from torch._inductor.codegen.cuda.cutlass_epilogue_gen import ( + CutlassEVTEpilogueArgumentFormatter, + CutlassEVTEpilogueTypeFormatter, + ) + + CutlassEVTEpilogueTypeFormatter.ir_to_evt_string( + cast(str, cuda_template_buffer.name), "anything", [additional_node] + ) + CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string( + cast(str, cuda_template_buffer.name), [additional_node] + ) + except CUTLASSEVTOpNotImplementedError as e: + not_implemented_op = str(e) + if not_implemented_op.startswith("_op_"): + not_implemented_op = not_implemented_op[4:] + log.warning( + f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}, likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 + ) + return False + else: + # Likely due to unsupported dtype. + log.warning( + f"Cannot fuse epilogue node {additional_node} into {cuda_template_buffer.name}. Reason: {not_implemented_op}" # noqa: G004, B950 + ) + return False + return True + + @staticmethod + def _unwrap_epilogue_nodes(fused_node: FusedSchedulerNode) -> List[ir.IRNode]: + nodes = fused_node.get_nodes() + template_node = fused_node.get_template_node() + nodes.remove(template_node) + return [n.node for n in nodes] + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self.is_cuda_cpp_template(node1) and isinstance(node2, SchedulerNode): + return self._can_fuse_epilogue_impl( + cast(CUDATemplateBuffer, node1.node), [], node2.node + ) + elif self.is_cuda_cpp_fused_template(node1) and isinstance( + node2, SchedulerNode + ): + fnode1 = cast(FusedSchedulerNode, node1) + return self._can_fuse_epilogue_impl( + fnode1.get_template_node().node, + self._unwrap_epilogue_nodes(fnode1), + node2.node, + ) + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''', 'so')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template( + template_node + ), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule) + kernel.call_kernel(kernel_name, ctb, epilogue_ir_nodes) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 1c4564ba89eeb..cf5eea9484ae9 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -1,5 +1,7 @@ +import logging from typing import Callable, Dict, List, Optional +from ... import ir from ...autotune_process import CUDABenchmarkRequest from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox from ...select_algorithm import ChoiceCaller @@ -9,6 +11,7 @@ from ..common import IndentedBuffer, Kernel, OpOverrides from ..cpp import CppPrinter, DTYPE_TO_CPP +log = logging.getLogger(__name__) cexpr = CppPrinter().doprint @@ -19,7 +22,7 @@ def _normalize_idx(index: int, total_length: int) -> int: class CUDAKernel(Kernel): """ - Kernels defined by C++ CUDA. + Baseclass for CUDA / Cutlass based Kernels """ overrides = OpOverrides # type: ignore[assignment] @@ -27,15 +30,18 @@ class CUDAKernel(Kernel): class CUDATemplateKernel(CUDAKernel): """ - Template kernels defined by C++ CUDA. + Template kernels defined by CUDA / Cutlass in C++. """ _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" - def __init__( - self, - kernel_name, - ): + def __init__(self, kernel_name): + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ super().__init__() self.kernel_name = kernel_name # Mapping from arg name to IRNode. @@ -45,7 +51,6 @@ def arg_name(self, node: IRNode) -> Optional[str]: """ Returns arg name of a given input or output node. """ - if node is None: return None return {**self.args.input_buffers, **self.args.output_buffers}.get( @@ -89,15 +94,17 @@ def def_kernel( input_reorder: Optional[List[int]] = None, ) -> str: """ - Hook called from template code to generate function def and + Hook called from template code to generate function definition and needed args. - inputs / outputs: List of input / output IRNodes. Note that IRNode can be None for optional arguments. - names_str: Comma separated list of input + output argument names. - input_reorder: The actual order of input nodes. - e.g. The template might have input argument defined as [X, W, Bias], - and the actual input passed into this template could be [Bias, X, W]. - In this case, the `input_reorder` would be [2, 0, 1]. + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. """ names = [x.strip() for x in names_str.strip().split(",")] @@ -126,14 +133,17 @@ def def_kernel( arg_defs, *_ = self.args.cpp_argdefs() return f"PT_EXPORT int {self.kernel_name}({', '.join(arg_defs)}, {self._EXTRA_CPP_ARGS})" - def call_kernel(self, name: str, node: CUDATemplateBuffer) -> None: + def call_kernel( + self, name: str, node: "CUDATemplateBuffer", epilogue_nodes: List[ir.Buffer] # type: ignore[name-defined] + ) -> None: """ Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.WrapperCodeGen name: Name of kernel function. - node: The IRNode which represents the kernel. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. """ - wrapper = V.graph.wrapper_code _, call_args, _ = self.args.python_argdefs() # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar @@ -266,19 +276,32 @@ def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + def __init__( self, name: str, category: str, input_nodes: List[Buffer], layout: Layout, - make_kernel_render: Callable[[str], str], + make_kernel_render: Callable[[CUDATemplateBuffer, Optional[List[IRNode]]], str], bmreq: CUDABenchmarkRequest, + template: "CUDATemplate", # type: ignore[name-defined] ): super().__init__(name, input_nodes, layout) self.category = category self.make_kernel_render = make_kernel_render self.bmreq = bmreq + self.template = template def benchmark(self, *args, out) -> float: assert self.bmreq is not None @@ -305,5 +328,6 @@ def output_node(self) -> TensorBox: inputs=self.input_nodes, make_kernel_render=self.make_kernel_render, workspace_size=self.bmreq.workspace_size, + template=self.template, ) ) diff --git a/torch/_inductor/codegen/cuda/cuda_scheduling.py b/torch/_inductor/codegen/cuda/cuda_scheduling.py deleted file mode 100644 index a017f9043faef..0000000000000 --- a/torch/_inductor/codegen/cuda/cuda_scheduling.py +++ /dev/null @@ -1,43 +0,0 @@ -from ... import config -from ...codecache import code_hash, get_path -from ...utils import get_fused_kernel_name, get_kernel_metadata -from ...virtualized import V - -from ..common import IndentedBuffer -from ..triton import TritonScheduling - - -class CUDAScheduling(TritonScheduling): - """ - Final codegen for CUDAKernels. - """ - - def define_kernel(self, src_code: str, node_schedule) -> str: - wrapper = V.graph.wrapper_code - if src_code in wrapper.src_to_kernel: - kernel_name = wrapper.src_to_kernel[src_code] - else: - fused_name = ( - get_fused_kernel_name(node_schedule, config.triton.descriptive_names) - if config.triton.descriptive_names - else "" - ) - kernel_name = "_".join(["cuda", fused_name, wrapper.next_kernel_suffix()]) - # use the original src_code as the key - wrapper.src_to_kernel[src_code] = kernel_name - src_code = src_code.replace("KERNEL_NAME", kernel_name) - - _, _, kernel_path = get_path(code_hash(src_code), "py") - - compile_wrapper = IndentedBuffer() - compile_wrapper.writeline("async_compile.cuda(r'''") - compile_wrapper.splice(src_code, strip=True) - compile_wrapper.writeline("''', 'so')") - - metadata_comment = f"# kernel path: {kernel_path}" - origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) - metadata_comment += "\n" + origins + "\n" + detailed_origins - wrapper.define_kernel( - kernel_name, compile_wrapper.getvalue(), metadata_comment - ) - return kernel_name diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 944483d09aee9..3e106dad84e4f 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -1,20 +1,18 @@ import functools import itertools import logging - from typing import List, Optional from unittest.mock import patch import sympy import torch - from ...autotune_process import CUDABenchmarkRequest, TensorMeta -from ...ir import Buffer, IRNode, Layout +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout + from ...utils import IndentedBuffer, unique from ...virtualized import V from ..common import KernelTemplate - from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel log = logging.getLogger(__name__) @@ -30,12 +28,37 @@ def __init__( layout: Layout, input_reorder: Optional[List[int]] = None, ): + """ + + Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ super().__init__(name) self.input_nodes = input_nodes self.output_node: Buffer = Buffer("buf_out", layout) self.input_reorder = input_reorder + self.layout = layout - def generate(self, **kwargs) -> CUDATemplateCaller: + def generate( # type: ignore[override] + self, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller + may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ kernel_name = f"cuda_{self.name}" with patch.object( V.graph, "get_dtype", self._fake_get_dtype(self.output_node) @@ -79,15 +102,19 @@ def generate(self, **kwargs) -> CUDATemplateCaller: source_code=code, ) - def make_kernel_render(output_node): + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): kernel = CUDATemplateKernel( kernel_name="KERNEL_NAME", ) render = functools.partial( self.render, kernel=kernel, - output_node=output_node, - **kwargs, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate ) return kernel, render @@ -98,6 +125,7 @@ def make_kernel_render(output_node): self.output_node.get_layout(), make_kernel_render, bmreq, + self, ) def header(self) -> IndentedBuffer: @@ -139,12 +167,19 @@ def render(self, **kwargs) -> str: class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + def header(self) -> IndentedBuffer: res = super().header() res.splice( """ + #include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/device/tensor_fill.h" @@ -157,6 +192,7 @@ def globals(self) -> IndentedBuffer: res = super().globals() res.splice( """ + using namespace cute; #define CUTLASS_CHECK(status) \\ { \\ cutlass::Status error = status; \\ @@ -166,6 +202,14 @@ def globals(self) -> IndentedBuffer: throw std::runtime_error(msg); \\ } \\ } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + """ ) return res diff --git a/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py b/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py new file mode 100644 index 0000000000000..fd751cdd4a2e2 --- /dev/null +++ b/torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py @@ -0,0 +1,360 @@ +from typing import Dict, List +from unittest.mock import patch + +import sympy + +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, FlexibleLayout, IRNode, Pointwise +from torch._inductor.utils import IndentedBuffer, sympy_str + + +# Used as a magic string to indicate an unsupported sympy expression +# became part of generated C++ code. +_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]" + + +def _arg_str(a): + if isinstance(a, sympy.Expr): + # If this return value containting the _MAGIC_SYMPY_ERROR_STRING + # is used as part of the final generated C++ code, + # a CUTLASSEVTOpNotImplementedError is raised to indicate that + # the op could not be converted to a valid EVT expression. + return f"{_MAGIC_SYMPY_ERROR_STRING}('{sympy_str(a)}')" + return str(a) + + +class CUTLASSEVTOpNotImplementedError(NotImplementedError): + pass + + +class CutlassEVTEpilogueTypeFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) functor declarations. + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueTypeFormatter.ir_to_evt_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name, evt_type_name): + """ + + Initialize an instance of CutlassEVTEpilogueTypeFormatter. + + Parameters: + - accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused) + IR graph. + - evt_type_name (str): The output name of the EVT type we are generating. + + """ + self.accumulator_node_name = accumulator_node_name + self.output = IndentedBuffer(0) + self.var_counter = 0 + self.evt_type_name = evt_type_name + self.aliases = dict() + + @staticmethod + def ir_to_evt_string( + template_output_node_name: str, + evt_type_name: str, + epilogue_nodes: List[IRNode], + ): + """ + Formats IR nodes into a string representation compatible with Cutlass EVT format. + + Args: + template_output_node_name (str): The name of the template output node. + evt_type_name (str): The name of the EVT type. + epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be + ComputedBuffer nodes wrapping Pointwise nodes. + + Returns: + A string representation of the IR nodes formatted according to the Cutlass EVT format. + """ + formatter = CutlassEVTEpilogueTypeFormatter( + template_output_node_name, evt_type_name + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg] + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + if isinstance(node, ComputedBuffer): + pnode = node.data + else: + raise RuntimeError( + "Epilogue nodes must be Pointwise nodes, wrapped in a named ComputedBuffer" + ) + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + formatter.aliases[node.name] = result + res = formatter.getvalue(result) + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + """ + Resolve V.ops. calls, after this instance has been installed as V.ops handler. + """ + + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + self.var_counter += 1 + varname = f"EVT_expr_{self.var_counter}" + # replace line with a new variable name + self.output.writeline(f"using {varname} = {line};") + return varname + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + # Load an input to an operation. Might be the output of the matmul, the result + # of a previous epilogue node, a constant or (TODO) an auxiliary input. + if name == self.accumulator_node_name: + return f"cutlass::epilogue::fusion::Sm90AccFetch /* :={name} (matmul output in accumulator) */" + elif name in self.aliases: + return self.aliases[name] + else: + # return f"cutlass::epilogue::fusion::Sm90SrcFetch /* :={name} */" + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + # Load a constant + if str(dtype) in ("torch.float16", "torch.float32"): + return f"cutlass::epilogue::fusion::Sm90ScalarBroadcast /* value={value}, dtype={dtype} */" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + # Perform a named operation on two inputs + # see https://github.com/NVIDIA/cutlass/blob/6407bcdf0a24097b7b016ee105937693c62f9923/include/cutlass/functional.h for ops + return f"cutlass::epilogue::fusion::Sm90EVT,{a},{b}>" # noqa: B950 + + def _convert_to_output_dtype(self, a): + # Convert the final output to the dtype of the output buffer + return f"cutlass::epilogue::fusion::Sm90EVT,{a}>" # noqa: B950 + + def _op_to_dtype(self, a, *args, **kwargs): + # no-op in our case, since we convert to the output dtype at the end and convert everything to the accumulator + # dtype. + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + return a # noqa: B950 + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return f"cutlass::epilogue::fusion::Sm90EVT,{a}, {const_zero}>" # noqa: B950 + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError() + + # Add more ops here... + def getvalue(self, result) -> str: + # Return final result + dtype_converted_expr = self._convert_to_output_dtype( + f"EVT_expr_{self.var_counter}" + ) + self.output.writeline(f"using {self.evt_type_name} = {dtype_converted_expr};") + return self.output.getvalue() + + +class CutlassEVTEpilogueArgumentFormatter: + """ + Codegen class, which provides an entry point to generate + Cutlass "Epilogue Visitor Tree" (EVT) Argument initializers + + See https://github.com/NVIDIA/cutlass/tree/main/examples/49_hopper_gemm_with_collective_builder + for more about EVTs and how they are declared and used to generate. + + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + + + """ + + def __init__(self, accumulator_node_name: str): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string static method. + + Args: + accumulator_node_name (str): The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + """ + self.accumulator_node_name: str = accumulator_node_name # + self.output: IndentedBuffer = IndentedBuffer(0) # The output buffer for codegen + self.var_counter: int = ( + 0 # used to generate variable names, incremented for each new variable + ) + self.aliases: Dict[str, str] = dict() # Aliases for subexpression functors + + @staticmethod + def ir_to_evt_argument_string( + template_output_node_name: str, + epilogue_nodes: List[IRNode], + ) -> str: + formatter = CutlassEVTEpilogueArgumentFormatter( + template_output_node_name, + ) + + with virtualized.V.set_ops_handler(formatter), patch.object( # type: ignore[call-arg] + FlexibleLayout, "allow_indexing", True + ): + for node in epilogue_nodes: + assert isinstance(node, ComputedBuffer) + pnode = node.data + assert isinstance(pnode, Pointwise) + index = pnode._index(pnode.ranges) + result = pnode.inner_fn(index) + # each epilogue node results in a single "using" statement and may refer to the previous steps by name + if node.name is not None: + formatter.aliases[node.name] = result + + res: str = formatter.getvalue(result) + if _MAGIC_SYMPY_ERROR_STRING in res: + raise CUTLASSEVTOpNotImplementedError( + "sympy / indexing expressions not yet supported in EVT fusion" + ) + else: + return res + + def __getattr__(self, name): + def inner(*args, **kwargs): + fargs = [_arg_str(a) for a in args] + fkwargs = {key: _arg_str(a) for key, a in kwargs.items()} + fn = getattr(self, f"_op_{name}") + line = fn(*fargs, **fkwargs) + return line + + if name.startswith("_"): + raise CUTLASSEVTOpNotImplementedError(name) + + if hasattr(self, f"_op_{name}"): + return inner + else: + raise CUTLASSEVTOpNotImplementedError(name) + + def _op_load(self, name, index_expr): + if name == self.accumulator_node_name: + return "{}" + elif name in self.aliases: + return self.aliases[name] + else: + raise CUTLASSEVTOpNotImplementedError( + f"Operand {name} not found. Auxiliary inputs not supported yet." + ) + + def _op_constant(self, value, dtype): + if str(dtype) in ("torch.float16", "torch.float32"): + return "{ static_cast(" + str(value) + ") }" + else: + raise CUTLASSEVTOpNotImplementedError( + f"Unsupported dtype for constant: {dtype}" + ) + + def _cutlass_binary_functional_op(self, op, a, b): + return f"{{ /*{op}: */ {a}, {b} }}" + + def _op_mul(self, a, b): + return self._cutlass_binary_functional_op("multiplies", a, b) + + def _op_div(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_truediv(self, a, b): + return self._cutlass_binary_functional_op("divides", a, b) + + def _op_ge(self, a, b): + return self._cutlass_binary_functional_op("greater_equal", a, b) + + def _op_add(self, a, b): + return self._cutlass_binary_functional_op("plus", a, b) + + def _op_sub(self, a, b): + return self._cutlass_binary_functional_op("minus", a, b) + + def _op_minimum(self, a, b): + return self._cutlass_binary_functional_op("minimum", a, b) + + def _op_maximum(self, a, b): + return self._cutlass_binary_functional_op("maximum", a, b) + + def _op_relu(self, a): + const_zero = self._op_constant(0.0, "torch.float32") + return "{" + str(a) + ", " + const_zero + "}" + + def _op_to_dtype(self, a, dtype, src_dtype=None): + # Is is asserted ( and ascertained during can_fuse decision ) that the dtype remains compatible + # throughout the fusion chain. + assert dtype in ( + "torch.float32", + "torch.float16", + ), f"Unsupported dtype: {dtype}" + assert src_dtype in ( + None, + "torch.float32", + "torch.float16", + ), f"Unsupported source dtype: {src_dtype}" + return a + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise CUTLASSEVTOpNotImplementedError() + + def getvalue(self, result) -> str: + return "{" + str(result) + "}" diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py new file mode 100644 index 0000000000000..daa834969beff --- /dev/null +++ b/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -0,0 +1,186 @@ +from ..cutlass_utils import try_import_cutlass + +if try_import_cutlass(): + import enum + + from cutlass_library.library import * # type: ignore[import] # noqa: F401, F403 + from cutlass_library.gemm_operation import * # type: ignore[import] # noqa: F401, F403 + + # copied / modified from original at + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 + # to support EVT similar to + # https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L315C69-L315C69 # noqa: B950 + class EmitGemmUniversal3xInstanceWithEVT: + """Responsible for emitting a CUTLASS 3.x template definition""" + + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > + """ + self.gemm_template = """ + using EpilogueScheduleType = ${epilogue_schedule}; + static_assert(cute::is_same_v || + cute::is_same_v, + "Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementAcc = ${element_accumulator}; + using ElementD = ${element_d}; + ${epilogue_functor}; + using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + EpilogueScheduleType, + ${operation_name}_epilogue_functor + >::CollectiveOp; + + using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + + // Gemm operator ${operation_name} + using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + cute::Shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + + // Define named type + struct ${operation_name} : + public ${operation_name}_base { }; + + """ + + # + def instance_template(self): + return """ + ${compile_guard_start} + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + ${compile_guard_end} + """ + + # + def emit(self, operation): + tile_shape = operation.tile_description.tile_shape + warp_count = operation.tile_description.warp_count + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + else: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout" # noqa: B950 + warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)] + + ( + instance_layout_A, + instance_layout_B, + instance_layout_C, + instance_layout_D, + ) = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + operation.D.layout, + ) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], # type: ignore[name-defined] + } + epilogue_functor = SubstituteTemplate( # type: ignore[name-defined] + self.builtin_epilogue_functor_template, values + ) + + elif callable(operation.epilogue_functor): + epilogue_functor = operation.epilogue_functor( + operation.procedural_name() + "_epilogue_functor" + ) + else: + epilogue_functor = str(operation.epilogue_functor) + # + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], # type: ignore[name-defined] + "layout_a": LayoutTag[instance_layout_A], # type: ignore[name-defined] + "element_b": DataTypeTag[operation.B.element], # type: ignore[name-defined] + "layout_b": LayoutTag[instance_layout_B], # type: ignore[name-defined] + "element_c": DataTypeTag[operation.C.element], # type: ignore[name-defined] + "layout_c": LayoutTag[instance_layout_C], # type: ignore[name-defined] + "element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined] + "layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined] + "element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined] + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950 + "arch": "cutlass::arch::Sm%d" % operation.arch, + "tile_shape_m": str(operation.tile_description.tile_shape[0]), + "tile_shape_n": str(operation.tile_description.tile_shape[1]), + "tile_shape_k": str(operation.tile_description.tile_shape[2]), + "cluster_m": str(operation.tile_description.cluster_shape[0]), + "cluster_n": str(operation.tile_description.cluster_shape[1]), + "cluster_k": str(operation.tile_description.cluster_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined] + "epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined] + "epilogue_functor": epilogue_functor, + "stages": stage_count_string, + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], # type: ignore[name-defined] + "transform_b": ComplexTransformTag[operation.B.complex_transform], # type: ignore[name-defined] + "math_operation": MathOperationTag[ # type: ignore[name-defined] + operation.tile_description.math_instruction.math_operation + ], + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), # type: ignore[name-defined] + "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), # type: ignore[name-defined] + } + + return SubstituteTemplate(self.gemm_template, values) # type: ignore[name-defined] diff --git a/torch/_inductor/codegen/cuda/cutlass_utils.py b/torch/_inductor/codegen/cuda/cutlass_utils.py index dd1170d909213..b03f4c79933c6 100644 --- a/torch/_inductor/codegen/cuda/cutlass_utils.py +++ b/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -17,6 +17,33 @@ log = logging.getLogger(__name__) +def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str: + for cutlass_module in cutlass_modules: + content = content.replace( + f"from {cutlass_module} import ", + f"from cutlass_library.{cutlass_module} import ", + ) + return content + + +def _gen_cutlass_file( + file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str +) -> None: + orig_full_path = os.path.abspath(os.path.join(src_dir, file_name)) + text = "" + with open(orig_full_path) as f: + text = f.read() + text = _rename_cutlass_import(text, cutlass_modules) + dst_full_path = os.path.abspath( + os.path.join( + dst_dir, + file_name, + ) + ) + with open(dst_full_path, "w") as f: + f.write(text) + + @functools.lru_cache(None) def try_import_cutlass() -> bool: # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 8a51c73a15ca0..2a7304758b5c4 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -1,21 +1,22 @@ import copy import logging import re -from typing import Dict, List, Optional, Tuple +from typing import cast, Dict, List, Optional, Tuple from ...config import cuda as inductor_cuda_config -from ...ir import Buffer, FixedLayout, IRNode, Layout +from ...ir import Buffer, CUDATemplateBuffer, FixedLayout, IRNode, Layout from ..common import IndentedBuffer from . import cutlass_utils from .cuda_kernel import CUDATemplateKernel from .cuda_template import CUTLASSTemplate +from .cutlass_epilogue_gen import ( + CutlassEVTEpilogueArgumentFormatter, + CutlassEVTEpilogueTypeFormatter, +) log = logging.getLogger(__name__) - -# Only supports alpha * A@B + beta * C now. -# TODO: Support arbitrary epilogue after epilogue visitor is released in cutlass 3.2. GEMM_TEMPLATE = r""" {{template.header().getvalue()}} {{template.globals().getvalue()}} @@ -36,7 +37,8 @@ using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; using coord_t = cutlass::gemm::GemmCoord::Index; {{instance_type}}::Arguments arguments; - {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, X, W, Bias, Y, alpha, beta, kernel)}} + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} {{instance_type}} gemm_op; if (workspace_size) { *workspace_size = gemm_op.get_workspace_size(arguments); @@ -132,8 +134,9 @@ """ GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 { - {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename ThreadEpilogueOp::Params thread + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C { {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, @@ -151,7 +154,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate): - # Calculates alpha * X@W + beta * Bias + """ + CUTLASS GEMM template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ def __init__( self, @@ -160,10 +166,81 @@ def __init__( alpha: float, beta: float, input_reorder: Optional[List[int]] = None, + can_fuse_epilogue: Optional[bool] = None, ): + """ + Args: + input_nodes: input nodes of the kernel + layout: layout of the output node + alpha: alpha value of the GEMM operation + beta: beta value of the GEMM operation + input_reorder: reorder of the input nodes + can_fuse_epilogue: If set to True, will only list and use operators capable of flexible epilogue fusions. + If False, it will not use those. If None, both may be listed, but it will not allow fusions. + Defaults to None + """ super().__init__("cutlass_gemm", input_nodes, layout, input_reorder) self.alpha = alpha self.beta = beta + self.can_fuse_epilogue = can_fuse_epilogue + + @staticmethod + def add_cutlass_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + fuseable=True, + non_fuseable=True, + ): + if non_fuseable: + if fuseable: + # list both fuseable and non-fuseable ops, and treat them all as non-fuseable + can_fuse_epilogue = False + else: + can_fuse_epilogue = None + + cutlass_template = CUTLASSGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + can_fuse_epilogue=can_fuse_epilogue, + ) + ops = cutlass_template.gen_ops() + for op in ops: + cutlass_template.maybe_append_choice( + choices, + op=op, + ) + else: + ops = [] + if fuseable: + cutlass_template_evt = CUTLASSGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + can_fuse_epilogue=True, + ) + # This will list only ops capable of EVT fusion + ops_evt = cutlass_template_evt.gen_ops() + for op in ops_evt: + cutlass_template_evt.maybe_append_choice( + choices, + op=op, + ) + else: + ops_evt = [] + log.debug( + "Added %d cutlass gemm configs and %d fuseable gemm configs.", + len(ops), + len(ops_evt), + ) def header(self) -> IndentedBuffer: res = super().header() @@ -175,6 +252,13 @@ def header(self) -> IndentedBuffer: #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" """ ) return res @@ -228,19 +312,67 @@ def has_tma_epilogue(op) -> bool: return result @staticmethod + def supports_evt(op: "cutlass_library.gemm_op.GemmOperation") -> bool: # type: ignore[name-defined] + """ + returns True if the op is capable of flexible epilogue fusions + using epilogue visitor trees. + + See https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu#L283-L285 # noqa: B950 + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib # type: ignore[import] + + if op.gemm_kind != cutlass_lib.GemmKind.Universal3x: + return False + if op.epilogue_schedule not in ( + cutlass_lib.EpilogueScheduleType.TmaWarpSpecialized, + cutlass_lib.EpilogueScheduleType.TmaWarpSpecializedCooperative, + ): + return False + + return True + + def render_evt_epilogue_declaration( + self, + template_output_node_name: str, + evt_type_name: str, + epilogue_nodes: List[IRNode], + ) -> str: + """Generates the epilogue for the EVT epilogue fusion""" + return CutlassEVTEpilogueTypeFormatter.ir_to_evt_string( + template_output_node_name, evt_type_name, epilogue_nodes + ) + def define_gemm_instance( - op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined] + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] + output_buffer_name: str, + epilogue_nodes: Optional[List[IRNode]] = None, ) -> Tuple[str, str]: assert cutlass_utils.try_import_cutlass() import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import] import cutlass_library.library as cutlass_lib # type: ignore[import] + from torch._inductor.codegen.cuda.cutlass_lib_extensions.gemm_operation_extensions import ( + EmitGemmUniversal3xInstanceWithEVT, + ) + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: - emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance() + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + emitter = EmitGemmUniversal3xInstanceWithEVT() + op.epilogue_functor = lambda epilogue_functor_type_name: self.render_evt_epilogue_declaration( + output_buffer_name, epilogue_functor_type_name, epilogue_nodes + ) + else: + emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance() op_def = emitter.emit(op) pattern = re.compile(r"\s*struct\s(.*?)\s:") decl = [line for line in op_def.split("\n") if "struct " in line][-1] else: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + raise RuntimeError( + "EVT epilogue fusion is not supported for Cutlass 2.x ops." + ) emitter = cutlass_gemm_op.EmitGemmInstance() op_def = emitter.emit(op) op_def = op_def.replace( @@ -275,7 +407,9 @@ def should_swap_XW( # return False @staticmethod - def swap_XW(op: "cutlass_gemm_op.GemmOperation") -> "cutlass_gemm_op.GemmOperation": # type: ignore[name-defined] + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # Swap X and W in GemmOperation. new_op = copy.deepcopy(op) new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) @@ -287,8 +421,8 @@ def swap_XW(op: "cutlass_gemm_op.GemmOperation") -> "cutlass_gemm_op.GemmOperati def filter_op( self, - op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined] - ) -> "cutlass_gemm_op.GemmOperation": # type: ignore[name-defined] + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] assert cutlass_utils.try_import_cutlass() import cutlass_library.library as cutlass_lib # type: ignore[import] @@ -365,7 +499,13 @@ def filter_op( op.C.element = cutlass_lib.DataType.void else: op.C.layout = op.D.layout - + supports_evt: bool = self.supports_evt(op) + if (self.can_fuse_epilogue is not None) and ( + self.can_fuse_epilogue != supports_evt + ): + return None + if inductor_cuda_config.cutlass_only_evt_capable_ops and not supports_evt: + return None return op def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] @@ -380,6 +520,7 @@ def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name for op_dict in ops.values(): for op_list in op_dict.values(): for op in op_list: + assert isinstance(op, cutlass_gemm_op.GemmOperation) filter_res = self.filter_op(op) if ( filter_res is not None @@ -419,6 +560,7 @@ def render_gemm_arguments( alpha: float, beta: float, kernel: CUDATemplateKernel, + epilogue_args, ) -> str: options = dict( alpha=self.alpha, @@ -431,6 +573,7 @@ def render_gemm_arguments( kernel=kernel, M="M", N="N", + epilogue_args=epilogue_args, ) if epilogue_template is not None: @@ -476,36 +619,72 @@ def clone_with_transposed_stride(node: IRNode) -> IRNode: def render( # type: ignore[override] self, kernel: CUDATemplateKernel, - op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined] - output_node: Optional[Buffer] = None, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] + template_buffer_node: Optional[CUDATemplateBuffer] = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, ) -> str: + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + assert self.can_fuse_epilogue and CUTLASSGemmTemplate.supports_evt( + op + ), "op does not support EVT epilogue fusion" + assert ( + template_buffer_node is not None + ), "Template node is required for epilogue fusion" + assert isinstance( + template_buffer_node, CUDATemplateBuffer + ), f"Template node has to be a CUDATemplateBuffer, is type {type(template_buffer_node)}" + assert ( + template_buffer_node.name is not None + ), "Output node has to be a Buffer with a name" + # This is the name of the output of the Matmul, before epilogues are applied. + # it is not necessarily materialized in global memory if we have an epilogue + + template_output_node_name = ( + template_buffer_node.name if template_buffer_node is not None else None + ) + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import] import cutlass_library.library as cutlass_lib # type: ignore[import] - if output_node is not None: - self.output_node = output_node + assert isinstance( + op, cutlass_gemm_op.GemmOperation + ), "op argument is required and has to be an instance of GemmOperation" + if template_buffer_node is not None: + self.output_node = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + self.output_node = cast(Buffer, epilogue_nodes[-1]) + assert len(self.input_nodes) >= 2 and self.output_node is not None X, W = self.input_nodes[0], self.input_nodes[1] Y = self.output_node Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] epilogue_template: Optional[str] = None - argument_template: Optional[str] = None should_swap_xw: bool = False - + epilogue_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: if Bias is not None and self.has_tma_epilogue(op): if self.should_swap_XW(Bias, self.beta): # TMA epilogue requires bias vector in column major to get best perf. op = self.swap_XW(op) should_swap_xw = True + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + epilogue_args = ( + CutlassEVTEpilogueArgumentFormatter.ir_to_evt_argument_string( + cast(str, template_output_node_name), epilogue_nodes + ) + ) epilogue_template = GEMM_ARGS_CUTLASS_3X_EPILOGUE argument_template = GEMM_ARGS_CUTLASS_3X else: # TODO: Support split_k. argument_template = GEMM_ARGS_CUTLASS_2X - instance_definition, instance_type = self.define_gemm_instance(op) + instance_definition, instance_type = self.define_gemm_instance( + op, cast(str, template_output_node_name), epilogue_nodes + ) options = dict( alpha=self.alpha, beta=self.beta, @@ -521,6 +700,7 @@ def render( # type: ignore[override] instance_definition=instance_definition, instance_type=instance_type, input_reorder=self.input_reorder, + epilogue_args=epilogue_args, ) res = self._template_from_string(GEMM_TEMPLATE).render(**options) return res diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py new file mode 100644 index 0000000000000..6d04983ef4696 --- /dev/null +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -0,0 +1,75 @@ +from typing import List + +from ..scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode +from .cuda.cuda_cpp_scheduling import CUDACPPScheduling + +from .triton import TritonScheduling + + +class CUDACombinedScheduling(BaseScheduling): + """ + Scheduler for CUDA Kernels, which delegates calls as appropriate + to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Scheduler): + super().__init__() + self._scheduler = scheduler + self._triton_scheduling = TritonScheduling(scheduler) + self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cuda_cpp_scheduling.is_cuda_cpp_template( + node + ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node): + return self._cuda_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode): + for node in (node1, node2): + if self._cuda_cpp_scheduling.is_cuda_cpp_template( + node + ) or self._cuda_cpp_scheduling.is_cuda_cpp_fused_template(node): + return self._cuda_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn(self, sizes): + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): + return self._cuda_cpp_scheduling.codegen_template( + template_node, epilogue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes + ) + + def codegen_nodes(self, nodes: List[BaseSchedulerNode]): + return self._triton_scheduling.codegen_nodes(nodes) + + def codegen_sync(self): + return self._triton_scheduling.codegen_sync() + + def flush(self): + return self._triton_scheduling.flush() + + def codegen_foreach(self, *args, **kwargs): + return self._triton_scheduling.codegen_foreach(*args, **kwargs) + + def benchmark_fused_nodes(self, nodes): + return self._triton_scheduling.benchmark_fused_nodes(nodes) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 5d60fbb277eda..1e94c62630a24 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1,4 +1,4 @@ -import os +import os # noqa: C101 import sys import torch @@ -559,6 +559,11 @@ class cuda: # 4) default system search PATH. cuda_cxx = None + # If set to True, it will ensure that only GEMM ops capable of + # epilogue fusion via CUTLASS Epilogue Visitor Trees ( EVT ) + # are enabled for the CUTLASS backend. + cutlass_only_evt_capable_ops: bool = False + # create a directory containing lots of debug information class trace: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index aa9e6306a1130..91b4d08676dc9 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -156,9 +156,10 @@ def init_backend_registration(self): register_backend_for_device("cpu", CppScheduling, WrapperCodeGen) if get_scheduling_for_device("cuda") is None: - from .codegen.triton import TritonScheduling + from .codegen.cuda_combined_scheduling import CUDACombinedScheduling - register_backend_for_device("cuda", TritonScheduling, WrapperCodeGen) + # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation + register_backend_for_device("cuda", CUDACombinedScheduling, WrapperCodeGen) def __init__( self, diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index f259e0e50ad6c..6b48055a9ea8e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2804,6 +2804,17 @@ def codegen_reference(self, writer=None): class ComputedBuffer(Buffer): data: Loops + def get_computed_buffer_name(self): + """ + Returns self.name if it exists, otherwise returns the name of the data node if that exists. + If neither exist, returns None. + """ + if self.name is not None: + return self.name + if hasattr(self.data, "name"): + return self.data.name + return None + @cache_on_self def num_reads(self): return len(self.get_read_writes().reads) @@ -3129,11 +3140,13 @@ def __init__( layout, inputs, make_kernel_render, - workspace_size: int = 0, + workspace_size: int, + template: "CUDATemplate", # type: ignore[name-defined] ): super().__init__(layout, inputs, make_kernel_render) # Global memory (in bytes) needed for this template. self.workspace_size = workspace_size + self.template = template def get_workspace_size(self): return self.workspace_size if self.workspace_size is not None else 0 diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index ed3ec98334e92..57c21fbc9df76 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -134,14 +134,9 @@ def tuned_mm(mat1, mat2, *, layout=None): ) if m * n != 0 and use_cutlass_template(layout): - cutlass_template = CUTLASSGemmTemplate([mat1, mat2], layout, alpha=1, beta=0) - ops = cutlass_template.gen_ops() - for op in ops: - cutlass_template.maybe_append_choice( - choices, - op=op, - ) - log.debug("Added %d cutlass gemm configs.", len(ops)) + CUTLASSGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) from torch._inductor.ir import FixedLayout, FlexibleLayout @@ -242,20 +237,15 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): ) if use_cutlass_template(layout): - cutlass_template = CUTLASSGemmTemplate( - [mat1, mat2, inp_expanded], + CUTLASSGemmTemplate.add_cutlass_gemm_choices( + choices, layout, + [mat1, mat2, inp_expanded], alpha=alpha, beta=beta, input_reorder=[2, 0, 1], + fuseable=False, ) - ops = cutlass_template.gen_ops() - for op in ops: - cutlass_template.maybe_append_choice( - choices, - op=op, - ) - log.debug("Added %d cutlass gemm configs.", len(ops)) return autotune_select_algorithm( "addmm", choices, [inp_expanded, mat1, mat2], layout diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index fcb1b0619aaa4..4598d84577a3d 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -837,7 +837,7 @@ def used_buffer_names(self) -> Set[str]: def used_or_aliased_buffer_names(self) -> Set[str]: return set.union(*[x.used_or_aliased_buffer_names() for x in self.snodes]) - def get_nodes(self) -> Sequence[BaseSchedulerNode]: + def get_nodes(self) -> List[SchedulerNode]: return self.snodes def __repr__(self): @@ -851,6 +851,13 @@ def is_reduction(self): def is_template(self): return any(x.is_template() for x in self.snodes) + @cache_on_self + def get_template_node(self): + for node in self.snodes: + if node.is_template(): + return node + return None + def get_device(self): return self.group[0] @@ -2163,12 +2170,7 @@ def codegen(self): if node.is_template(): node, *epilogue = node.get_nodes() - if isinstance(node.node, ir.CUDATemplateBuffer): - from .codegen.cuda.cuda_scheduling import CUDAScheduling - - CUDAScheduling(self).codegen_template(node, epilogue) - else: - self.get_backend(device).codegen_template(node, epilogue) + self.get_backend(device).codegen_template(node, epilogue) elif node.is_extern(): self.codegen_extern_call(node) elif node.is_foreach(): @@ -2176,6 +2178,7 @@ def codegen(self): elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): self.get_backend(device).codegen_nodes(node.get_nodes()) else: + assert isinstance(node, NopKernelSchedulerNode) node.allocate() if config.debug_check_inf_and_nan: @@ -2220,7 +2223,7 @@ def group_fn(self, sizes): raise NotImplementedError() def codegen_template( - self, template_node: BaseSchedulerNode, epilogue_nodes: List[BaseSchedulerNode] + self, template_node: SchedulerNode, epilogue_nodes: List[SchedulerNode] ): """ Given a template node, generate a kernel.