diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py new file mode 100644 index 0000000000000..eadd619ba9d29 --- /dev/null +++ b/test/inductor/test_cpu_select_algorithm.py @@ -0,0 +1,135 @@ +# Owner(s): ["oncall: cpu inductor"] +import functools +import unittest +from unittest.mock import patch + +import torch +import torch._dynamo.config +import torch._dynamo.config as dynamo_config +import torch._inductor.config as inductor_config +import torch._inductor.select_algorithm as select_algorithm +from torch._dynamo.utils import counters +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, +) + +from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL + +aten = torch.ops.aten + + +def patches(fn): + def skip_cache(self, choices, name, key, benchmark): + if benchmark is None: + return {} + return benchmark(choices) + + for patcher in [ + dynamo_config.patch(verbose=True), + inductor_config.patch( + debug=True, + max_autotune=True, + epilogue_fusion=True, + max_autotune_gemm_backends="CPP,ATEN", + ), + patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), + patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), + ]: + fn = patcher(fn) + + @functools.wraps(fn) + def wrapped(*args, **kwargs): + counters.clear() + torch.manual_seed(12345) + return fn(*args, **kwargs) + + return wrapped + + +class TestSelectAlgorithm(TestCase): + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (1, 2, 1000)) + @parametrize("in_features", (1, 2, 1000)) + @parametrize("out_features", (1, 32, 1024)) + @parametrize("bias", (True, False)) + @parametrize("input_3d", (True, False)) + @dtypes(torch.float) + def test_linear_static_shapes( + self, batch_size, in_features, out_features, bias, input_3d, dtype + ): + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias) + + @torch.compile + def forward(self, x): + return self.linear(x) + + counters.clear() + mod = M(bias=bias).to(dtype=dtype).eval() + B = (2, batch_size) if input_3d else (batch_size,) + v = torch.randn(*B, in_features).to(dtype=dtype) + mod(v) + if ( + counters["inductor"]["decompose_mm"] > 0 + or counters["inductor"]["decompose_addmm"] > 0 + ): + # This is a special case where we go directly with vectorized codegen + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) + else: + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("bias", (True, False)) + @dtypes(torch.float) + def test_linear_input_transpose(self, bias, dtype): + batch_size = 384 + in_features = 196 + out_features = 384 + + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias) + + @torch.compile + def forward(self, x): + return self.linear(x) + + counters.clear() + mod = M(bias=bias).to(dtype=dtype).eval() + v = torch.randn(in_features, batch_size).to(dtype=dtype) + mod(v.transpose(0, 1)) + # TODO(jgong5): support transposed input + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0) + + +@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) +class _DynamicShapesTestBase(TestCase): + pass + + +class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): + test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes + + +instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") +instantiate_device_type_tests( + TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu" +) + + +if __name__ == "__main__": + from torch.testing._internal.inductor_utils import HAS_CPU + + if HAS_CPU and not IS_MACOS: + run_tests() diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index c59d0af4a3e36..1737b9ffa65b6 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -8,7 +8,7 @@ import sys from copy import copy, deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union import sympy @@ -20,6 +20,7 @@ from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges +from ..._dynamo.utils import counters from .. import codecache, config, ir, metrics from ..codegen.wrapper import WrapperCodeGen @@ -3584,6 +3585,8 @@ def _can_fuse_horizontal_impl(self, node1, node2): return self._why_fuse_nodes(node1, node2) is not None def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False if ( len(node1.get_nodes()) + len(node2.get_nodes()) > config.cpp.max_horizontal_fusion_size @@ -3664,6 +3667,9 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): + # TODO(jgong5): support vertical fusion for template nodes + if node1.is_template() or node2.is_template(): + return False return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) @@ -3720,6 +3726,44 @@ def codegen_node( if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: self._set_flush_status(True) + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template( + template_node + ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: List[Optional[ir.Buffer]] = [n.node for n in epilogue_nodes] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() # type: ignore[attr-defined] + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() + def _get_scheduled_num_args(self): return self.kernel_group.get_num_args() @@ -3729,7 +3773,7 @@ def ready_to_flush(self): def codegen_sync(self): pass - def define_kernel(self, src_code, nodes): + def define_kernel(self, src_code, nodes, kernel_args=None): wrapper = V.graph.wrapper_code fused_name = ( get_fused_kernel_name(nodes, config.cpp.descriptive_names) @@ -3745,7 +3789,8 @@ def define_kernel(self, src_code, nodes): src_code = src_code.replace("#pragma CMT", "//") compile_wrapper = IndentedBuffer() - _, _, arg_types = self.kernel_group.args.cpp_argdefs() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() if not V.graph.cpp_wrapper: compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") compile_wrapper.splice(src_code, strip=True) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000..8ac82ef266002 --- /dev/null +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,373 @@ +from typing import cast, List, Optional + +import torch +import torch.utils +from .. import ir, lowering as L + +from ..kernel.mm_common import mm_args +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import cache_on_self, has_free_symbols, parallel_num_threads +from ..virtualized import V +from .cpp_micro_gemm import create_micro_gemm +from .cpp_template import CppTemplate + +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import GemmBlocking + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} + +{{micro_gemm.codegen_define(kernel)}} + +extern "C" +{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y})}} +{ + {{kernel.maybe_codegen_profile()}} + constexpr int64_t num_threads = {{num_threads}}; + constexpr int64_t N = {{kernel.size(Y, 1)}}; + constexpr int64_t K = {{kernel.size(X, 1)}}; + constexpr int64_t M0 = {{micro_gemm.register_blocking.block_m}}; + constexpr int64_t N0 = {{micro_gemm.register_blocking.block_n}}; + constexpr int64_t K0 = {{micro_gemm.register_blocking.block_k}}; + constexpr int64_t N0_blocks = (N + N0 - 1) / N0; + constexpr int64_t K0_blocks = (K + K0 - 1) / K0; + + static_assert(N % N0 == 0, "N dimension must be multiple of N0"); + + // TODO(jgong5): improve cache blocking with CPU info (Mc, Kc) + {%- if is_dynamic_M %} + const int64_t M = {{kernel.size(Y, 0)}}; + const int64_t M0_blocks = (M + M0 - 1) / M0; + {%- if num_threads > 1 %} + int64_t Mt_blocks, Nt_blocks, Kt_blocks; + mm_get_thread_blocking(num_threads, M, N, K, M0, N0, K0, Mt_blocks, Nt_blocks, Kt_blocks); + {%- else %} + const auto Mt_blocks = M0_blocks; + const auto Nt_blocks = N0_blocks; + const auto Kt_blocks = K0_blocks; + {%- endif %} + const int64_t Mc_blocks = Mt_blocks; + const int64_t Kc_blocks = Kt_blocks; + {%- else %} + constexpr int64_t M = {{kernel.size(Y, 0)}}; + constexpr int64_t M0_blocks = (M + M0 - 1) / M0; + constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; + constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; + constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; + constexpr int64_t Mc_blocks = {{template.cache_blocking().block_m}}; + constexpr int64_t Kc_blocks = {{template.cache_blocking().block_k}}; + {%- endif %} + + // TODO(jgong5): support k-slicing + {{kernel.assert_function}}(Kt_blocks == K0_blocks, "Do not support k slicing yet."); + // make sure all partitions are assigned + {{kernel.assert_function}}( + Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks, + "Not all partitions are assigned." + ); + + {%- if num_threads > 1 %} + #pragma omp parallel num_threads({{num_threads}}) + { + int tid = omp_get_thread_num(); + int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; + mm_get_thread_blocks( + tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks, + m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); + {%- else %} + { + int64_t m_block_start = 0; + int64_t m_block_end = M0_blocks; + int64_t n_block_start = 0; + int64_t n_block_end = N0_blocks; + int64_t k_block_start = 0; + int64_t k_block_end = K0_blocks; + {%- endif %} + for (int64_t mc = m_block_start; mc < m_block_end; mc += Mc_blocks) { + int64_t m_start = mc * M0; + int64_t m_end = std::min((mc + Mc_blocks) * M0, M); + for (int64_t nc = n_block_start; nc < n_block_end; ++nc) { + int64_t n_start = nc * N0; + // TODO(jgong5): use float32 temporary buffer to support bfloat16/float16 gemm + {%- if inp is not none and beta != 0 %} + for (int64_t m = m_start; m < m_end; ++m) { + #pragma omp simd + for (int64_t n = n_start; n < n_start + N0; ++n) { + {{kernel.index(Y, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m", "n"])}}; + } + } + {%- endif %} + for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { + int64_t k_start = kc * K0; + int64_t k_end = std::min((kc + Kc_blocks) * K0, K); + {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + {%- set tile_W_3d = kernel.slice_nd(W, [("nc", "nc + 1"), ("k_start", "k_end"), ()]) %} + {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + {%- set tile_Y = kernel.slice_nd(Y, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} + {%- if inp is not none and beta != 0 %} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(20, false) }} + {%- else %} + if (kc == k_block_start) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=False)|indent(24, false) }} + } else { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(24, false) }} + } + {%- endif %} + } + } + } + } +} +""" + + +class CppPackedGemmTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: ir.Layout, + num_threads: int, + register_blocking: GemmBlocking, + beta=1, + alpha=1, + ): + super().__init__("packed_gemm", input_nodes, layout) + self.beta = beta + self.alpha = alpha + self.num_threads = num_threads + self.register_blocking = register_blocking + m, n = layout.size + _, k = input_nodes[0].get_size() + self.m, self.n, self.k = m, n, k + self.is_dynamic_M = has_free_symbols((m,)) + + @cache_on_self + def thread_blocking(self) -> GemmBlocking: + # TODO(jgong5): allow tuning various blocking options + def get_factors(number): + factors = [] + # priorize more evenly divided factors + for i in range(int(number**0.5), 0, -1): + if number % i == 0: + factors.append(number // i) + factors.append(i) + return factors + + def get_blocking(num_threads, factor, m_blocks, n_blocks, k_blocks): + thread_block_n = (n_blocks + factor - 1) // factor + cofactor = num_threads // factor + thread_block_m = (m_blocks + cofactor - 1) // cofactor + return GemmBlocking(thread_block_m, thread_block_n, k_blocks) + + assert ( + not self.is_dynamic_M + ), "Unable to determine thread blocking for dynamic M." + register_blocking = self.register_blocking + m_blocks = (self.m + register_blocking.block_m - 1) // register_blocking.block_m + n_blocks = (self.n + register_blocking.block_n - 1) // register_blocking.block_n + k_blocks = (self.k + register_blocking.block_k - 1) // register_blocking.block_k + factors = get_factors(self.num_threads) + assert len(factors) > 0 + for factor in factors: + if n_blocks % factor == 0 and m_blocks % (self.num_threads // factor) == 0: + return get_blocking( + self.num_threads, factor, m_blocks, n_blocks, k_blocks + ) + for factor in factors: + if n_blocks % factor == 0: + return get_blocking( + self.num_threads, factor, m_blocks, n_blocks, k_blocks + ) + cofactor = self.num_threads // factor + if m_blocks % cofactor == 0: + return get_blocking( + self.num_threads, factor, m_blocks, n_blocks, k_blocks + ) + raise AssertionError("Should not reach here.") + + @cache_on_self + def cache_blocking(self) -> GemmBlocking: + # TODO(jgong5): improve cache blocking with CPU info + assert ( + not self.is_dynamic_M + ), "Unable to determine cache blocking for dynamic M." + thread_blocking = self.thread_blocking() + return GemmBlocking(thread_blocking.block_m, 1, thread_blocking.block_k) + + @staticmethod + def add_choices( + choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None + ): + if input_indices is None: + input_indices = list(range(len(input_nodes))) + + def reorder_and_filter(inputs, layout_or_out): + if len(input_indices) == 2: + x_idx = input_indices[0] + w_idx = input_indices[1] + return [inputs[x_idx], inputs[w_idx]], layout_or_out + else: + assert ( + len(input_indices) == 3 + ), "Cpp Packed GEMM template requires 2 or 3 input nodes." + # assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out + + def transpose_weight(inputs, layout_or_out): + if not trans_w: + return inputs, layout_or_out + + new_inputs = list(inputs) + W = inputs[1] + if isinstance(W, ir.IRNode): + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + new_inputs[1] = L.permute(W, [1, 0]) + return new_inputs, layout_or_out + else: + assert isinstance(W, torch.Tensor) + new_inputs[1] = W.transpose(0, 1) + return new_inputs, layout_or_out + + # TODO(jgong5): decide proper number of threads per problem size + num_threads = parallel_num_threads() + new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) + m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) + micro_gemm = create_micro_gemm( + "micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads + ) + assert micro_gemm is not None + _, block_n, _ = micro_gemm.register_blocking + + def pack_weight(inputs, layout_or_out): + W = inputs[1] + new_inputs = list(inputs) + if isinstance(W, ir.IRNode): + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + k, n = W.get_size() + assert ( + n % block_n == 0 + ), f"The last dimension of W must be a multiple of {block_n}." + blocked_w = L.permute( + L.view(W, (k, n // block_n, block_n)), + [1, 0, 2], + ) + blocked_w = ir.ExternKernel.realize_input(blocked_w) + blocked_w = ir.ExternKernel.require_contiguous(blocked_w) + if isinstance(blocked_w, ir.ReinterpretView): + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + assert isinstance(blocked_w.layout, ir.FixedLayout) + blocked_w.layout = ir.FixedLayout( + blocked_w.layout.device, + blocked_w.layout.dtype, + blocked_w.layout.size, + ir.FlexibleLayout.contiguous_strides(blocked_w.layout.size), + blocked_w.layout.offset, + ) + else: + k, n = list(W.shape) + blocked_w = ( + W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous() + ) + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + new_stride = [1] + for sz in reversed(blocked_w.shape[1:]): + new_stride.insert(0, new_stride[0] * sz) + blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride) + new_inputs[1] = blocked_w + return new_inputs, layout_or_out + + def preprocessor(inputs, layout): + return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) + + def postprocessor(output): + if isinstance(output, ir.TensorBox): + # prepack the weight as input to the template buffer + # TODO(jgong5): prune the unused constants in V.graph + # Should we implement it with constant folding in the scheduler instead? + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + W_node = new_input_nodes[1] + assert W_node.get_name() in V.graph.constants + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, _ = pack_weight( + *transpose_weight(new_input_nodes, layout) + ) + W_packed = new_input_nodes[1] + W_packed_constant = V.graph.add_tensor_constant(W_packed) + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( + W_packed_constant + ) + return output + + template = DataProcessorTemplateWrapper( + CppPackedGemmTemplate, + preprocessor, + postprocessor, + input_nodes=input_nodes, + layout=layout, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, + ) + template.maybe_append_choice(choices) + return template + + def render( # type: ignore[override] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + epilogue_nodes: Optional[List[ir.IRNode]] = None, + **kwargs, + ) -> str: + assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template." + assert len(self.input_nodes) >= 2 + + X, W = self.input_nodes[0], self.input_nodes[1] + inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None + Y = self.output_node + + if template_buffer_node is not None: + # Use the updated prepacked weight buffer + W = template_buffer_node.inputs[1] + Y = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + Y = cast(ir.Buffer, epilogue_nodes[-1]) + assert self.output_node is not None + + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + self.layout.dtype, + alpha=self.alpha, + num_threads=self.num_threads, + ) + assert micro_gemm is not None + assert self.register_blocking == micro_gemm.register_blocking + + options = dict( + X=X, + W=W, + inp=inp, + Y=Y, + beta=self.beta, + alpha=self.alpha, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, + template=self, + kernel=kernel, + epilogues=epilogue_nodes, + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py new file mode 100644 index 0000000000000..649782ff158d8 --- /dev/null +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -0,0 +1,401 @@ +from collections import namedtuple +from typing import Dict, List, Optional, Type + +import sympy + +import torch + +from .. import ir +from ..codecache import pick_vec_isa, VecAVX2, VecAVX512 +from ..utils import IndentedBuffer, parallel_num_threads +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp + + +class CppMicroGemm: + """ + A class that codegens a kernel that computes small-sized matrix multiplication. + + A micro GEMM kernel is responsible for register blocking, instruction selection, + and other CPU architecture-specific optimizations. + + The subclasses need to override `codegen_define` to define the kernel function + that is called by the code generated by `codegen_call`. + """ + + # TODO(jgong5): support constant shapes and lds as template args. + DECLARE_KERNEL = r""" +template +inline void {{kernel_name}}( + const {{input_t}}* __restrict__ A, + const {{input_t}}* __restrict__ B, + {{output_t}}* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) +""" + + def __init__( + self, + name, + input_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + ): + self.name = name + self.input_dtype = input_dtype + self.output_dtype = output_dtype + self.compute_dtype = compute_dtype + self.register_blocking = register_blocking + self.alpha = alpha + + def get_common_options(self): + return { + "kernel_name": self.name, + "input_t": DTYPE_TO_CPP[self.input_dtype], + "output_t": DTYPE_TO_CPP[self.output_dtype], + "compute_t": DTYPE_TO_CPP[self.compute_dtype], + "alpha": self.alpha, + } + + def get_kernel_declaration(self): + options = self.get_common_options() + return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + raise NotImplementedError + + def codegen_call( + self, + kernel: CppTemplateKernel, + A: ir.Buffer, + B: ir.Buffer, + C: ir.Buffer, + accum: bool, + ) -> str: + """ + Generate the code for calling the templated kernel that computes + `C += alpha * A @ B` if `accum` is True, or `C = alpha * A @ B` otherwise. + """ + A_ptr = f"&({kernel.index(A, [0, 0])})" + B_ptr = f"&({kernel.index(B, [0, 0])})" + C_ptr = f"&({kernel.index(C, [0, 0])})" + M = kernel.size(C, 0) + N = kernel.size(C, 1) + K = kernel.size(A, 1) + lda = kernel.stride(A, 0) + ldb = kernel.stride(B, 0) + ldc = kernel.stride(C, 0) + res = IndentedBuffer() + res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(") + with res.indent(): + res.writeline(f"{A_ptr},") + res.writeline(f"{B_ptr},") + res.writeline(f"{C_ptr},") + res.writeline(f"{M},") + res.writeline(f"{N},") + res.writeline(f"{K},") + res.writeline(f"{lda},") + res.writeline(f"{ldb},") + res.writeline(f"{ldc}") + res.writeline(");") + return res.getvalue() + + +CppMicroGemmConfig = namedtuple( + "CppMicroGemmConfig", + [ + "input_dtype", + "output_dtype", + "compute_dtype", + "vec_isa_cls", + "register_blocking", + ], +) + +micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {} + + +def register_micro_gemm(*configs): + def inner(cls): + assert ( + cls not in micro_gemm_configs + ), f"Duplicate micro_gemm registration for {cls}" + assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" + micro_gemm_configs[cls] = list(configs) + return cls + + return inner + + +class CppMicroGemmRef(CppMicroGemm): + """ + A reference implementation of the CppMicroGemm class with naive C++ code. + It is used for correctness debugging. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + {{compute_t}} result = accum ? C[m * ldc + n] : 0; + for (int64_t k = 0; k < K; ++k) { + result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; + } + C[m * ldc + n] = result; + } + } +} +""" + + def __init__(self, name, input_dtype, output_dtype, compute_dtype, alpha): + super().__init__( + name, input_dtype, output_dtype, compute_dtype, GemmBlocking(1, 1, 1), alpha + ) + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + **self.get_common_options(), + } + return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) + + +@register_micro_gemm( + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 48, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 24, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1) + ), +) +class CppMicroGemmFP32Vec(CppMicroGemm): + """ + This class generates the code for fp32 micro gemm using vec instructions. + """ + + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + // TODO(jgong5): loop unroll for M and N + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + if (block_m == {{block_m}}) { + {{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>( + A + m * lda, + B + n, + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + } else { + switch (block_m) { + {%- for b in range(block_m - 1, 0, -1) %} + case {{b}}: + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( + A + m * lda, + B + n, + C + m * ldc + n, + K, + lda, + ldb, + ldc + ); + break; + {%- endfor %} + default: + {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); + } + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +template +inline void {{kernel_name}}_kernel( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) { + using Vectorized = at::vec::Vectorized; + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + + Vectorized va; + at::vec::VectorizedN vb; + at::vec::VectorizedN vc; + + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + {%- if alpha != 1 %} + va = Vectorized(A[row * lda + k] * {{alpha}}); + {%- else %} + va = Vectorized(A[row * lda + k]); + {%- endif %} + } + + if constexpr (row == 0) { + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); + } + + constexpr int idx = row * COLS + col; + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + }; + + {{kernel.unroll_pragma(4)}} + for (int k = 0; k < K; ++k) { + c10::ForcedUnroll{}(compute, k); + } + + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i].store(C + row * ldc + col * VLEN); + }; + c10::ForcedUnroll{}(storec); +} +""" + + def codegen_define(self, kernel: CppTemplateKernel) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + **self.get_common_options(), + } + result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + +def create_micro_gemm( + name, + m, + n, + k, + input_dtype, + output_dtype=None, + compute_dtype=None, + alpha=1, + num_threads=-1, + use_ref=True, +) -> Optional[CppMicroGemm]: + def create_from_config(cls, config: CppMicroGemmConfig): + return cls( + name, + config.input_dtype, + config.output_dtype, + config.compute_dtype, + config.register_blocking, + alpha, + ) + + assert isinstance(n, int) or n.is_number, n + assert isinstance(k, int) or k.is_number, k + m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m + assert isinstance(m, int), m + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = input_dtype + if num_threads < 0: + num_threads = parallel_num_threads() + vec_isa = pick_vec_isa() + matched_configs = [] + for cls, configs in micro_gemm_configs.items(): + for config in configs: + if not isinstance(vec_isa, config.vec_isa_cls): + continue + if ( + config.input_dtype == input_dtype + and config.output_dtype == output_dtype + and config.compute_dtype == compute_dtype + ): + block_m, block_n, block_k = config.register_blocking + # TODO(jgong5): support n % n_block_size != 0 + if n % block_n != 0: + continue + # Criteria on the ranking of configurations + # 1. Dividable by block sizes (block_m, block_k) + # 2. Number of mxn blocks is large enough to occupy all the threads + # 3. Register blocks are larger + dividable_score = 0 + if k % block_k == 0: + dividable_score += 1 + if m % block_m == 0: + dividable_score += 1 + occupancy_score = 0 + n_blocks = n // block_n + total_mxn_blocks = n // block_n * ((m + block_m - 1) // block_m) + if n_blocks >= num_threads: + occupancy_score += 1 + if total_mxn_blocks >= num_threads: + occupancy_score += 1 + matched_configs.append( + ( + (dividable_score, occupancy_score, block_m * block_n * block_k), + cls, + config, + ) + ) + if len(matched_configs) == 0: + if use_ref: + return CppMicroGemmRef( + name, input_dtype, output_dtype, compute_dtype, alpha + ) + else: + return None + # TODO(jgong5): allow autotuning on choices of configs + return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 7e3483ca99948..6898a8a52112e 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -5,6 +5,7 @@ #include #include #include +#include #include // WARNING: be extra careful when including more ATen/c10 header files here! @@ -309,3 +310,106 @@ atomic_add(volatile T *addr, T offset) { std::atomic *atomic_addr = (std::atomic *)addr; atomic_addr->fetch_add(offset, std::memory_order_relaxed); } + +void mm_get_thread_blocking( + int num_threads, + int64_t M, + int64_t N, + int64_t K, + int64_t M0, + int64_t N0, + int64_t K0, + int64_t& Mt, + int64_t& Nt, + int64_t& Kt) { + auto get_factors = [](int64_t number) { + int count = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + count += 2; + } + } + auto factors = std::make_unique(count); + int index = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + factors[index++] = number / i; + factors[index++] = i; + } + } + return std::make_tuple(std::move(factors), count); + }; + + auto get_blocking = [](int64_t num_threads, + int64_t factor, + int64_t m_blocks, + int64_t n_blocks, + int64_t k_blocks) { + int64_t thread_block_n = (n_blocks + factor - 1) / factor; + int64_t cofactor = num_threads / factor; + int64_t thread_block_m = (m_blocks + cofactor - 1) / cofactor; + return std::make_tuple(thread_block_m, thread_block_n, k_blocks); + }; + + int64_t m_blocks = (M + M0 - 1) / M0; + int64_t n_blocks = (N + N0 - 1) / N0; + int64_t k_blocks = (K + K0 - 1) / K0; + + auto [factors, count] = get_factors(num_threads); + assert(count > 0); + + for (int i = 0; i < count; ++i) { + int64_t factor = factors[i]; + if (n_blocks % factor == 0 && + m_blocks % (num_threads / factor) == 0) { + std::tie(Mt, Nt, Kt) = get_blocking( + num_threads, factor, m_blocks, n_blocks, k_blocks); + return; + } + } + + for (int i = 0; i < count; ++i) { + int64_t factor = factors[i]; + if (n_blocks % factor == 0) { + std::tie(Mt, Nt, Kt) = get_blocking( + num_threads, factor, m_blocks, n_blocks, k_blocks); + return; + } + int64_t cofactor = num_threads / factor; + if (m_blocks % cofactor == 0) { + std::tie(Mt, Nt, Kt) = get_blocking( + num_threads, factor, m_blocks, n_blocks, k_blocks); + return; + } + } + + assert(false && "Should not reach here."); + // Dummy return to avoid compiler warning + return; +} + +inline void mm_get_thread_blocks( + int thread_id, + int64_t M_blocks, + int64_t N_blocks, + int64_t K_blocks, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& m_block_start, + int64_t& m_block_end, + int64_t& n_block_start, + int64_t& n_block_end, + int64_t& k_block_start, + int64_t& k_block_end) { + int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; + k_block_start = (thread_id % num_Kt) * Kt_blocks; + k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); + thread_id /= num_Kt; + int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; + n_block_start = (thread_id % num_Nt) * Nt_blocks; + n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); + thread_id /= num_Nt; + m_block_start = std::min(thread_id * Mt_blocks, M_blocks); + m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); +} diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py new file mode 100644 index 0000000000000..492aca83462a4 --- /dev/null +++ b/torch/_inductor/codegen/cpp_template.py @@ -0,0 +1,116 @@ +import functools +import itertools +import logging + +import sys +from typing import List, Optional +from unittest.mock import patch + +import sympy + +from .. import codecache, config, ir +from ..autotune_process import CppBenchmarkRequest, TensorMeta +from ..utils import IndentedBuffer, Placeholder, unique +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel + +log = logging.getLogger(__name__) + + +class CppTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes, + layout: ir.Layout, + ): + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: ir.Buffer = ir.Buffer("buf_out", layout) + self.layout = layout + + def generate(self, **kwargs): + kernel_name = f"cpp_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), CppTemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + expected_args = list( + unique(input_node.get_name() for input_node in self.input_nodes) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + + kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}" + + # Create the BenchmarkRequest for CPP + bmreq = CppBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ir.CppTemplateBuffer, + epilogue_nodes: Optional[List[ir.IRNode]] = None, + ): + kernel = CppTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + return kernel, render + + return CppTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.writeline(codecache.cpp_prefix()) + res.splice( + """ + #include "c10/util/Unroll.h" + """ + ) + enable_kernel_profile = ( + config.cpp.enable_kernel_profile and sys.platform == "linux" + ) + if enable_kernel_profile: + res.writelines(["#include "]) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py new file mode 100644 index 0000000000000..07e61d83a9122 --- /dev/null +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -0,0 +1,200 @@ +import itertools +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import sympy +from sympy.parsing.sympy_parser import parse_expr + +import torch + +from torch._inductor.autotune_process import CppBenchmarkRequest +from torch._inductor.utils import sympy_index_symbol +from .. import codecache, config, ir, lowering as L +from ..virtualized import V +from .common import Kernel, OpOverrides +from .cpp_utils import cexpr_index, DTYPE_TO_CPP + + +def parse_expr_with_index_symbols(expr_str: str) -> sympy.Expr: + expr = parse_expr(expr_str) + int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} + return expr.subs(int_symbols) + + +def wrap_with_tensorbox(node) -> ir.TensorBox: + return ( + ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) + ) + + +class CppTemplateKernel(Kernel): + overrides = OpOverrides + + def __init__(self, kernel_name): + super().__init__() + self.kernel_name = kernel_name + + def def_kernel( + self, + inputs: Dict[str, ir.Buffer], + outputs: Dict[str, ir.Buffer], + ) -> str: + for name, inp in inputs.items(): + if inp is not None: + self.args.input_buffers[inp.get_name()] = name + for name, out in outputs.items(): + self.args.output_buffers[out.get_name()] = name + unique_sizevars = { + s + for input in inputs.values() + if input is not None + for sym in itertools.chain(input.get_size(), input.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + } + unique_sizevars |= { + s + for output in outputs.values() + for sym in itertools.chain(output.get_size(), output.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + } + sizevars = sorted(unique_sizevars, key=str) + for sizevar in sizevars: + self.args.sizevars[sizevar] = f"k{sizevar}" + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" + + def call_kernel(self, name: str, node: ir.CppTemplateBuffer): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) + + def dtype(self, node: ir.Buffer) -> str: + return DTYPE_TO_CPP[node.get_dtype()] + + def acc_dtype(self, node: ir.Buffer) -> str: + if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def size(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_size()[dim])) + + def stride(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_stride()[dim])) + + def index(self, node: ir.Buffer, indices: List[Any]) -> str: + indexer = node.layout.as_fixed().make_indexer() + index = indexer([parse_expr_with_index_symbols(str(idx)) for idx in indices]) + index = self.rename_indexing(index) + return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" + + def slice_nd(self, node, ranges: List[Tuple[Any]]) -> ir.ReinterpretView: + """ + Slice the given node with a list of ranges (start and end) corresponding to its dims. + The dim is not sliced if the corresponding range is empty. + """ + assert len(ranges) == len(node.get_size()) + sliced = wrap_with_tensorbox(node) + for dim, _range in enumerate(ranges): + if len(_range) == 0: + continue + assert len(_range) == 2 + start, end = (parse_expr_with_index_symbols(str(r)) for r in _range) + sliced = L.slice_(sliced, dim, start, end, clamp=False) + assert isinstance(sliced.data, ir.ReinterpretView) + return sliced.data + + def view(self, node, sizes: List[Any]) -> ir.View: + node = wrap_with_tensorbox(node) + sizes = [parse_expr_with_index_symbols(str(s)) for s in sizes] + return L.view(node, sizes).data + + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + + def maybe_codegen_profile(self) -> str: + if config.cpp.enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' + else: + return "" + + def unroll_pragma(self, unroll): + if codecache.is_gcc(): + return f"#pragma GCC unroll {unroll}" + else: + return f"#pragma unroll {unroll}" + + +class CppTemplateCaller(ir.ChoiceCaller): + """ + CppTemplateCaller + + This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CppBenchmarkRequest): The benchmark request for the caller. + template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[ir.Buffer], + layout: ir.Layout, + make_kernel_render: Callable[ + [ir.CppTemplateBuffer, Optional[List[ir.IRNode]]], str + ], + bmreq: CppBenchmarkRequest, + template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 + info_kwargs: Optional[ + Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict( + self, + ) -> Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]: + return {"backend": "CPP", "op_type": "unknown"} + + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.CppTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + choice=self, + ) + ) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 7d63b7acf2fb4..4ab33a5e26dc0 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,5 +1,7 @@ import math +from collections import namedtuple + import torch from .common import ExprPrinter @@ -60,6 +62,8 @@ INDEX_TYPE = "long" +GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) + class CppPrinter(ExprPrinter): def _print_Integer(self, expr): diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 1869a8cf8fd87..5b674ed53f205 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -234,12 +234,13 @@ def is_fbcode(): ) # Specify candidate backends for gemm autotune. -# Possible choices are combinations of: ATen, Triton, CUTLASS. +# Possible choices are combinations of: ATen, Triton, CUTLASS, CPP. # ATen: default Pytorch ATen kernels. # Triton: Triton templates defined in torch inductor. # CUTLASS: Cutlass templates and kernels. +# CPP: CPP templates and kernels for CPU. max_autotune_gemm_backends = os.environ.get( - "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" ).upper() # Specify the size of the search space for GEMM autotuning. diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 01803af152608..960c3a42e1f15 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -19,6 +19,7 @@ pw_cast_for_opmath, ) from torch._decomp.decompositions_for_rng import extra_random_decomps +from torch._dynamo.utils import counters from torch._higher_order_ops.out_dtype import out_dtype from torch._inductor.utils import pad_listlike from torch._prims_common import ( @@ -205,6 +206,7 @@ def bmm(self, batch2): return out if self.device.type == "cpu": if self.size(1) == 1 and batch2.size(-1) == 1: + counters["inductor"]["decompose_bmm"] += 1 return torch.sum( self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True ).unsqueeze(1) @@ -216,11 +218,13 @@ def bmm(self, batch2): def addmm(self, mat1, mat2, beta=1, alpha=1): if self.device.type == "cpu": if mat1.size(0) == 1 and mat2.size(-1) == 1: + counters["inductor"]["decompose_addmm"] += 1 out = torch.sum( mat1.squeeze(0) * mat2.squeeze(-1), dim=0, keepdim=True ).unsqueeze(0) return alpha * out + beta * self if mat1.size(0) == 1 and mat2.size(0) <= 16 and mat2.size(1) <= 16: + counters["inductor"]["decompose_addmm"] += 1 out = (mat1.T * mat2).sum(dim=0, keepdim=True) return alpha * out + beta * self return NotImplemented @@ -247,10 +251,12 @@ def mm(self, input2): and (self.dtype == input2.dtype) and definitely_true((torch.numel(self) + torch.numel(input2)) <= 32) ): + counters["inductor"]["decompose_mm"] += 1 return torch.cat([self[i, :] * input2 for i in range(self.size(0))]) if guard_size_oblivious(self.size(0) == 1) and guard_size_oblivious( input2.size(-1) == 1 ): + counters["inductor"]["decompose_mm"] += 1 return torch.sum( self.squeeze(0) * input2.squeeze(-1), dim=0, keepdim=True ).unsqueeze(0) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index f6b187de1f01d..49c0ed4864f2c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3722,6 +3722,13 @@ def get_workspace_size(self): return self.workspace_size if self.workspace_size is not None else 0 +class CppTemplateBuffer(TemplateBuffer): + def __init__(self, layout, inputs, make_kernel_render, template, choice): + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.choice = choice + + @dataclasses.dataclass class InputsKernel(Buffer): inputs: List[Buffer] @@ -6266,7 +6273,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, packed_w, orig_w, batch_size): + def create(cls, x, packed_w, orig_w, B, batch_size): x = cls.require_stride1(cls.realize_input(x)) orig_w = cls.require_stride1(cls.realize_input(orig_w)) *m, _ = x.get_size() @@ -6274,7 +6281,11 @@ def create(cls, x, packed_w, orig_w, batch_size): output_size = list(m) + [oc] output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] - constant_args = [None, batch_size] + constant_args = [batch_size] + if B is not None: + inputs += [B] + else: + constant_args.insert(0, None) return MKLPackedLinear( layout=FixedLayout( diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index d48209508c5bc..a90fdbfa33d90 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional import torch +from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from torch._inductor.virtualized import V from .. import config as inductor_config from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate @@ -17,6 +18,7 @@ ) from ..utils import ( use_aten_gemm_kernels, + use_cpp_packed_gemm_template, use_cutlass_template, use_max_autotune, use_triton_template, @@ -156,6 +158,13 @@ def tuned_mm(mat1, mat2, *, layout=None): if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): CUTLASSGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + if use_cpp_packed_gemm_template(layout, mat1, mat2): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + if ( len(choices) == 0 and not use_aten_gemm_kernels() @@ -320,6 +329,15 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): beta=beta, ) + if use_cpp_packed_gemm_template(layout, mat1, mat2): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [inp_expanded, mat1, mat2], + alpha=alpha, + beta=beta, + ) + add_aten_fallback = False if len(choices) == 0: log.warning("No choices for GEMM, using ATen backend as fallback") diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 5a12a5c090bf8..1f64574d589ba 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,10 +1,22 @@ -from typing import List +from typing import List, Optional import torch import torch.utils._pytree as pytree +from torch._inductor.kernel.mm_common import mm_args from . import ir +from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox -from .lowering import add, add_needs_realized_inputs, aten, register_lowering, to_dtype +from .lowering import ( + add, + add_needs_realized_inputs, + aten, + permute, + register_lowering, + to_dtype, +) +from .select_algorithm import autotune_select_algorithm, ExternKernelChoice +from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune +from .virtualized import V def register_onednn_fusion_ops(): @@ -403,6 +415,12 @@ def qlinear_binary( ) if torch._C.has_mkl: + aten_mkl_linear = ExternKernelChoice( + torch.ops.mkl._mkl_linear, + "mkl::_mkl_linear", + has_out_variant=False, + kernel_creator=ir.MKLPackedLinear.create, + ) cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) @register_lowering(torch.ops.mkl._mkl_linear) @@ -410,11 +428,48 @@ def mkl_packed_linear( x: TensorBox, packed_w: TensorBox, orig_w: TensorBox, - b: TensorBox, + b: Optional[TensorBox], batch_size, + *, + layout=None, ): - result = TensorBox.create( - ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size) + choices = ( + [ + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ] + if use_aten_gemm_kernels() + else [] + ) + if use_max_autotune(): + transposed_w = permute(orig_w, [1, 0]) + *_, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, packed_w, orig_w], + trans_w=True, + input_indices=[0, 2], + ) + + assert packed_w.get_name() in V.graph.constants + assert orig_w.get_name() in V.graph.constants + # packed_w is a mkldnn tensor which we can't generate directly + # so we use the weights from the original tensor in autotune. + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + 2: lambda x: V.graph.constants[x.get_name()], + } + result: TensorBox = autotune_select_algorithm( + "packed_linear", + choices, + [x, packed_w, orig_w], + layout, + input_gen_fns=input_gen_fns, ) if b is not None: result = add(result, b) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index e220544127b60..bb868c241efb2 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -749,17 +749,19 @@ def __init__( has_out_variant=True, op_overload=None, use_fallback_kernel=False, + kernel_creator=None, ): super().__init__() name = name or kernel.__name__ assert callable(kernel) - assert not hasattr(extern_kernels, name), "duplicate extern kernel" + assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" self.name = name self.cpp_kernel_name = cpp_kernel self.has_out_variant = has_out_variant setattr(extern_kernels, name, kernel) self.op_overload = op_overload self.use_fallback_kernel = use_fallback_kernel + self.kernel_creator = kernel_creator def to_callable(self): return getattr(extern_kernels, self.name) @@ -926,6 +928,8 @@ def output_node(self): inner = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) + elif self.choice.kernel_creator is not None: + inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) else: cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc inner = cls( @@ -976,6 +980,86 @@ def append_to_log(filename, data): json.dump(log_data, f, indent=4) +class DataProcessorChoiceCallerWrapper: + def __init__(self, wrapped, preprocessor, postprocessor): + self._wrapped = wrapped + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._preprocessor(args, out) + result = self._wrapped.benchmark(*new_args, out=new_out) + new_out = self._postprocessor(new_out) + if out is not new_out: + out.copy_(new_out) + return result + + def output_node(self) -> ir.TensorBox: + result = self._wrapped.output_node() + return self._postprocessor(result) + + def __repr__(self) -> str: + return f"DataProcessorChoiceCallerWrapper({self._wrapped})" + + +class DataProcessorTemplateWrapper: + """ + A wrapper class for a kernel template. + + This class together with `DataProcessorChoiceCallerWrapper` provides a convenient way to + preprocess and postprocess data before and after using the wrapped template. A typical + usage is to reorder or filter the input nodes in order to match the expected input of other + kernel choices like a ATen kernel. A more complicated usage is to prepack the weights. + See the example from :mod:`cpp_gemm_template` for more details. + """ + + def __init__( + self, + wrapped_template_cls, + preprocessor, + postprocessor, + **kwargs, + ): + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + assert "input_nodes" in kwargs + assert "layout" in kwargs + kwargs["input_nodes"], kwargs["layout"] = preprocessor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataProcessorChoiceCallerWrapper( + choice_caller, self._preprocessor, self._postprocessor + ) + + def __repr__(self) -> str: + return f"DataProcessorTemplateWrapper({self._wrapped})" + + class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): msg += f"\nFrom choice {choice}\n{inputs_str}" @@ -1268,7 +1352,9 @@ def get_inputs(): } example_inputs = list(unique_example_inputs.values()) example_inputs_extern = [ - torch.as_strided( + unique_example_inputs[input_node.get_name()] + if unique_example_inputs[input_node.get_name()].is_mkldnn + else torch.as_strided( unique_example_inputs[input_node.get_name()], V.graph.sizevars.size_hints( input_node.get_size(), diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index c25623fdb6133..5ae285817941d 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1002,6 +1002,42 @@ def use_cutlass_template(layout, m, n, k): return res +def _use_template_for_cpu(layout): + return use_max_autotune() and layout.device.type == "cpu" + + +def use_cpp_packed_gemm_template(layout, mat1, mat2): + from . import ir + from .codegen.cpp_micro_gemm import create_micro_gemm + from .kernel.mm_common import mm_args + + if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): + return False + + if not config.cpp.weight_prepack: + return False + + layout_dtypes = [torch.float32] + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) + # TODO(jgong5): support dynamic shapes for n or k + if has_free_symbols((n, k)): + return False + if isinstance(mat2, ir.BaseView): + mat2 = mat2.unwrap_view() + micro_gemm = create_micro_gemm( + "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() + ) + # TODO(jgong5): support n % n_block_size != 0 + return ( + layout.dtype in layout_dtypes + and micro_gemm is not None + and n % micro_gemm.register_blocking[1] == 0 + and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input + and isinstance(mat2, ir.StorageBox) + and mat2.is_module_buffer() + ) + + def use_aten_gemm_kernels(): return not use_max_autotune() or _use_autotune_backend("ATEN")