From 453bca33106e3d119ed0edb7dd8b3642bc931215 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Wed, 17 Apr 2024 06:02:16 -0700 Subject: [PATCH] [inductor][cpp] GEMM template ghstack-source-id: 30af30579007f953540b54d0d2811a9ae6868234 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021 --- test/inductor/test_cpu_select_algorithm.py | 71 +++++++ torch/_inductor/autotune_process.py | 133 ++++++++++--- torch/_inductor/codecache.py | 3 - torch/_inductor/codegen/cpp.py | 49 ++++- torch/_inductor/codegen/cpp_gemm_template.py | 177 ++++++++++++++++++ torch/_inductor/codegen/cpp_template.py | 114 +++++++++++ .../_inductor/codegen/cpp_template_kernel.py | 152 +++++++++++++++ torch/_inductor/ir.py | 22 ++- torch/_inductor/kernel/mm.py | 19 ++ torch/_inductor/mkldnn_lowerings.py | 64 ++++++- torch/_inductor/select_algorithm.py | 121 +++++++++--- torch/_inductor/utils.py | 20 ++ 12 files changed, 888 insertions(+), 57 deletions(-) create mode 100644 test/inductor/test_cpu_select_algorithm.py create mode 100644 torch/_inductor/codegen/cpp_gemm_template.py create mode 100644 torch/_inductor/codegen/cpp_template.py create mode 100644 torch/_inductor/codegen/cpp_template_kernel.py diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py new file mode 100644 index 0000000000000..706484fee8c41 --- /dev/null +++ b/test/inductor/test_cpu_select_algorithm.py @@ -0,0 +1,71 @@ +# Owner(s): ["oncall: cpu inductor"] +import functools +import unittest +from unittest.mock import patch + +import torch +import torch._dynamo.config as dynamo_config +import torch._inductor.config as inductor_config +import torch._inductor.select_algorithm as select_algorithm +from torch._dynamo.utils import counters +from torch._inductor.test_case import run_tests, TestCase + +from torch.testing._internal.common_utils import IS_MACOS, TEST_MKL + +aten = torch.ops.aten + + +def patches(fn): + def skip_cache(self, choices, name, key, generate): + return generate(choices) + + for patcher in [ + dynamo_config.patch(verbose=True), + inductor_config.patch(debug=True, max_autotune=True, epilogue_fusion=True), + patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), + patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), + ]: + fn = patcher(fn) + + @functools.wraps(fn) + def wrapped(*args, **kwargs): + counters.clear() + torch.manual_seed(12345) + return fn(*args, **kwargs) + + return wrapped + + +class TestSelectAlgorithm(TestCase): + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + def test_linear_fp32_cpu(self): + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear = torch.nn.Linear(10, 32, bias) + + @torch.compile + def forward(self, x): + return self.linear(x) + + for bias in [True, False]: + counters.clear() + mod = M(bias=bias).eval() + v = torch.randn(2, 10) + mod(v) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + + +@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) +class TestDynamicSelectAlgorithm(TestCase): + test_linear_fp32_dynamic_shapes_cpu = TestSelectAlgorithm.test_linear_fp32_cpu + + +if __name__ == "__main__": + from torch.testing._internal.inductor_utils import HAS_CPU + + if HAS_CPU and not IS_MACOS: + run_tests() diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 790ec9d60ec0f..e7e66062abef9 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -9,9 +9,10 @@ import time import warnings from concurrent.futures import ThreadPoolExecutor -from ctypes import byref, c_size_t, c_void_p +from ctypes import byref, c_size_t, c_void_p, CDLL from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue +from types import ModuleType from typing import ( Any, Callable, @@ -29,13 +30,19 @@ from torch._dynamo.testing import rand_strided from torch._inductor import ir -from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache +from torch._inductor.codecache import ( + CppCodeCache, + CUDACodeCache, + DLLWrapper, + get_hash, + PyCodeCache, +) if TYPE_CHECKING: from torch._inductor.select_algorithm import TritonTemplateCaller from . import config -from .utils import do_bench +from .utils import do_bench, timed from .virtualized import V CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" @@ -427,6 +434,14 @@ def make_run_fn( def cleanup_run_fn(self) -> None: pass + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + raise NotImplementedError() + def benchmark( self, *input_tensors: torch.Tensor, @@ -452,22 +467,7 @@ def benchmark( load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] start_ts = time.time() - device_idx_set = { - tensor.device.index - for tensor in [*input_tensors, output_tensor] - if isinstance(tensor, torch.Tensor) - and tensor.is_cuda - and tensor.device.index is not None - } - assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" - if len(device_idx_set) == 1: - device_idx = next(iter(device_idx_set)) - else: - device_idx = torch.cuda.current_device() - - with torch.cuda.device(device_idx): - out = do_bench(fn) - torch.cuda.synchronize() # shake out any CUDA errors + out = self.do_bench(fn, *input_tensors, output_tensor) if debug: bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] @@ -499,7 +499,34 @@ def benchmark( return self.value -class TritonBenchmarkRequest(BenchmarkRequest): +class GPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + device_idx_set = { + tensor.device.index + for tensor in [*input_tensors, output_tensor] + if isinstance(tensor, torch.Tensor) + and tensor.is_cuda + and tensor.device.index is not None + } + assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = torch.cuda.current_device() + + with torch.cuda.device(device_idx): + out = do_bench(fn) + torch.cuda.synchronize() # shake out any CUDA errors + + return out + + +class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! @@ -573,7 +600,7 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" -class CUDABenchmarkRequest(BenchmarkRequest): +class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! @@ -661,6 +688,70 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" +class CPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + return timed(fn, ()) + + +@dataclasses.dataclass +class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ): + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.hash_key = get_hash(source_code) + self.DLL: Optional[Union[CDLL, ModuleType]] = None + + def precompile(self): + # Prepopulate CppCodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + CppCodeCache.load(self.source_code, cuda=False) + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.DLL = CppCodeCache.load(self.source_code, cuda=False) + args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] + log.debug( + "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.DLL, + args, + self.extra_args, + ) + run_method = getattr(self.DLL, self.kernel_name) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + + def __str__(self) -> str: + return f"{self.kernel_name=}" + + def benchmark_in_sub_process( choices: List[TritonTemplateCaller], ) -> Dict[TritonTemplateCaller, float]: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index d55b81e16d157..729d7bb8713c3 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -183,9 +183,6 @@ def get_global_cache_path() -> Optional[Path]: ) def __init__(self) -> None: - if not torch.cuda.is_available(): - return - self.system = CacheBase.get_system() self.local_cache_path = CacheBase.get_local_cache_path() diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 31dbe27c229ec..8f485bc3a5bf0 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -8,7 +8,7 @@ import sys from copy import copy, deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union import sympy @@ -19,6 +19,7 @@ from torch.utils import _pytree as pytree from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges +from ..._dynamo.utils import counters from .. import codecache, config, ir, metrics from ..codegen.wrapper import WrapperCodeGen @@ -3729,6 +3730,8 @@ def _can_fuse_horizontal_impl(self, node1, node2): return self._why_fuse_nodes(node1, node2) is not None def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False if ( len(node1.get_nodes()) + len(node2.get_nodes()) > config.cpp.max_horizontal_fusion_size @@ -3809,6 +3812,9 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): + # TODO: support vertical fusion for template nodes + if node1.is_template() or node2.is_template(): + return False return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) @@ -3865,6 +3871,42 @@ def codegen_node( if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: self._set_flush_status(True) + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template( + template_node + ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() + def _get_scheduled_num_args(self): return self.kernel_group.get_num_args() @@ -3874,7 +3916,7 @@ def ready_to_flush(self): def codegen_sync(self): pass - def define_kernel(self, src_code, nodes): + def define_kernel(self, src_code, nodes, kernel_args=None): wrapper = V.graph.wrapper_code fused_name = ( get_fused_kernel_name(nodes, config.cpp.descriptive_names) @@ -3890,7 +3932,8 @@ def define_kernel(self, src_code, nodes): src_code = src_code.replace("#pragma CMT", "//") compile_wrapper = IndentedBuffer() - _, _, arg_types = self.kernel_group.args.cpp_argdefs() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() if not V.graph.cpp_wrapper: compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") compile_wrapper.splice(src_code, strip=True) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000..43fbbe2336ac3 --- /dev/null +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,177 @@ +from typing import cast, List, Optional + +import torch +from torch._inductor.select_algorithm import DataProcessorTemplateWrapper +from .. import ir + +from ..ir import Buffer, CppTemplateBuffer, IRNode, Layout +from ..lowering import permute, view +from .cpp_template import CppTemplate + +from .cpp_template_kernel import CppTemplateKernel + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} +// TODO: use micro-kernel to replace this naive GEMM implementation below +extern "C" +{{kernel.def_kernel(inputs=[X, W, inp], outputs=[Y], names_str="X, W, inp, Y")}} +{ + // TODO: support dynamic shapes + int64_t M = {{kernel.size(Y, 0)}}; + int64_t N = {{kernel.size(Y, 1)}}; + int64_t K = {{kernel.size(X, 1)}}; + + #pragma omp parallel for collapse(2) + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N/{{n_bs}}; ++j) { + {{kernel.acc_dtype(Y)}} sum[16]; + for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { + {% if inp is none %} + sum[ni] = 0; + {% else %} + int64_t n = j * {{n_bs}} + ni; + sum[ni] = {{beta}} * {{kernel.index(inp, ["i", "n"])}}; + {% endif %} + } + for (int64_t k = 0; k < K; ++k) { + for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { + sum[ni] += {{kernel.index(X, ["i", "k"])}} * {{kernel.index(W, ["j", "k", "ni"])}}; + } + } + for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { + int64_t n = j * {{n_bs}} + ni; + {{kernel.index(Y, ["i", "n"])}} = {{alpha}} * sum[ni]; + } + } + } +} +""" + + +class CppPackedGemmTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: Layout, + beta=1, + alpha=1, + n_block_size: int = 1, + ): + super().__init__("cpp_gemm", input_nodes, layout) + self.beta = beta + self.alpha = alpha + self.n_block_size = n_block_size + + @staticmethod + def add_choices( + choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None + ): + if input_indices is None: + input_indices = list(range(len(input_nodes))) + + def reorder_and_filter(inputs, layout_or_out): + if len(input_indices) == 2: + x_idx = input_indices[0] + w_idx = input_indices[1] + return [inputs[x_idx], inputs[w_idx]], layout_or_out + else: + assert ( + len(input_indices) == 3 + ), "Cpp Packed GEMM template requires 2 or 3 input nodes." + # assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out + + def transpose_weight(inputs, layout_or_out): + if not trans_w: + return inputs, layout_or_out + + new_inputs = list(inputs) + W = inputs[1] + if isinstance(W, ir.IRNode): + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + new_inputs[1] = permute(W, [1, 0]) + return new_inputs, layout_or_out + else: + assert isinstance(W, torch.Tensor) + new_inputs[1] = W.transpose(0, 1) + return new_inputs, layout_or_out + + n_block_size = 16 + + def pack_weight(inputs, layout_or_out): + W = inputs[1] + new_inputs = list(inputs) + if isinstance(W, ir.IRNode): + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + k, n = W.get_size() + assert ( + n % n_block_size == 0 + ), f"The last dimension of W must be a multiple of {n_block_size}." + blocked_w = permute( + view(W, (k, n // n_block_size, n_block_size)), + [1, 0, 2], + ) + blocked_w = ir.ExternKernel.require_contiguous(blocked_w) + blocked_w = ir.ExternKernel.realize_input(blocked_w) + else: + k, n = list(W.shape) + blocked_w = ( + W.reshape(k, n // n_block_size, n_block_size) + .transpose(0, 1) + .contiguous() + ) + new_inputs[1] = blocked_w + return new_inputs, layout_or_out + + def preprocessor(inputs, layout): + return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) + + template = DataProcessorTemplateWrapper( + CppPackedGemmTemplate, + preprocessor, + None, + input_nodes=input_nodes, + layout=layout, + n_block_size=n_block_size, + ) + template.maybe_append_choice(choices) + return template + + def render( # type: ignore[override] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[CppTemplateBuffer] = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ) -> str: + assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template." + assert len(self.input_nodes) >= 2 + + if template_buffer_node is not None: + self.output_node = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + self.output_node = cast(Buffer, epilogue_nodes[-1]) + assert self.output_node is not None + + X, W = self.input_nodes[0], self.input_nodes[1] + inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None + Y = self.output_node + + options = dict( + X=X, + W=W, + inp=inp, + Y=Y, + beta=self.beta, + alpha=self.alpha, + n_bs=self.n_block_size, + template=self, + kernel=kernel, + epilogues=epilogue_nodes, + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py new file mode 100644 index 0000000000000..99ec7305f8324 --- /dev/null +++ b/torch/_inductor/codegen/cpp_template.py @@ -0,0 +1,114 @@ +import functools +import itertools +import logging +from typing import List, Optional +from unittest.mock import patch + +import sympy + +from .. import codecache +from ..autotune_process import CppBenchmarkRequest, TensorMeta +from ..ir import Buffer, IRNode, Layout +from ..utils import IndentedBuffer, Placeholder, unique +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateBuffer, CppTemplateCaller, CppTemplateKernel + +log = logging.getLogger(__name__) + + +class CppTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes, + layout: Layout, + ): + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer("buf_out", layout) + self.layout = layout + + def generate(self, **kwargs): + kernel_name = f"cpp_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), CppTemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + expected_args = list( + unique(input_node.get_name() for input_node in self.input_nodes) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + + kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}" + + # Create the BenchmarkRequest for CPP + bmreq = CppBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: CppTemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): + kernel = CppTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + return kernel, render + + return CppTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + res.writeline(codecache.cpp_prefix()) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py new file mode 100644 index 0000000000000..24323a59e359d --- /dev/null +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -0,0 +1,152 @@ +from typing import Callable, Dict, List, Optional, Union + +import sympy + +import torch + +from torch._inductor.autotune_process import CppBenchmarkRequest +from ..ir import ( + Buffer, + ChoiceCaller, + CppTemplateBuffer, + IRNode, + Layout, + PrimitiveInfoType, + TensorBox, +) +from ..virtualized import V +from .common import Kernel, OpOverrides +from .cpp import cexpr_index + + +class CppTemplateKernel(Kernel): + overrides = OpOverrides + + def __init__(self, kernel_name): + super().__init__() + self.kernel_name = kernel_name + + def def_kernel( + self, + inputs: List[Buffer], + outputs: List[Buffer], + names_str: str = "", + input_reorder: Optional[List[int]] = None, + ) -> str: + input_names = [inp.get_name() if inp is not None else None for inp in inputs] + output_names = [out.get_name() for out in outputs] + all_names = input_names + output_names + assert len(all_names) == len(names_str.split(",")), ( + all_names, + names_str, + ) + names = names_str.split(",") + for i, input_name in enumerate(input_names): + if input_name is not None: + self.args.input_buffers[input_name] = names[i].strip() + for i, output_name in enumerate(output_names): + self.args.output_buffers[output_name] = names[i + len(input_names)].strip() + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" + + def call_kernel(self, name: str, node: CppTemplateBuffer): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) + + def dtype(self, node: Buffer) -> str: + if node.get_dtype() == torch.float32: + return "float" + elif node.get_dtype() == torch.bfloat16: + return "float" + elif node.get_dtype() == torch.half: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def acc_dtype(self, node: Buffer) -> str: + if node.get_dtype() == torch.float32: + return "float" + elif node.get_dtype() == torch.bfloat16: + return "float" + elif node.get_dtype() == torch.half: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def size(self, node: Buffer, dim: int) -> str: + return str(node.get_size()[dim]) + + def index(self, node: Buffer, indices: List[str]) -> str: + indexer = node.make_indexer() + index = indexer([sympy.Symbol(idx) for idx in indices]) + return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" + + +class CppTemplateCaller(ChoiceCaller): + """ + CppTemplateCaller + + This class represents a caller for CPP template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CppBenchmarkRequest): The benchmark request for the caller. + template_buffer (CppTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[Buffer], + layout: Layout, + make_kernel_render: Callable[[CppTemplateBuffer, Optional[List[IRNode]]], str], + bmreq: CppBenchmarkRequest, + template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 + info_kwargs: Optional[ + Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + # def __str__(self): + # return f"CppTemplateCaller(source_file={self.bmreq.source_file})" + + # def call_name(self) -> str: + # return f"cpp_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + return {"backend": "CPP", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + return TensorBox.create( + CppTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + choice=self, + ) + ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0fd35193a877c..e094d23725f7f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -68,6 +68,7 @@ developer_warning, do_bench, get_kernel_metadata, + is_cpu_device, is_dynamic, is_gpu, pad_listlike, @@ -75,6 +76,7 @@ sympy_index_symbol, sympy_product, sympy_subs, + timed, ) from .virtualized import ops, V @@ -3531,7 +3533,10 @@ def __init__(self, name, input_nodes, layout): def benchmark(self, *args, out) -> float: algo = self.to_callable() - return do_bench(lambda: algo(*args, out=out)) + if is_cpu_device(args): + return timed(lambda: algo(*args, out=out), ()) + else: + return do_bench(lambda: algo(*args, out=out)) def call_name(self) -> str: raise NotImplementedError() @@ -3622,6 +3627,13 @@ def get_workspace_size(self): return self.workspace_size if self.workspace_size is not None else 0 +class CppTemplateBuffer(TemplateBuffer): + def __init__(self, layout, inputs, make_kernel_render, template, choice): + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.choice = choice + + @dataclasses.dataclass class InputsKernel(Buffer): inputs: List[Buffer] @@ -6015,7 +6027,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, packed_w, orig_w, batch_size): + def create(cls, x, packed_w, orig_w, B, batch_size): x = cls.require_stride1(cls.realize_input(x)) orig_w = cls.require_stride1(cls.realize_input(orig_w)) *m, _ = x.get_size() @@ -6023,7 +6035,11 @@ def create(cls, x, packed_w, orig_w, batch_size): output_size = list(m) + [oc] output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] - constant_args = [None, batch_size] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) return MKLPackedLinear( layout=FixedLayout( diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 2cb78e0c45c92..c4e5fc6573e9c 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional import torch +from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from torch._inductor.virtualized import V from .. import config as inductor_config from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate @@ -14,6 +15,7 @@ ) from ..utils import ( use_aten_gemm_kernels, + use_cpp_packed_gemm_template, use_cutlass_template, use_max_autotune, use_triton_template, @@ -256,6 +258,23 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): fuseable=False, ) + if use_cpp_packed_gemm_template(layout, mat1, mat2): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [inp_expanded, mat1, mat2], + alpha=alpha, + beta=beta, + ) + choices.append( + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ) + return autotune_select_algorithm( "addmm", choices, [inp_expanded, mat1, mat2], layout ) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 0ebccbf27ea3b..0a6ca89287849 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,10 +1,22 @@ -from typing import List +from typing import List, Optional import torch import torch.utils._pytree as pytree -from . import ir +from torch._inductor.kernel.mm_common import mm_args +from . import config, ir +from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox -from .lowering import add, add_needs_realized_inputs, aten, register_lowering, to_dtype +from .lowering import ( + add, + add_needs_realized_inputs, + aten, + permute, + register_lowering, + to_dtype, +) +from .select_algorithm import autotune_select_algorithm, ExternKernelChoice +from .utils import use_cpp_packed_gemm_template +from .virtualized import V def register_onednn_fusion_ops(): @@ -339,6 +351,12 @@ def qlinear_unary( ) if torch._C.has_mkl: + aten_mkl_linear = ExternKernelChoice( + torch.ops.mkl._mkl_linear, + "mkl::_mkl_linear", + has_out_variant=False, + kernel_creator=ir.MKLPackedLinear.create, + ) cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) @register_lowering(torch.ops.mkl._mkl_linear) @@ -346,12 +364,46 @@ def mkl_packed_linear( x: TensorBox, packed_w: TensorBox, orig_w: TensorBox, - b: TensorBox, + b: Optional[TensorBox], batch_size, + *, + layout=None, ): - result = TensorBox.create( - ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size) + choices = [ + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute(orig_w, [1, 0]) + *_, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, packed_w, orig_w], + trans_w=True, + input_indices=[0, 2], + ) + + assert isinstance(packed_w.data, ir.StorageBox) + assert isinstance(packed_w.data.data, ir.ConstantBuffer) + assert isinstance(orig_w.data, ir.StorageBox) + assert isinstance(orig_w.data.data, ir.ConstantBuffer) + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + 2: lambda x: V.graph.constants[x.get_name()], + } + chosen_node: TensorBox = autotune_select_algorithm( + "packed_linear", + choices, + [x, packed_w, orig_w], + layout, + input_gen_fns=input_gen_fns, ) + result = TensorBox.create(chosen_node) if b is not None: result = add(result, b) return result diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 03a8e63141ff9..217053e4a344b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -38,9 +38,11 @@ from .utils import ( do_bench, get_dtype_size, + is_cpu_device, Placeholder, sympy_dot, sympy_product, + timed, unique, ) from .virtualized import V @@ -662,17 +664,19 @@ def __init__( has_out_variant=True, op_overload=None, use_fallback_kernel=False, + kernel_creator=None, ): super().__init__() name = name or kernel.__name__ assert callable(kernel) - assert not hasattr(extern_kernels, name), "duplicate extern kernel" + assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" self.name = name self.cpp_kernel_name = cpp_kernel self.has_out_variant = has_out_variant setattr(extern_kernels, name, kernel) self.op_overload = op_overload self.use_fallback_kernel = use_fallback_kernel + self.kernel_creator = kernel_creator def to_callable(self): return getattr(extern_kernels, self.name) @@ -802,7 +806,10 @@ def benchmark(self, *args, out): out_new, tuple(out.size()), tuple(out.stride()) ) out.copy_(out_new) # for correctness checking - return do_bench(lambda: algo(*args)) + if is_cpu_device(args): + return timed(lambda: algo(*args), ()) + else: + return do_bench(lambda: algo(*args)) def to_callable(self): fn = self.choice.to_callable() @@ -831,6 +838,8 @@ def output_node(self): inner = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) + elif self.choice.kernel_creator is not None: + inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) else: cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc inner = cls( @@ -853,6 +862,70 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType } +class DataProcessorChoiceCallerWrapper: + def __init__(self, wrapped, preprocessor, postprocessor): + self._wrapped = wrapped + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._preprocessor(args, out) + result = self._wrapped.benchmark(*new_args, out=new_out) + new_out = self._postprocessor(new_out) + if out is not new_out: + out.copy_(new_out) + return result + + def output_node(self) -> ir.TensorBox: + result = self._wrapped.output_node() + return self._postprocessor(result) + + +class DataProcessorTemplateWrapper: + def __init__( + self, + wrapped_template_cls, + preprocessor, + postprocessor, + **kwargs, + ): + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + assert "input_nodes" in kwargs + assert "layout" in kwargs + kwargs["input_nodes"], kwargs["layout"] = preprocessor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataProcessorChoiceCallerWrapper( + choice_caller, self._preprocessor, self._postprocessor + ) + + class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): msg += f"\nFrom choice {choice}\n{inputs_str}" @@ -1035,24 +1108,29 @@ def make_benchmark_fn( for i, x in enumerate(input_nodes) } example_inputs = list(unique_example_inputs.values()) - example_inputs_extern = [ - torch.as_strided( - unique_example_inputs[input_node.get_name()], - V.graph.sizevars.size_hints( - input_node.get_size(), - fallback=config.unbacked_symint_fallback, - ), - V.graph.sizevars.size_hints( - input_node.get_stride(), - fallback=config.unbacked_symint_fallback, - ), - V.graph.sizevars.size_hint( - input_node.get_layout().offset, - fallback=config.unbacked_symint_fallback, - ), - ) - for input_node in input_nodes - ] + example_inputs_extern = [] + for input_node in input_nodes: + example_input = unique_example_inputs[input_node.get_name()] + if example_input.is_mkldnn: + example_inputs_extern.append(example_input) + else: + example_inputs_extern.append( + torch.as_strided( + example_input, + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + ) out = cls.benchmark_example_value(layout) out_extern = torch.as_strided( @@ -1090,7 +1168,8 @@ def benchmark_choice_in_current_process(choice): result = choice.benchmark(*example_inputs, out=out) if VERIFY: torch.testing.assert_close(out_extern, expected, **VERIFY) - torch.cuda.synchronize() # shake out any CUDA errors + if torch.cuda.is_available(): + torch.cuda.synchronize() # shake out any CUDA errors return result def benchmark_in_current_process(choices): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5f85704c99380..feafbfccfb399 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1023,6 +1023,26 @@ def use_cutlass_template(layout, m, n, k): return res +def use_cpp_packed_gemm_template(layout, mat1, mat2): + from . import ir + + layout_dtypes = [torch.float32] + _, n = mat2.get_size() + if isinstance(mat2, ir.BaseView): + mat2 = mat2.unwrap_view() + # TODO: decide block size per ISA + # TODO: use larger block size for larger batch sizes + # TODO: support n % n_block_size != 0 + n_block_size = 16 + return ( + layout.device.type == "cpu" + and layout.dtype in layout_dtypes + and n % n_block_size == 0 + and isinstance(mat2, ir.StorageBox) + and mat2.is_module_buffer() + ) + + def use_aten_gemm_kernels(): return not use_max_autotune() or _use_autotune_backend("ATEN")