From 00eb31a67a3a18ed9612932ad3b566e226ea13c3 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sun, 14 Apr 2024 07:21:04 -0700 Subject: [PATCH 01/28] Update [ghstack-poisoned] --- torch/_inductor/autotune_process.py | 65 +++++++- torch/_inductor/codecache.py | 3 - torch/_inductor/codegen/cpp.py | 49 +++++- torch/_inductor/codegen/cpp_gemm_template.py | 71 ++++++++ torch/_inductor/codegen/cpp_template.py | 121 ++++++++++++++ .../_inductor/codegen/cpp_template_kernel.py | 151 ++++++++++++++++++ torch/_inductor/ir.py | 15 +- torch/_inductor/mkldnn_lowerings.py | 51 +++++- torch/_inductor/select_algorithm.py | 49 ++++-- 9 files changed, 542 insertions(+), 33 deletions(-) create mode 100644 torch/_inductor/codegen/cpp_gemm_template.py create mode 100644 torch/_inductor/codegen/cpp_template.py create mode 100644 torch/_inductor/codegen/cpp_template_kernel.py diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 790ec9d60ec0f..baf45d67e4557 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -9,9 +9,10 @@ import time import warnings from concurrent.futures import ThreadPoolExecutor -from ctypes import byref, c_size_t, c_void_p +from ctypes import byref, c_size_t, c_void_p, CDLL from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue +from types import ModuleType from typing import ( Any, Callable, @@ -29,7 +30,13 @@ from torch._dynamo.testing import rand_strided from torch._inductor import ir -from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache +from torch._inductor.codecache import ( + CppCodeCache, + CUDACodeCache, + DLLWrapper, + get_hash, + PyCodeCache, +) if TYPE_CHECKING: from torch._inductor.select_algorithm import TritonTemplateCaller @@ -661,6 +668,60 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" +@dataclasses.dataclass +class CppBenchmarkRequest(BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, List[TensorMeta]], + output_tensor_meta: Union[TensorMeta, List[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ): + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.hash_key = get_hash(source_code) + self.DLL: Optional[Union[CDLL, ModuleType]] = None + + def precompile(self): + # Prepopulate CppCodeCache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + CppCodeCache.load(self.source_code, cuda=False) + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor + ) -> Callable[[], None]: + self.DLL = CppCodeCache.load(self.source_code, cuda=False) + args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] + log.debug( + "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.DLL, + args, + self.extra_args, + ) + run_method = getattr(self.DLL, self.kernel_name) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *self.extra_args, + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + + def __str__(self) -> str: + return f"{self.kernel_name=}" + + def benchmark_in_sub_process( choices: List[TritonTemplateCaller], ) -> Dict[TritonTemplateCaller, float]: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index d55b81e16d157..729d7bb8713c3 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -183,9 +183,6 @@ def get_global_cache_path() -> Optional[Path]: ) def __init__(self) -> None: - if not torch.cuda.is_available(): - return - self.system = CacheBase.get_system() self.local_cache_path = CacheBase.get_local_cache_path() diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 31dbe27c229ec..8f485bc3a5bf0 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -8,7 +8,7 @@ import sys from copy import copy, deepcopy from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union import sympy @@ -19,6 +19,7 @@ from torch.utils import _pytree as pytree from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges +from ..._dynamo.utils import counters from .. import codecache, config, ir, metrics from ..codegen.wrapper import WrapperCodeGen @@ -3729,6 +3730,8 @@ def _can_fuse_horizontal_impl(self, node1, node2): return self._why_fuse_nodes(node1, node2) is not None def can_fuse_horizontal(self, node1, node2): + if node1.is_template() or node2.is_template(): + return False if ( len(node1.get_nodes()) + len(node2.get_nodes()) > config.cpp.max_horizontal_fusion_size @@ -3809,6 +3812,9 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): + # TODO: support vertical fusion for template nodes + if node1.is_template() or node2.is_template(): + return False return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) @@ -3865,6 +3871,42 @@ def codegen_node( if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: self._set_flush_status(True) + def is_cpp_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ir.CppTemplateBuffer + ) + + def codegen_template( + self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] + ): + """ + Codegen a CPP template, possibly with fused epilogues + """ + counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cpp_template( + template_node + ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" + template_node = cast(SchedulerNode, template_node) + _, (_, rnumel) = template_node.group + assert rnumel == () + ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) + epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] + assert all( + isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes + ), "Epilogue nodes must all be instances of ir.ComputedBuffer" + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.scheduler.free_buffers() + def _get_scheduled_num_args(self): return self.kernel_group.get_num_args() @@ -3874,7 +3916,7 @@ def ready_to_flush(self): def codegen_sync(self): pass - def define_kernel(self, src_code, nodes): + def define_kernel(self, src_code, nodes, kernel_args=None): wrapper = V.graph.wrapper_code fused_name = ( get_fused_kernel_name(nodes, config.cpp.descriptive_names) @@ -3890,7 +3932,8 @@ def define_kernel(self, src_code, nodes): src_code = src_code.replace("#pragma CMT", "//") compile_wrapper = IndentedBuffer() - _, _, arg_types = self.kernel_group.args.cpp_argdefs() + args = self.kernel_group.args if kernel_args is None else kernel_args + _, _, arg_types = args.cpp_argdefs() if not V.graph.cpp_wrapper: compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") compile_wrapper.splice(src_code, strip=True) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py new file mode 100644 index 0000000000000..b57b4f171e819 --- /dev/null +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -0,0 +1,71 @@ +from typing import cast, List, Optional + +from ..ir import Buffer, CppTemplateBuffer, IRNode, Layout +from .cpp_template import CppTemplate + +from .cpp_template_kernel import CppTemplateKernel + +GEMM_TEMPLATE = r""" +{{template.header().getvalue()}} +// TODO: use micro-kernel to replace this naive GEMM implementation below +// TODO: support weight prepack +extern "C" +{{kernel.def_kernel(inputs=[X, W], outputs=[Y], names_str="X, W, Y", input_reorder=input_reorder)}} +{ + // TODO: support dynamic shapes + int64_t M = {{kernel.size(Y, 0)}}; + int64_t N = {{kernel.size(Y, 1)}}; + int64_t K = {{kernel.size(X, 1)}}; + + #pragma omp parallel for collapse(2) + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + {{kernel.acc_dtype(Y)}} sum = 0; + for (int64_t k = 0; k < K; ++k) { + sum += {{kernel.index(X, ["i", "k"])}} * {{kernel.index(W, ["k", "j"])}}; + } + {{kernel.index(Y, ["i", "j"])}} = sum; + } + } +} +""" + + +class CppGemmTemplate(CppTemplate): + def __init__( + self, + input_nodes, + layout: Layout, + input_reorder: Optional[List[int]] = None, + ): + super().__init__("cpp_gemm", input_nodes, layout, input_reorder) + + def render( # type: ignore[override] + self, + kernel: CppTemplateKernel, + template_buffer_node: Optional[CppTemplateBuffer] = None, + epilogue_nodes: Optional[List[IRNode]] = None, + **kwargs, + ) -> str: + assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template." + assert len(self.input_nodes) >= 2 + + if template_buffer_node is not None: + self.output_node = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + self.output_node = cast(Buffer, epilogue_nodes[-1]) + assert self.output_node is not None + + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + + options = dict( + X=X, + W=W, + Y=Y, + template=self, + kernel=kernel, + epilogues=epilogue_nodes, + input_reorder=self.input_reorder, + ) + return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py new file mode 100644 index 0000000000000..4eda7771062d0 --- /dev/null +++ b/torch/_inductor/codegen/cpp_template.py @@ -0,0 +1,121 @@ +import functools +import itertools +import logging +from typing import List, Optional +from unittest.mock import patch + +import sympy + +from .. import codecache +from ..autotune_process import CppBenchmarkRequest, TensorMeta +from ..ir import Buffer, IRNode, Layout +from ..utils import IndentedBuffer, Placeholder, unique +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateBuffer, CppTemplateCaller, CppTemplateKernel + +log = logging.getLogger(__name__) + + +class CppTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes, + layout: Layout, + input_reorder: Optional[List[int]] = None, + ): + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer("buf_out", layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate(self, **kwargs): + kernel_name = f"cpp_{self.name}" + with patch.object( + V.graph, "get_dtype", self._fake_get_dtype(self.output_node) + ), CppTemplateKernel( + kernel_name=kernel_name, + ) as kernel: + code = self.render(kernel=kernel, **kwargs) + _, call_args, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + + kernel_hash_name = f"cpp_{self.name}_{next(self.index_counter)}" + + # Create the BenchmarkRequest for CPP + bmreq = CppBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: CppTemplateBuffer, + epilogue_nodes: Optional[List[IRNode]] = None, + ): + kernel = CppTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + return kernel, render + + return CppTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + res.writeline(codecache.cpp_prefix()) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py new file mode 100644 index 0000000000000..0cd973187361c --- /dev/null +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -0,0 +1,151 @@ +from typing import Callable, Dict, List, Optional, Union + +import sympy + +import torch + +from torch._inductor.autotune_process import CppBenchmarkRequest +from ..ir import ( + Buffer, + ChoiceCaller, + CppTemplateBuffer, + IRNode, + Layout, + PrimitiveInfoType, + TensorBox, +) +from ..virtualized import V +from .common import Kernel, OpOverrides +from .cpp import cexpr_index + + +class CppTemplateKernel(Kernel): + overrides = OpOverrides + + def __init__(self, kernel_name): + super().__init__() + self.kernel_name = kernel_name + + def def_kernel( + self, + inputs: List[Buffer], + outputs: List[Buffer], + names_str: str = "", + input_reorder: Optional[List[int]] = None, + ) -> str: + input_names = [inp.get_name() for inp in inputs] + output_names = [out.get_name() for out in outputs] + all_names = input_names + output_names + assert len(all_names) == len(names_str.split(",")), ( + all_names, + names_str, + ) + names = names_str.split(",") + for i, input_name in enumerate(input_names): + self.args.input_buffers[input_name] = names[i].strip() + for i, output_name in enumerate(output_names): + self.args.output_buffers[output_name] = names[i + len(input_names)].strip() + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" + + def call_kernel(self, name: str, node: CppTemplateBuffer): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) + + def dtype(self, node: Buffer) -> str: + if node.get_dtype() == torch.float32: + return "float" + elif node.get_dtype() == torch.bfloat16: + return "float" + elif node.get_dtype() == torch.half: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def acc_dtype(self, node: Buffer) -> str: + if node.get_dtype() == torch.float32: + return "float" + elif node.get_dtype() == torch.bfloat16: + return "float" + elif node.get_dtype() == torch.half: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def size(self, node: Buffer, dim: int) -> str: + return str(node.get_size()[dim]) + + def index(self, node: Buffer, indices: List[str]) -> str: + indexer = node.make_indexer() + index = indexer([sympy.Symbol(idx) for idx in indices]) + return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" + + +class CppTemplateCaller(ChoiceCaller): + """ + CppTemplateCaller + + This class represents a caller for CPP template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CppBenchmarkRequest): The benchmark request for the caller. + template_buffer (CppTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: List[Buffer], + layout: Layout, + make_kernel_render: Callable[[CppTemplateBuffer, Optional[List[IRNode]]], str], + bmreq: CppBenchmarkRequest, + template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 + info_kwargs: Optional[ + Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + return self.bmreq.benchmark(*args, output_tensor=out) + + # def __str__(self): + # return f"CppTemplateCaller(source_file={self.bmreq.source_file})" + + # def call_name(self) -> str: + # return f"cpp_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + return {"backend": "CPP", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + return TensorBox.create( + CppTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + choice=self, + ) + ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0fd35193a877c..1f64f5d549fde 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3622,6 +3622,13 @@ def get_workspace_size(self): return self.workspace_size if self.workspace_size is not None else 0 +class CppTemplateBuffer(TemplateBuffer): + def __init__(self, layout, inputs, make_kernel_render, template, choice): + super().__init__(layout, inputs, make_kernel_render) + self.template = template + self.choice = choice + + @dataclasses.dataclass class InputsKernel(Buffer): inputs: List[Buffer] @@ -6015,7 +6022,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, packed_w, orig_w, batch_size): + def create(cls, x, packed_w, orig_w, bias, batch_size): x = cls.require_stride1(cls.realize_input(x)) orig_w = cls.require_stride1(cls.realize_input(orig_w)) *m, _ = x.get_size() @@ -6023,7 +6030,11 @@ def create(cls, x, packed_w, orig_w, batch_size): output_size = list(m) + [oc] output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] - constant_args = [None, batch_size] + constant_args = [batch_size] + if bias is not None: + inputs += [bias] + else: + constant_args.insert(0, None) return MKLPackedLinear( layout=FixedLayout( diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 0ebccbf27ea3b..e39947e64c3b2 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,10 +1,20 @@ -from typing import List +from typing import List, Optional import torch import torch.utils._pytree as pytree -from . import ir +from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +from torch._inductor.kernel.mm_common import mm_args +from . import config, ir from .ir import TensorBox -from .lowering import add, add_needs_realized_inputs, aten, register_lowering, to_dtype +from .lowering import ( + add, + add_needs_realized_inputs, + aten, + permute, + register_lowering, + to_dtype, +) +from .select_algorithm import autotune_select_algorithm, ExternKernelChoice def register_onednn_fusion_ops(): @@ -339,6 +349,12 @@ def qlinear_unary( ) if torch._C.has_mkl: + aten_mkl_linear = ExternKernelChoice( + torch.ops.mkl._mkl_linear, + "mkl::_mkl_linear", + has_out_variant=False, + kernel_creator=ir.MKLPackedLinear.create, + ) cpu_needs_realized_inputs.append(torch.ops.mkl._mkl_linear) @register_lowering(torch.ops.mkl._mkl_linear) @@ -346,12 +362,35 @@ def mkl_packed_linear( x: TensorBox, packed_w: TensorBox, orig_w: TensorBox, - b: TensorBox, + b: Optional[TensorBox], batch_size, + *, + layout=None, ): - result = TensorBox.create( - ir.MKLPackedLinear.create(x, packed_w, orig_w, batch_size) + choices = [ + # aten_mkl_linear.bind( + # (x, packed_w, orig_w, None, batch_size), layout + # ) + ] + if config.max_autotune or config.max_autotune_gemm: + transposed_w = permute( + orig_w, [*range(len(orig_w.get_size()) - 2), -1, -2] + ) + _, _, _, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + # TODO: match inputs of mkl_linear + template = CppGemmTemplate([x, transposed_w], layout) + template.maybe_append_choice(choices) + + # TODO: add input gen fn + chosen_node: TensorBox = autotune_select_algorithm( + "packed_linear", + choices, + [x, orig_w], + layout, ) + result = TensorBox.create(chosen_node) if b is not None: result = add(result, b) return result diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 03a8e63141ff9..64bb67075848f 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -662,17 +662,19 @@ def __init__( has_out_variant=True, op_overload=None, use_fallback_kernel=False, + kernel_creator=None, ): super().__init__() name = name or kernel.__name__ assert callable(kernel) - assert not hasattr(extern_kernels, name), "duplicate extern kernel" + assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" self.name = name self.cpp_kernel_name = cpp_kernel self.has_out_variant = has_out_variant setattr(extern_kernels, name, kernel) self.op_overload = op_overload self.use_fallback_kernel = use_fallback_kernel + self.kernel_creator = kernel_creator def to_callable(self): return getattr(extern_kernels, self.name) @@ -831,6 +833,8 @@ def output_node(self): inner = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) + elif self.choice.kernel_creator is not None: + inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) else: cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc inner = cls( @@ -1033,8 +1037,11 @@ def make_benchmark_fn( unique_example_inputs = { x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) for i, x in enumerate(input_nodes) + if isinstance(x, ir.IRNode) } - example_inputs = list(unique_example_inputs.values()) + example_inputs = list(unique_example_inputs.values()) + [ + x for x in input_nodes if not isinstance(x, ir.IRNode) + ] example_inputs_extern = [ torch.as_strided( unique_example_inputs[input_node.get_name()], @@ -1051,6 +1058,8 @@ def make_benchmark_fn( fallback=config.unbacked_symint_fallback, ), ) + if isinstance(input_node, ir.IRNode) + else input_node for input_node in input_nodes ] @@ -1076,7 +1085,9 @@ def tensor_repr(x): "inputs = [", ] for x in example_inputs: - lines.append(f" {tensor_repr(x)},") + lines.append( + f" {tensor_repr(x) if isinstance(x, torch.Tensor) else x}," + ) lines += ["]", f"out = {tensor_repr(out)}", ""] return "\n".join(lines) @@ -1226,20 +1237,24 @@ def key_of(node): """ sizevars = V.graph.sizevars return ( - node.get_device().type, - str(node.get_dtype()), - *sizevars.size_hints( - node.get_size(), - fallback=config.unbacked_symint_fallback, - ), - *sizevars.size_hints( - node.get_stride(), - fallback=config.unbacked_symint_fallback, - ), - sizevars.size_hint( - node.get_layout().offset, - fallback=config.unbacked_symint_fallback, - ), + ( + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + if isinstance(node, ir.IRNode) + else str(node) ) From b6ff5fe6b06b80efea845f8474195fcd2eaf6e97 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Tue, 16 Apr 2024 06:36:50 -0700 Subject: [PATCH 02/28] Update [ghstack-poisoned] --- test/inductor/test_mkldnn_pattern_matcher.py | 2 +- torch/_inductor/autotune_process.py | 70 +++++++---- torch/_inductor/codegen/cpp_gemm_template.py | 28 +++-- torch/_inductor/codegen/cpp_template.py | 9 +- torch/_inductor/ir.py | 7 +- torch/_inductor/mkldnn_lowerings.py | 118 ++++++++++++++++--- torch/_inductor/select_algorithm.py | 57 +++++---- 7 files changed, 212 insertions(+), 79 deletions(-) diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index e44f103571dab..51070fffa5534 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -318,7 +318,7 @@ def test_linear_fp32(self): class M(torch.nn.Module): def __init__(self, bias): super().__init__() - self.linear = torch.nn.Linear(10, 30, bias) + self.linear = torch.nn.Linear(10, 32, bias) def forward(self, x): return self.linear(x) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index baf45d67e4557..e7e66062abef9 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -42,7 +42,7 @@ from torch._inductor.select_algorithm import TritonTemplateCaller from . import config -from .utils import do_bench +from .utils import do_bench, timed from .virtualized import V CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" @@ -434,6 +434,14 @@ def make_run_fn( def cleanup_run_fn(self) -> None: pass + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + raise NotImplementedError() + def benchmark( self, *input_tensors: torch.Tensor, @@ -459,22 +467,7 @@ def benchmark( load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] start_ts = time.time() - device_idx_set = { - tensor.device.index - for tensor in [*input_tensors, output_tensor] - if isinstance(tensor, torch.Tensor) - and tensor.is_cuda - and tensor.device.index is not None - } - assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" - if len(device_idx_set) == 1: - device_idx = next(iter(device_idx_set)) - else: - device_idx = torch.cuda.current_device() - - with torch.cuda.device(device_idx): - out = do_bench(fn) - torch.cuda.synchronize() # shake out any CUDA errors + out = self.do_bench(fn, *input_tensors, output_tensor) if debug: bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] @@ -506,7 +499,34 @@ def benchmark( return self.value -class TritonBenchmarkRequest(BenchmarkRequest): +class GPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + device_idx_set = { + tensor.device.index + for tensor in [*input_tensors, output_tensor] + if isinstance(tensor, torch.Tensor) + and tensor.is_cuda + and tensor.device.index is not None + } + assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = torch.cuda.current_device() + + with torch.cuda.device(device_idx): + out = do_bench(fn) + torch.cuda.synchronize() # shake out any CUDA errors + + return out + + +class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! @@ -580,7 +600,7 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" -class CUDABenchmarkRequest(BenchmarkRequest): +class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! @@ -668,8 +688,18 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" +class CPUDeviceBenchmarkRequest(BenchmarkRequest): + def do_bench( + self, + fn, + *input_tensors: torch.Tensor, + output_tensor: Optional[torch.Tensor] = None, + ) -> float: + return timed(fn, ()) + + @dataclasses.dataclass -class CppBenchmarkRequest(BenchmarkRequest): +class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put Tensors in here! diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index b57b4f171e819..04c4dbcf0e7a6 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -8,9 +8,8 @@ GEMM_TEMPLATE = r""" {{template.header().getvalue()}} // TODO: use micro-kernel to replace this naive GEMM implementation below -// TODO: support weight prepack extern "C" -{{kernel.def_kernel(inputs=[X, W], outputs=[Y], names_str="X, W, Y", input_reorder=input_reorder)}} +{{kernel.def_kernel(inputs=[X, W], outputs=[Y], names_str="X, W, Y")}} { // TODO: support dynamic shapes int64_t M = {{kernel.size(Y, 0)}}; @@ -19,26 +18,35 @@ #pragma omp parallel for collapse(2) for (int64_t i = 0; i < M; ++i) { - for (int64_t j = 0; j < N; ++j) { - {{kernel.acc_dtype(Y)}} sum = 0; + for (int64_t j = 0; j < N/{{n_bs}}; ++j) { + {{kernel.acc_dtype(Y)}} sum[16]; + for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { + sum[ni] = 0; + } for (int64_t k = 0; k < K; ++k) { - sum += {{kernel.index(X, ["i", "k"])}} * {{kernel.index(W, ["k", "j"])}}; + for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { + sum[ni] += {{kernel.index(X, ["i", "k"])}} * {{kernel.index(W, ["j", "k", "ni"])}}; + } + } + for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { + int64_t n = j * {{n_bs}} + ni; + {{kernel.index(Y, ["i", "n"])}} = sum[ni]; } - {{kernel.index(Y, ["i", "j"])}} = sum; } } } """ -class CppGemmTemplate(CppTemplate): +class CppPackedGemmTemplate(CppTemplate): def __init__( self, input_nodes, layout: Layout, - input_reorder: Optional[List[int]] = None, + n_block_size: int = 1, ): - super().__init__("cpp_gemm", input_nodes, layout, input_reorder) + super().__init__("cpp_gemm", input_nodes, layout) + self.n_block_size = n_block_size def render( # type: ignore[override] self, @@ -63,9 +71,9 @@ def render( # type: ignore[override] X=X, W=W, Y=Y, + n_bs=self.n_block_size, template=self, kernel=kernel, epilogues=epilogue_nodes, - input_reorder=self.input_reorder, ) return self._template_from_string(GEMM_TEMPLATE).render(**options) diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 4eda7771062d0..99ec7305f8324 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -25,12 +25,10 @@ def __init__( name: str, input_nodes, layout: Layout, - input_reorder: Optional[List[int]] = None, ): super().__init__(name) self.input_nodes = input_nodes self.output_node: Buffer = Buffer("buf_out", layout) - self.input_reorder = input_reorder self.layout = layout def generate(self, **kwargs): @@ -49,13 +47,8 @@ def generate(self, **kwargs): kernel.args.python_argdefs(), ) - input_reorder = ( - self.input_reorder - if self.input_reorder is not None - else list(range(len(self.input_nodes))) - ) expected_args = list( - unique(self.input_nodes[idx].get_name() for idx in input_reorder) + unique(input_node.get_name() for input_node in self.input_nodes) ) expected_args.extend([self.output_node.get_name()]) assert list(call_args)[: len(expected_args)] == expected_args, ( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 1f64f5d549fde..43cb10d2bc530 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -68,6 +68,7 @@ developer_warning, do_bench, get_kernel_metadata, + is_cpu_device, is_dynamic, is_gpu, pad_listlike, @@ -75,6 +76,7 @@ sympy_index_symbol, sympy_product, sympy_subs, + timed, ) from .virtualized import ops, V @@ -3531,7 +3533,10 @@ def __init__(self, name, input_nodes, layout): def benchmark(self, *args, out) -> float: algo = self.to_callable() - return do_bench(lambda: algo(*args, out=out)) + if is_cpu_device(args): + return timed(lambda: algo(*args, out=out), ()) + else: + return do_bench(lambda: algo(*args, out=out)) def call_name(self) -> str: raise NotImplementedError() diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index e39947e64c3b2..643c8f0c3d390 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -2,8 +2,7 @@ import torch import torch.utils._pytree as pytree -from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate -from torch._inductor.kernel.mm_common import mm_args +from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from . import config, ir from .ir import TensorBox from .lowering import ( @@ -13,8 +12,14 @@ permute, register_lowering, to_dtype, + view, ) -from .select_algorithm import autotune_select_algorithm, ExternKernelChoice +from .select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, +) +from .virtualized import V def register_onednn_fusion_ops(): @@ -368,27 +373,106 @@ def mkl_packed_linear( layout=None, ): choices = [ - # aten_mkl_linear.bind( - # (x, packed_w, orig_w, None, batch_size), layout - # ) + aten_mkl_linear.bind( + (x, packed_w, orig_w, None, batch_size), layout + ) ] if config.max_autotune or config.max_autotune_gemm: - transposed_w = permute( - orig_w, [*range(len(orig_w.get_size()) - 2), -1, -2] - ) - _, _, _, layout, x, transposed_w = mm_args( - x, transposed_w, layout=layout - ) - # TODO: match inputs of mkl_linear - template = CppGemmTemplate([x, transposed_w], layout) - template.maybe_append_choice(choices) - # TODO: add input gen fn + class DataPreprocessorChoiceCallerWrapper: + def __init__(self, wrapped, processor): + self._wrapped = wrapped + self._processor = processor + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._processor(args, out) + return self._wrapped.benchmark(*new_args, out=new_out) + + class DataPreprocessorTemplateWrapper: + def __init__(self, wrapped_template_cls, processor, **kwargs): + self._processor = processor + assert "input_nodes" in kwargs + assert "layout" in kwargs + kwargs["input_nodes"], kwargs["layout"] = processor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice( + self, choices, **kwargs + ) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataPreprocessorChoiceCallerWrapper( + choice_caller, self._processor + ) + + *m, _ = x.get_size() + n, k = orig_w.get_size() + # TODO: decide block size per ISA + # TODO: use larger block size for larger batch sizes + # TODO: support n % n_block_size != 0 + n_block_size = 16 + if n % n_block_size == 0: + + def preprocessor(inputs, layout_or_out): + x = inputs[0] + w = inputs[2] + if isinstance(w, ir.IRNode): + blocked_w = permute( + view(w, (n / n_block_size, n_block_size, k)), + [0, 2, 1], + ) + blocked_w = ir.ExternKernel.require_contiguous( + blocked_w + ) + x, blocked_w = realize_inputs(x, blocked_w) + if layout_or_out is None: + layout_or_out = ir.FixedLayout( + x.get_device(), + x.get_dtype(), + [*m, n], + ) + else: + blocked_w = ( + w.reshape(n / n_block_size, n_block_size, k) + .transpose(1, 2) + .contiguous() + ) + return [x, blocked_w], layout_or_out + + template = DataPreprocessorTemplateWrapper( + CppPackedGemmTemplate, + preprocessor, + input_nodes=[x, packed_w, orig_w, None, batch_size], + layout=layout, + n_block_size=n_block_size, + ) + layout = template.layout + template.maybe_append_choice(choices) + + assert isinstance(packed_w.data, ir.StorageBox) + assert isinstance(packed_w.data.data, ir.ConstantBuffer) + assert isinstance(orig_w.data, ir.StorageBox) + assert isinstance(orig_w.data.data, ir.ConstantBuffer) + input_gen_fns = { + 1: lambda x: V.graph.constants[x.get_name()], + 2: lambda x: V.graph.constants[x.get_name()], + } chosen_node: TensorBox = autotune_select_algorithm( "packed_linear", choices, - [x, orig_w], + [x, packed_w, orig_w, None, batch_size], layout, + input_gen_fns=input_gen_fns, ) result = TensorBox.create(chosen_node) if b is not None: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 64bb67075848f..fc230bf247b9a 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -38,9 +38,11 @@ from .utils import ( do_bench, get_dtype_size, + is_cpu_device, Placeholder, sympy_dot, sympy_product, + timed, unique, ) from .virtualized import V @@ -804,7 +806,10 @@ def benchmark(self, *args, out): out_new, tuple(out.size()), tuple(out.stride()) ) out.copy_(out_new) # for correctness checking - return do_bench(lambda: algo(*args)) + if is_cpu_device(args): + return timed(lambda: algo(*args), ()) + else: + return do_bench(lambda: algo(*args)) def to_callable(self): fn = self.choice.to_callable() @@ -1042,26 +1047,32 @@ def make_benchmark_fn( example_inputs = list(unique_example_inputs.values()) + [ x for x in input_nodes if not isinstance(x, ir.IRNode) ] - example_inputs_extern = [ - torch.as_strided( - unique_example_inputs[input_node.get_name()], - V.graph.sizevars.size_hints( - input_node.get_size(), - fallback=config.unbacked_symint_fallback, - ), - V.graph.sizevars.size_hints( - input_node.get_stride(), - fallback=config.unbacked_symint_fallback, - ), - V.graph.sizevars.size_hint( - input_node.get_layout().offset, - fallback=config.unbacked_symint_fallback, - ), - ) - if isinstance(input_node, ir.IRNode) - else input_node - for input_node in input_nodes - ] + example_inputs_extern = [] + for input_node in input_nodes: + if isinstance(input_node, ir.IRNode): + example_input = unique_example_inputs[input_node.get_name()] + if example_input.is_mkldnn: + example_inputs_extern.append(example_input) + else: + example_inputs_extern.append( + torch.as_strided( + example_input, + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + ) + else: + example_inputs_extern.append(input_node) out = cls.benchmark_example_value(layout) out_extern = torch.as_strided( @@ -1101,7 +1112,8 @@ def benchmark_choice_in_current_process(choice): result = choice.benchmark(*example_inputs, out=out) if VERIFY: torch.testing.assert_close(out_extern, expected, **VERIFY) - torch.cuda.synchronize() # shake out any CUDA errors + if torch.cuda.is_available(): + torch.cuda.synchronize() # shake out any CUDA errors return result def benchmark_in_current_process(choices): @@ -1175,6 +1187,7 @@ def log_results( ) ) for n in input_nodes + if isinstance(n, ir.IRNode) ] ) n = None if log.getEffectiveLevel() == logging.DEBUG else 10 From 0355c46878a72a9d4f8960d3d6f09a55fcedfcfe Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Tue, 16 Apr 2024 08:10:24 -0700 Subject: [PATCH 03/28] Update [ghstack-poisoned] --- torch/_inductor/mkldnn_lowerings.py | 56 ++++++++----------------- torch/_inductor/select_algorithm.py | 64 +++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 40 deletions(-) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 643c8f0c3d390..1d26596f8f0ac 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -2,8 +2,8 @@ import torch import torch.utils._pytree as pytree -from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from . import config, ir +from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox from .lowering import ( add, @@ -16,9 +16,11 @@ ) from .select_algorithm import ( autotune_select_algorithm, + DataProcessorTemplateWrapper, ExternKernelChoice, realize_inputs, ) +from .utils import sympy_product from .virtualized import V @@ -378,43 +380,6 @@ def mkl_packed_linear( ) ] if config.max_autotune or config.max_autotune_gemm: - - class DataPreprocessorChoiceCallerWrapper: - def __init__(self, wrapped, processor): - self._wrapped = wrapped - self._processor = processor - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def benchmark(self, *args, out) -> float: - new_args, new_out = self._processor(args, out) - return self._wrapped.benchmark(*new_args, out=new_out) - - class DataPreprocessorTemplateWrapper: - def __init__(self, wrapped_template_cls, processor, **kwargs): - self._processor = processor - assert "input_nodes" in kwargs - assert "layout" in kwargs - kwargs["input_nodes"], kwargs["layout"] = processor( - kwargs["input_nodes"], kwargs["layout"] - ) - self._wrapped = wrapped_template_cls(**kwargs) - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def maybe_append_choice(self, choices, **kwargs): - return type(self._wrapped).maybe_append_choice( - self, choices, **kwargs - ) - - def generate(self, **kwargs): - choice_caller = self._wrapped.generate(**kwargs) - return DataPreprocessorChoiceCallerWrapper( - choice_caller, self._processor - ) - *m, _ = x.get_size() n, k = orig_w.get_size() # TODO: decide block size per ISA @@ -427,6 +392,7 @@ def preprocessor(inputs, layout_or_out): x = inputs[0] w = inputs[2] if isinstance(w, ir.IRNode): + x = view(x, [-1, k]) blocked_w = permute( view(w, (n / n_block_size, n_block_size, k)), [0, 2, 1], @@ -439,9 +405,10 @@ def preprocessor(inputs, layout_or_out): layout_or_out = ir.FixedLayout( x.get_device(), x.get_dtype(), - [*m, n], + [sympy_product((*m,)), n], ) else: + x = x.view([-1, k]) blocked_w = ( w.reshape(n / n_block_size, n_block_size, k) .transpose(1, 2) @@ -449,9 +416,18 @@ def preprocessor(inputs, layout_or_out): ) return [x, blocked_w], layout_or_out - template = DataPreprocessorTemplateWrapper( + def postprocessor(out): + if not isinstance(m, (list, tuple)): + return out + if isinstance(out, ir.IRNode): + return view(out, [*m, n]) + else: + return out.view([*m, n]) + + template = DataProcessorTemplateWrapper( CppPackedGemmTemplate, preprocessor, + postprocessor, input_nodes=[x, packed_w, orig_w, None, batch_size], layout=layout, n_block_size=n_block_size, diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index fc230bf247b9a..2813012f118a2 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -862,6 +862,70 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType } +class DataProcessorChoiceCallerWrapper: + def __init__(self, wrapped, preprocessor, postprocessor): + self._wrapped = wrapped + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x: x + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def benchmark(self, *args, out) -> float: + new_args, new_out = self._preprocessor(args, out) + result = self._wrapped.benchmark(*new_args, out=new_out) + new_out = self._postprocessor(new_out) + if out is not new_out: + out.copy_(new_out) + return result + + def output_node(self) -> ir.TensorBox: + result = self._wrapped.output_node() + return self._postprocessor(result) + + +class DataProcessorTemplateWrapper: + def __init__( + self, + wrapped_template_cls, + preprocessor, + postprocessor, + **kwargs, + ): + if preprocessor is not None: + self._preprocessor = preprocessor + else: + self._preprocessor = lambda x, y: (x, y) + if postprocessor is not None: + self._postprocessor = postprocessor + else: + self._postprocessor = lambda x, y: (x, y) + assert "input_nodes" in kwargs + assert "layout" in kwargs + kwargs["input_nodes"], kwargs["layout"] = preprocessor( + kwargs["input_nodes"], kwargs["layout"] + ) + self._wrapped = wrapped_template_cls(**kwargs) + + def __getattr__(self, name): + return getattr(self._wrapped, name) + + def maybe_append_choice(self, choices, **kwargs): + return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) + + def generate(self, **kwargs): + choice_caller = self._wrapped.generate(**kwargs) + return DataProcessorChoiceCallerWrapper( + choice_caller, self._preprocessor, self._postprocessor + ) + + class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): msg += f"\nFrom choice {choice}\n{inputs_str}" From ba94cdff1a66728742ce34e7ee687713a2d29fdf Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Tue, 16 Apr 2024 22:42:15 -0700 Subject: [PATCH 04/28] Update [ghstack-poisoned] --- torch/_inductor/ir.py | 6 +- torch/_inductor/mkldnn_lowerings.py | 4 +- torch/_inductor/select_algorithm.py | 85 ++++++++++++----------------- 3 files changed, 41 insertions(+), 54 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 43cb10d2bc530..e094d23725f7f 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -6027,7 +6027,7 @@ def codegen(self, wrapper): ) @classmethod - def create(cls, x, packed_w, orig_w, bias, batch_size): + def create(cls, x, packed_w, orig_w, B, batch_size): x = cls.require_stride1(cls.realize_input(x)) orig_w = cls.require_stride1(cls.realize_input(orig_w)) *m, _ = x.get_size() @@ -6036,8 +6036,8 @@ def create(cls, x, packed_w, orig_w, bias, batch_size): output_stride = make_contiguous_strides_for(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] - if bias is not None: - inputs += [bias] + if B is not None: + inputs += [B] else: constant_args.insert(0, None) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 1d26596f8f0ac..dd595658581c3 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -376,7 +376,7 @@ def mkl_packed_linear( ): choices = [ aten_mkl_linear.bind( - (x, packed_w, orig_w, None, batch_size), layout + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size ) ] if config.max_autotune or config.max_autotune_gemm: @@ -446,7 +446,7 @@ def postprocessor(out): chosen_node: TensorBox = autotune_select_algorithm( "packed_linear", choices, - [x, packed_w, orig_w, None, batch_size], + [x, packed_w, orig_w], layout, input_gen_fns=input_gen_fns, ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 2813012f118a2..f4c84a25e98c9 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1106,37 +1106,31 @@ def make_benchmark_fn( unique_example_inputs = { x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x) for i, x in enumerate(input_nodes) - if isinstance(x, ir.IRNode) } - example_inputs = list(unique_example_inputs.values()) + [ - x for x in input_nodes if not isinstance(x, ir.IRNode) - ] + example_inputs = list(unique_example_inputs.values()) example_inputs_extern = [] for input_node in input_nodes: - if isinstance(input_node, ir.IRNode): - example_input = unique_example_inputs[input_node.get_name()] - if example_input.is_mkldnn: - example_inputs_extern.append(example_input) - else: - example_inputs_extern.append( - torch.as_strided( - example_input, - V.graph.sizevars.size_hints( - input_node.get_size(), - fallback=config.unbacked_symint_fallback, - ), - V.graph.sizevars.size_hints( - input_node.get_stride(), - fallback=config.unbacked_symint_fallback, - ), - V.graph.sizevars.size_hint( - input_node.get_layout().offset, - fallback=config.unbacked_symint_fallback, - ), - ) - ) + example_input = unique_example_inputs[input_node.get_name()] + if example_input.is_mkldnn: + example_inputs_extern.append(example_input) else: - example_inputs_extern.append(input_node) + example_inputs_extern.append( + torch.as_strided( + example_input, + V.graph.sizevars.size_hints( + input_node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hints( + input_node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + V.graph.sizevars.size_hint( + input_node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), + ) + ) out = cls.benchmark_example_value(layout) out_extern = torch.as_strided( @@ -1160,9 +1154,7 @@ def tensor_repr(x): "inputs = [", ] for x in example_inputs: - lines.append( - f" {tensor_repr(x) if isinstance(x, torch.Tensor) else x}," - ) + lines.append(f" {tensor_repr(x)},") lines += ["]", f"out = {tensor_repr(out)}", ""] return "\n".join(lines) @@ -1251,7 +1243,6 @@ def log_results( ) ) for n in input_nodes - if isinstance(n, ir.IRNode) ] ) n = None if log.getEffectiveLevel() == logging.DEBUG else 10 @@ -1314,24 +1305,20 @@ def key_of(node): """ sizevars = V.graph.sizevars return ( - ( - node.get_device().type, - str(node.get_dtype()), - *sizevars.size_hints( - node.get_size(), - fallback=config.unbacked_symint_fallback, - ), - *sizevars.size_hints( - node.get_stride(), - fallback=config.unbacked_symint_fallback, - ), - sizevars.size_hint( - node.get_layout().offset, - fallback=config.unbacked_symint_fallback, - ), - ) - if isinstance(node, ir.IRNode) - else str(node) + node.get_device().type, + str(node.get_dtype()), + *sizevars.size_hints( + node.get_size(), + fallback=config.unbacked_symint_fallback, + ), + *sizevars.size_hints( + node.get_stride(), + fallback=config.unbacked_symint_fallback, + ), + sizevars.size_hint( + node.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ), ) From 5ad789902a6d72092d2334119e8bf1cd3d768603 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Wed, 17 Apr 2024 06:02:16 -0700 Subject: [PATCH 05/28] Update [ghstack-poisoned] --- test/inductor/test_cpu_select_algorithm.py | 71 ++++++++++++ test/inductor/test_mkldnn_pattern_matcher.py | 2 +- torch/_inductor/codegen/cpp_gemm_template.py | 102 +++++++++++++++++- .../_inductor/codegen/cpp_template_kernel.py | 5 +- torch/_inductor/kernel/mm.py | 19 ++++ torch/_inductor/mkldnn_lowerings.py | 75 +++---------- torch/_inductor/select_algorithm.py | 2 +- torch/_inductor/utils.py | 20 ++++ 8 files changed, 229 insertions(+), 67 deletions(-) create mode 100644 test/inductor/test_cpu_select_algorithm.py diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py new file mode 100644 index 0000000000000..706484fee8c41 --- /dev/null +++ b/test/inductor/test_cpu_select_algorithm.py @@ -0,0 +1,71 @@ +# Owner(s): ["oncall: cpu inductor"] +import functools +import unittest +from unittest.mock import patch + +import torch +import torch._dynamo.config as dynamo_config +import torch._inductor.config as inductor_config +import torch._inductor.select_algorithm as select_algorithm +from torch._dynamo.utils import counters +from torch._inductor.test_case import run_tests, TestCase + +from torch.testing._internal.common_utils import IS_MACOS, TEST_MKL + +aten = torch.ops.aten + + +def patches(fn): + def skip_cache(self, choices, name, key, generate): + return generate(choices) + + for patcher in [ + dynamo_config.patch(verbose=True), + inductor_config.patch(debug=True, max_autotune=True, epilogue_fusion=True), + patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), + patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), + ]: + fn = patcher(fn) + + @functools.wraps(fn) + def wrapped(*args, **kwargs): + counters.clear() + torch.manual_seed(12345) + return fn(*args, **kwargs) + + return wrapped + + +class TestSelectAlgorithm(TestCase): + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + def test_linear_fp32_cpu(self): + class M(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.linear = torch.nn.Linear(10, 32, bias) + + @torch.compile + def forward(self, x): + return self.linear(x) + + for bias in [True, False]: + counters.clear() + mod = M(bias=bias).eval() + v = torch.randn(2, 10) + mod(v) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + + +@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) +class TestDynamicSelectAlgorithm(TestCase): + test_linear_fp32_dynamic_shapes_cpu = TestSelectAlgorithm.test_linear_fp32_cpu + + +if __name__ == "__main__": + from torch.testing._internal.inductor_utils import HAS_CPU + + if HAS_CPU and not IS_MACOS: + run_tests() diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 51070fffa5534..e44f103571dab 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -318,7 +318,7 @@ def test_linear_fp32(self): class M(torch.nn.Module): def __init__(self, bias): super().__init__() - self.linear = torch.nn.Linear(10, 32, bias) + self.linear = torch.nn.Linear(10, 30, bias) def forward(self, x): return self.linear(x) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 04c4dbcf0e7a6..43fbbe2336ac3 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -1,6 +1,11 @@ from typing import cast, List, Optional +import torch +from torch._inductor.select_algorithm import DataProcessorTemplateWrapper +from .. import ir + from ..ir import Buffer, CppTemplateBuffer, IRNode, Layout +from ..lowering import permute, view from .cpp_template import CppTemplate from .cpp_template_kernel import CppTemplateKernel @@ -9,7 +14,7 @@ {{template.header().getvalue()}} // TODO: use micro-kernel to replace this naive GEMM implementation below extern "C" -{{kernel.def_kernel(inputs=[X, W], outputs=[Y], names_str="X, W, Y")}} +{{kernel.def_kernel(inputs=[X, W, inp], outputs=[Y], names_str="X, W, inp, Y")}} { // TODO: support dynamic shapes int64_t M = {{kernel.size(Y, 0)}}; @@ -21,7 +26,12 @@ for (int64_t j = 0; j < N/{{n_bs}}; ++j) { {{kernel.acc_dtype(Y)}} sum[16]; for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { + {% if inp is none %} sum[ni] = 0; + {% else %} + int64_t n = j * {{n_bs}} + ni; + sum[ni] = {{beta}} * {{kernel.index(inp, ["i", "n"])}}; + {% endif %} } for (int64_t k = 0; k < K; ++k) { for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { @@ -30,7 +40,7 @@ } for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { int64_t n = j * {{n_bs}} + ni; - {{kernel.index(Y, ["i", "n"])}} = sum[ni]; + {{kernel.index(Y, ["i", "n"])}} = {{alpha}} * sum[ni]; } } } @@ -43,11 +53,95 @@ def __init__( self, input_nodes, layout: Layout, + beta=1, + alpha=1, n_block_size: int = 1, ): super().__init__("cpp_gemm", input_nodes, layout) + self.beta = beta + self.alpha = alpha self.n_block_size = n_block_size + @staticmethod + def add_choices( + choices, layout, input_nodes, beta=1, alpha=1, trans_w=False, input_indices=None + ): + if input_indices is None: + input_indices = list(range(len(input_nodes))) + + def reorder_and_filter(inputs, layout_or_out): + if len(input_indices) == 2: + x_idx = input_indices[0] + w_idx = input_indices[1] + return [inputs[x_idx], inputs[w_idx]], layout_or_out + else: + assert ( + len(input_indices) == 3 + ), "Cpp Packed GEMM template requires 2 or 3 input nodes." + # assume the input order is [inp, x, w] and we reorder it to [x, w, inp] + inp_idx = input_indices[0] + x_idx = input_indices[1] + w_idx = input_indices[2] + return [inputs[x_idx], inputs[w_idx], inputs[inp_idx]], layout_or_out + + def transpose_weight(inputs, layout_or_out): + if not trans_w: + return inputs, layout_or_out + + new_inputs = list(inputs) + W = inputs[1] + if isinstance(W, ir.IRNode): + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + new_inputs[1] = permute(W, [1, 0]) + return new_inputs, layout_or_out + else: + assert isinstance(W, torch.Tensor) + new_inputs[1] = W.transpose(0, 1) + return new_inputs, layout_or_out + + n_block_size = 16 + + def pack_weight(inputs, layout_or_out): + W = inputs[1] + new_inputs = list(inputs) + if isinstance(W, ir.IRNode): + if not isinstance(W, ir.TensorBox): + W = ir.TensorBox(W) + k, n = W.get_size() + assert ( + n % n_block_size == 0 + ), f"The last dimension of W must be a multiple of {n_block_size}." + blocked_w = permute( + view(W, (k, n // n_block_size, n_block_size)), + [1, 0, 2], + ) + blocked_w = ir.ExternKernel.require_contiguous(blocked_w) + blocked_w = ir.ExternKernel.realize_input(blocked_w) + else: + k, n = list(W.shape) + blocked_w = ( + W.reshape(k, n // n_block_size, n_block_size) + .transpose(0, 1) + .contiguous() + ) + new_inputs[1] = blocked_w + return new_inputs, layout_or_out + + def preprocessor(inputs, layout): + return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) + + template = DataProcessorTemplateWrapper( + CppPackedGemmTemplate, + preprocessor, + None, + input_nodes=input_nodes, + layout=layout, + n_block_size=n_block_size, + ) + template.maybe_append_choice(choices) + return template + def render( # type: ignore[override] self, kernel: CppTemplateKernel, @@ -65,12 +159,16 @@ def render( # type: ignore[override] assert self.output_node is not None X, W = self.input_nodes[0], self.input_nodes[1] + inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None Y = self.output_node options = dict( X=X, W=W, + inp=inp, Y=Y, + beta=self.beta, + alpha=self.alpha, n_bs=self.n_block_size, template=self, kernel=kernel, diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 0cd973187361c..24323a59e359d 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -33,7 +33,7 @@ def def_kernel( names_str: str = "", input_reorder: Optional[List[int]] = None, ) -> str: - input_names = [inp.get_name() for inp in inputs] + input_names = [inp.get_name() if inp is not None else None for inp in inputs] output_names = [out.get_name() for out in outputs] all_names = input_names + output_names assert len(all_names) == len(names_str.split(",")), ( @@ -42,7 +42,8 @@ def def_kernel( ) names = names_str.split(",") for i, input_name in enumerate(input_names): - self.args.input_buffers[input_name] = names[i].strip() + if input_name is not None: + self.args.input_buffers[input_name] = names[i].strip() for i, output_name in enumerate(output_names): self.args.output_buffers[output_name] = names[i + len(input_names)].strip() cpp_argdefs, _, _ = self.args.cpp_argdefs() diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 2cb78e0c45c92..c4e5fc6573e9c 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional import torch +from torch._inductor.codegen.cpp_gemm_template import CppPackedGemmTemplate from torch._inductor.virtualized import V from .. import config as inductor_config from ..codegen.cuda.gemm_template import CUTLASSGemmTemplate @@ -14,6 +15,7 @@ ) from ..utils import ( use_aten_gemm_kernels, + use_cpp_packed_gemm_template, use_cutlass_template, use_max_autotune, use_triton_template, @@ -256,6 +258,23 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): fuseable=False, ) + if use_cpp_packed_gemm_template(layout, mat1, mat2): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [inp_expanded, mat1, mat2], + alpha=alpha, + beta=beta, + ) + choices.append( + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ) + return autotune_select_algorithm( "addmm", choices, [inp_expanded, mat1, mat2], layout ) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index dd595658581c3..0a6ca89287849 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -2,6 +2,7 @@ import torch import torch.utils._pytree as pytree +from torch._inductor.kernel.mm_common import mm_args from . import config, ir from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox @@ -12,15 +13,9 @@ permute, register_lowering, to_dtype, - view, ) -from .select_algorithm import ( - autotune_select_algorithm, - DataProcessorTemplateWrapper, - ExternKernelChoice, - realize_inputs, -) -from .utils import sympy_product +from .select_algorithm import autotune_select_algorithm, ExternKernelChoice +from .utils import use_cpp_packed_gemm_template from .virtualized import V @@ -380,60 +375,18 @@ def mkl_packed_linear( ) ] if config.max_autotune or config.max_autotune_gemm: - *m, _ = x.get_size() - n, k = orig_w.get_size() - # TODO: decide block size per ISA - # TODO: use larger block size for larger batch sizes - # TODO: support n % n_block_size != 0 - n_block_size = 16 - if n % n_block_size == 0: - - def preprocessor(inputs, layout_or_out): - x = inputs[0] - w = inputs[2] - if isinstance(w, ir.IRNode): - x = view(x, [-1, k]) - blocked_w = permute( - view(w, (n / n_block_size, n_block_size, k)), - [0, 2, 1], - ) - blocked_w = ir.ExternKernel.require_contiguous( - blocked_w - ) - x, blocked_w = realize_inputs(x, blocked_w) - if layout_or_out is None: - layout_or_out = ir.FixedLayout( - x.get_device(), - x.get_dtype(), - [sympy_product((*m,)), n], - ) - else: - x = x.view([-1, k]) - blocked_w = ( - w.reshape(n / n_block_size, n_block_size, k) - .transpose(1, 2) - .contiguous() - ) - return [x, blocked_w], layout_or_out - - def postprocessor(out): - if not isinstance(m, (list, tuple)): - return out - if isinstance(out, ir.IRNode): - return view(out, [*m, n]) - else: - return out.view([*m, n]) - - template = DataProcessorTemplateWrapper( - CppPackedGemmTemplate, - preprocessor, - postprocessor, - input_nodes=[x, packed_w, orig_w, None, batch_size], - layout=layout, - n_block_size=n_block_size, + transposed_w = permute(orig_w, [1, 0]) + *_, layout, x, transposed_w = mm_args( + x, transposed_w, layout=layout + ) + if use_cpp_packed_gemm_template(layout, x, transposed_w): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [x, packed_w, orig_w], + trans_w=True, + input_indices=[0, 2], ) - layout = template.layout - template.maybe_append_choice(choices) assert isinstance(packed_w.data, ir.StorageBox) assert isinstance(packed_w.data.data, ir.ConstantBuffer) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index f4c84a25e98c9..217053e4a344b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -905,7 +905,7 @@ def __init__( if postprocessor is not None: self._postprocessor = postprocessor else: - self._postprocessor = lambda x, y: (x, y) + self._postprocessor = lambda x: x assert "input_nodes" in kwargs assert "layout" in kwargs kwargs["input_nodes"], kwargs["layout"] = preprocessor( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5f85704c99380..feafbfccfb399 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1023,6 +1023,26 @@ def use_cutlass_template(layout, m, n, k): return res +def use_cpp_packed_gemm_template(layout, mat1, mat2): + from . import ir + + layout_dtypes = [torch.float32] + _, n = mat2.get_size() + if isinstance(mat2, ir.BaseView): + mat2 = mat2.unwrap_view() + # TODO: decide block size per ISA + # TODO: use larger block size for larger batch sizes + # TODO: support n % n_block_size != 0 + n_block_size = 16 + return ( + layout.device.type == "cpu" + and layout.dtype in layout_dtypes + and n % n_block_size == 0 + and isinstance(mat2, ir.StorageBox) + and mat2.is_module_buffer() + ) + + def use_aten_gemm_kernels(): return not use_max_autotune() or _use_autotune_backend("ATEN") From 1c4edcd69f46c2aeb821d4a4920fdf4cc7921aba Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Wed, 17 Apr 2024 06:50:51 -0700 Subject: [PATCH 06/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 8 +++--- .../_inductor/codegen/cpp_template_kernel.py | 27 +++++++++++++------ torch/_inductor/kernel/mm.py | 15 +++++------ 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 43fbbe2336ac3..c9bc00614774b 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -22,7 +22,7 @@ int64_t K = {{kernel.size(X, 1)}}; #pragma omp parallel for collapse(2) - for (int64_t i = 0; i < M; ++i) { + for (int64_t m = 0; m < M; ++m) { for (int64_t j = 0; j < N/{{n_bs}}; ++j) { {{kernel.acc_dtype(Y)}} sum[16]; for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { @@ -30,17 +30,17 @@ sum[ni] = 0; {% else %} int64_t n = j * {{n_bs}} + ni; - sum[ni] = {{beta}} * {{kernel.index(inp, ["i", "n"])}}; + sum[ni] = {{beta}} * {{kernel.index(inp, ["m", "n"])}}; {% endif %} } for (int64_t k = 0; k < K; ++k) { for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { - sum[ni] += {{kernel.index(X, ["i", "k"])}} * {{kernel.index(W, ["j", "k", "ni"])}}; + sum[ni] += {{kernel.index(X, ["m", "k"])}} * {{kernel.index(W, ["j", "k", "ni"])}}; } } for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { int64_t n = j * {{n_bs}} + ni; - {{kernel.index(Y, ["i", "n"])}} = {{alpha}} * sum[ni]; + {{kernel.index(Y, ["m", "n"])}} = {{alpha}} * sum[ni]; } } } diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 24323a59e359d..af9a341b49c81 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -1,3 +1,4 @@ +import itertools from typing import Callable, Dict, List, Optional, Union import sympy @@ -31,7 +32,6 @@ def def_kernel( inputs: List[Buffer], outputs: List[Buffer], names_str: str = "", - input_reorder: Optional[List[int]] = None, ) -> str: input_names = [inp.get_name() if inp is not None else None for inp in inputs] output_names = [out.get_name() for out in outputs] @@ -46,6 +46,22 @@ def def_kernel( self.args.input_buffers[input_name] = names[i].strip() for i, output_name in enumerate(output_names): self.args.output_buffers[output_name] = names[i + len(input_names)].strip() + inputs_not_none = [inp for inp in inputs if inp is not None] + unique_sizevars = { + s + for input in inputs_not_none + for sym in itertools.chain(input.get_size(), input.get_stride()) + for s in sym.free_symbols + } + unique_sizevars |= { + s + for output in outputs + for sym in itertools.chain(output.get_size(), output.get_stride()) + for s in sym.free_symbols + } + sizevars = sorted(unique_sizevars) + for sizevar in sizevars: + self.args.sizevars[sizevar] = f"k{sizevar}" cpp_argdefs, _, _ = self.args.cpp_argdefs() return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" @@ -75,11 +91,12 @@ def acc_dtype(self, node: Buffer) -> str: raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") def size(self, node: Buffer, dim: int) -> str: - return str(node.get_size()[dim]) + return str(self.rename_indexing(node.get_size()[dim])) def index(self, node: Buffer, indices: List[str]) -> str: indexer = node.make_indexer() index = indexer([sympy.Symbol(idx) for idx in indices]) + index = self.rename_indexing(index) return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" @@ -123,12 +140,6 @@ def benchmark(self, *args, out) -> float: assert self.bmreq is not None return self.bmreq.benchmark(*args, output_tensor=out) - # def __str__(self): - # return f"CppTemplateCaller(source_file={self.bmreq.source_file})" - - # def call_name(self) -> str: - # return f"cpp_template_kernels.{self.name}" - def hash_key(self) -> str: return "-".join( [ diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index c4e5fc6573e9c..2a092b80da50d 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -143,6 +143,13 @@ def tuned_mm(mat1, mat2, *, layout=None): choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True ) + if use_cpp_packed_gemm_template(layout, mat1, mat2): + CppPackedGemmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + from torch._inductor.ir import FixedLayout, FlexibleLayout if ( @@ -266,14 +273,6 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): alpha=alpha, beta=beta, ) - choices.append( - aten_addmm.bind( - (inp_expanded, mat1, mat2), - layout, - alpha=alpha, - beta=beta, - ) - ) return autotune_select_algorithm( "addmm", choices, [inp_expanded, mat1, mat2], layout From a56957da65acb49ea8152797257bb4f8a7b80d3a Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Wed, 17 Apr 2024 06:51:52 -0700 Subject: [PATCH 07/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index c9bc00614774b..17621abd69174 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -16,7 +16,6 @@ extern "C" {{kernel.def_kernel(inputs=[X, W, inp], outputs=[Y], names_str="X, W, inp, Y")}} { - // TODO: support dynamic shapes int64_t M = {{kernel.size(Y, 0)}}; int64_t N = {{kernel.size(Y, 1)}}; int64_t K = {{kernel.size(X, 1)}}; From 5bf33c43b12517fd5a25eaaf871519bc560fd479 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Wed, 17 Apr 2024 15:02:57 -0700 Subject: [PATCH 08/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 17621abd69174..fe847b3657b2a 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -16,6 +16,7 @@ extern "C" {{kernel.def_kernel(inputs=[X, W, inp], outputs=[Y], names_str="X, W, inp, Y")}} { + // TODO: support >2D tensors int64_t M = {{kernel.size(Y, 0)}}; int64_t N = {{kernel.size(Y, 1)}}; int64_t K = {{kernel.size(X, 1)}}; From f780f9ca9d85ee2778641b66889bd8a2989ad5c5 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Wed, 17 Apr 2024 19:28:44 -0700 Subject: [PATCH 09/28] Update [ghstack-poisoned] --- torch/_inductor/mkldnn_lowerings.py | 3 +-- torch/_inductor/utils.py | 6 +++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 0a6ca89287849..586a3757d72a3 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -396,14 +396,13 @@ def mkl_packed_linear( 1: lambda x: V.graph.constants[x.get_name()], 2: lambda x: V.graph.constants[x.get_name()], } - chosen_node: TensorBox = autotune_select_algorithm( + result: TensorBox = autotune_select_algorithm( "packed_linear", choices, [x, packed_w, orig_w], layout, input_gen_fns=input_gen_fns, ) - result = TensorBox.create(chosen_node) if b is not None: result = add(result, b) return result diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index feafbfccfb399..effe4f5a8e618 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1023,6 +1023,10 @@ def use_cutlass_template(layout, m, n, k): return res +def _use_template_for_cpu(layout): + return use_max_autotune() and layout.device.type == "cpu" + + def use_cpp_packed_gemm_template(layout, mat1, mat2): from . import ir @@ -1035,7 +1039,7 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): # TODO: support n % n_block_size != 0 n_block_size = 16 return ( - layout.device.type == "cpu" + _use_template_for_cpu(layout) and layout.dtype in layout_dtypes and n % n_block_size == 0 and isinstance(mat2, ir.StorageBox) From 0580a46dc99915665adfc8b0a0806dd98c752510 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Wed, 17 Apr 2024 22:29:15 -0700 Subject: [PATCH 10/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 41 ++++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index fe847b3657b2a..9e0d671a4f4c4 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -6,6 +6,7 @@ from ..ir import Buffer, CppTemplateBuffer, IRNode, Layout from ..lowering import permute, view +from ..virtualized import V from .cpp_template import CppTemplate from .cpp_template_kernel import CppTemplateKernel @@ -131,10 +132,34 @@ def pack_weight(inputs, layout_or_out): def preprocessor(inputs, layout): return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) + def postprocessor(output): + if isinstance(output, ir.IRNode): + # prepack the weight as input to the template buffer + # TODO: prune the unused constants in V.graph + # TODO: should we implement it with constant folding in the scheduler instead? + assert isinstance(output, ir.TensorBox) + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + W_node = ir.InputsKernel.unwrap_storage_for_input(new_input_nodes[1]) + assert isinstance(W_node, ir.ConstantBuffer) + assert W_node.get_name() in V.graph.constants + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, _ = pack_weight( + *transpose_weight(new_input_nodes, layout) + ) + W_packed = new_input_nodes[1] + W_packed_constant = V.graph.add_tensor_constant(W_packed) + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( + W_packed_constant + ) + return output + template = DataProcessorTemplateWrapper( CppPackedGemmTemplate, preprocessor, - None, + postprocessor, input_nodes=input_nodes, layout=layout, n_block_size=n_block_size, @@ -152,16 +177,18 @@ def render( # type: ignore[override] assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template." assert len(self.input_nodes) >= 2 - if template_buffer_node is not None: - self.output_node = template_buffer_node - if epilogue_nodes is not None and len(epilogue_nodes) > 0: - self.output_node = cast(Buffer, epilogue_nodes[-1]) - assert self.output_node is not None - X, W = self.input_nodes[0], self.input_nodes[1] inp = self.input_nodes[2] if len(self.input_nodes) > 2 else None Y = self.output_node + if template_buffer_node is not None: + # Use the updated prepacked weight buffer + W = template_buffer_node.inputs[1] + Y = template_buffer_node + if epilogue_nodes is not None and len(epilogue_nodes) > 0: + Y = cast(Buffer, epilogue_nodes[-1]) + assert self.output_node is not None + options = dict( X=X, W=W, From d795f31479a0cb5912d567dd80dfb47ded705fcf Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Fri, 26 Apr 2024 07:29:14 -0700 Subject: [PATCH 11/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp.py | 64 +--- torch/_inductor/codegen/cpp_gemm_template.py | 210 ++++++++++--- torch/_inductor/codegen/cpp_micro_gemm.py | 297 ++++++++++++++++++ torch/_inductor/codegen/cpp_prefix.h | 98 ++++++ torch/_inductor/codegen/cpp_template.py | 13 +- .../_inductor/codegen/cpp_template_kernel.py | 41 ++- torch/_inductor/codegen/cpp_utils.py | 68 ++++ torch/_inductor/mkldnn_lowerings.py | 6 +- torch/_inductor/utils.py | 12 +- 9 files changed, 687 insertions(+), 122 deletions(-) create mode 100644 torch/_inductor/codegen/cpp_micro_gemm.py create mode 100644 torch/_inductor/codegen/cpp_utils.py diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 8f485bc3a5bf0..7f4ca58f31374 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -61,56 +61,9 @@ OptimizationContext, ) -schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") - -DTYPE_TO_CPP = { - torch.float32: "float", - torch.float64: "double", - torch.float16: "half", - torch.int64: "int64_t", - torch.int32: "int", - torch.int16: "short", - torch.int8: "signed char", - torch.uint64: "uint64_t", - torch.uint32: "unsigned int", - torch.uint16: "unsigned short", - torch.uint8: "unsigned char", - torch.bool: "bool", - torch.bfloat16: "bfloat16", - torch.complex64: "complex64", - torch.float8_e4m3fn: "float8_e4m3fn", - torch.float8_e5m2: "float8_e5m2", -} - -DTYPE_TO_ATEN = { - torch.float32: "at::kFloat", - torch.float64: "at::kDouble", - torch.float16: "at::kHalf", - torch.int64: "at::kLong", - torch.int32: "at::kInt", - torch.int16: "at::kShort", - torch.int8: "at::kChar", - torch.uint64: "at::kUInt64", - torch.uint32: "at::kUInt32", - torch.uint16: "at::kUInt16", - torch.uint8: "at::kByte", - torch.uint32: "at::kUInt32", - torch.uint64: "at::kUInt64", - torch.bool: "at::kBool", - torch.bfloat16: "at::kBFloat16", - torch.complex32: "at::kComplexHalf", - torch.complex64: "at::kComplexFloat", - torch.complex128: "at::kComplexDouble", - torch.float8_e4m3fn: "at::kFloat8_e4m3fn", - torch.float8_e5m2: "at::kFloat8_e5m2", - torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", - torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", -} +from .cpp_utils import DTYPE_TO_CPP, value_to_cpp -DEVICE_TO_ATEN = { - "cpu": "at::kCPU", - "cuda": "at::kCUDA", -} +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") INDEX_TYPE = "long" @@ -164,19 +117,6 @@ BIN_CMP_OPS = ["eq", "ne", "le", "ge", "lt", "gt"] -def value_to_cpp(value, cpp_type): - if value == float("-inf"): - return f"-std::numeric_limits<{cpp_type}>::infinity()" - elif value == float("inf"): - return f"std::numeric_limits<{cpp_type}>::infinity()" - elif isinstance(value, bool): - return f"static_cast<{cpp_type}>({str(value).lower()})" - elif math.isnan(value): - return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" - else: - return f"static_cast<{cpp_type}>({repr(value)})" - - def reduction_init(reduction_type, dtype): if dtype in DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, the initial diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 9e0d671a4f4c4..5bbd60116fc9a 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -1,47 +1,101 @@ from typing import cast, List, Optional import torch -from torch._inductor.select_algorithm import DataProcessorTemplateWrapper from .. import ir from ..ir import Buffer, CppTemplateBuffer, IRNode, Layout +from ..kernel.mm_common import mm_args from ..lowering import permute, view +from ..select_algorithm import DataProcessorTemplateWrapper +from ..utils import cache_on_self, parallel_num_threads from ..virtualized import V +from .cpp_micro_gemm import create_micro_gemm from .cpp_template import CppTemplate from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import GemmBlocking GEMM_TEMPLATE = r""" {{template.header().getvalue()}} -// TODO: use micro-kernel to replace this naive GEMM implementation below + +{{micro_gemm.codegen_define()}} + extern "C" {{kernel.def_kernel(inputs=[X, W, inp], outputs=[Y], names_str="X, W, inp, Y")}} { - // TODO: support >2D tensors - int64_t M = {{kernel.size(Y, 0)}}; - int64_t N = {{kernel.size(Y, 1)}}; - int64_t K = {{kernel.size(X, 1)}}; - - #pragma omp parallel for collapse(2) - for (int64_t m = 0; m < M; ++m) { - for (int64_t j = 0; j < N/{{n_bs}}; ++j) { - {{kernel.acc_dtype(Y)}} sum[16]; - for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { - {% if inp is none %} - sum[ni] = 0; - {% else %} - int64_t n = j * {{n_bs}} + ni; - sum[ni] = {{beta}} * {{kernel.index(inp, ["m", "n"])}}; - {% endif %} - } - for (int64_t k = 0; k < K; ++k) { - for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { - sum[ni] += {{kernel.index(X, ["m", "k"])}} * {{kernel.index(W, ["j", "k", "ni"])}}; + constexpr int64_t num_threads = {{num_threads}}; + constexpr int64_t N = {{kernel.size(Y, 1)}}; + constexpr int64_t K = {{kernel.size(X, 1)}}; + constexpr int64_t M0 = {{micro_gemm.register_blocking.block_m}}; + constexpr int64_t N0 = {{micro_gemm.register_blocking.block_n}}; + constexpr int64_t K0 = {{micro_gemm.register_blocking.block_k}}; + constexpr int64_t N0_blocks = (N + N0 - 1) / N0; + constexpr int64_t K0_blocks = (K + K0 - 1) / K0; + + static_assert(N % N0 == 0, "N dimension must be multiple of N0"); + + {% if is_dynamic_M %} + const int64_t M = {{kernel.size(Y, 0)}}; + const int64_t M0_blocks = (M + M0 - 1) / M0; + // TODO: implement below + const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); + const int64_t M2_blocks = Mt_blocks; // TODO: improve cache blocking + {% else %} + constexpr int64_t M = {{kernel.size(Y, 0)}}; + constexpr int64_t M0_blocks = (M + M0 - 1) / M0; + constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; + constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; + constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; + constexpr int64_t M2_blocks = {{template.cache_blocking().block_m}}; + {% endif %} + constexpr int64_t K2_blocks = {{template.cache_blocking().block_k}}; + + // TODO: support k-slicing + TORCH_CHECK(Kt_blocks == K0_blocks, "Do not support k slicing yet."); + // make sure all partitions are assigned + TORCH_CHECK( + Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks, + "Not all partitions are assigned." + ); + + #pragma omp parallel num_threads({{num_threads}}) + { + int tid = omp_get_thread_num(); + int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; + mm_get_thread_blocks( + tid, M, N, K, Mt_blocks, Nt_blocks, Kt_blocks, + m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); + for (int64_t m2 = m_block_start; m2 < m_block_end; m2 += M2_blocks) { + int64_t m_start = m2 * M0; + int64_t m_end = std::min((m2 + M2_blocks) * M0, M); + for (int64_t n2 = n_block_start; n2 < n_block_end; ++n2) { + int64_t n_start = n2 * N0; + // TODO: use float32 temporary buffer to support bfloat16/float16 gemm + {% if inp is not none and beta != 0 %} + for (int64_t m = m_start; m < m_end; ++m) { + #pragma omp simd + for (int64_t n = n_start; n < n_start + N0; ++n) { + {{kernel.index(Y, "m", "n")}} = beta * {{kernel.index(inp, "m", "n")}}; + } + } + {% endif %} + for (int64_t k2 = k_block_start; k2 < k_block_end; k2 += K2_blocks) { + int64_t k_start = k2 * K0; + int64_t k_end = std::min((k2 + K2_blocks) * K0, K); + {% set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + {% set tile_W_3d = kernel.slice_nd(W, [("n2", "n2 + 1"), ("k_start", "k_end"), ()]) %} + {% set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + {% set tile_Y = kernel.slice_nd(Y, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} + {% if inp is not none and beta != 0 %} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True) }} + {% else %} + if (k2 == k_block_start) { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=False) }} + } else { + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True) }} + } + {% endif %} } - } - for (int64_t ni = 0; ni < {{n_bs}}; ++ni) { - int64_t n = j * {{n_bs}} + ni; - {{kernel.index(Y, ["m", "n"])}} = {{alpha}} * sum[ni]; } } } @@ -54,14 +108,73 @@ def __init__( self, input_nodes, layout: Layout, + num_threads: int, + register_blocking: GemmBlocking, beta=1, alpha=1, - n_block_size: int = 1, ): super().__init__("cpp_gemm", input_nodes, layout) self.beta = beta self.alpha = alpha - self.n_block_size = n_block_size + self.num_threads = num_threads + self.register_blocking = register_blocking + m, n = layout.size + _, k = input_nodes[0].get_size() + self.m, self.n, self.k = m, n, k + self.is_dynamic_M = len(self.m.free_symbols) > 0 + + @cache_on_self + def thread_blocking(self) -> GemmBlocking: + # TODO: allow tuning various blocking options + def get_factors(number): + factors = [] + # priorize more evenly divided factors + for i in range(int(number**0.5), 0, -1): + if number % i == 0: + factors.append(number // i) + factors.append(i) + return factors + + def get_blocking(num_threads, factor, m_blocks, n_blocks, k_blocks): + thread_block_n = (n_blocks + factor - 1) // factor + cofactor = num_threads // factor + thread_block_m = (m_blocks + cofactor - 1) // cofactor + return GemmBlocking(thread_block_m, thread_block_n, k_blocks) + + assert ( + not self.is_dynamic_M + ), "Unable to determine thread blocking for dynamic M." + register_blocking = self.register_blocking + m_blocks = (self.m + register_blocking.block_m - 1) // register_blocking.block_m + n_blocks = (self.n + register_blocking.block_n - 1) // register_blocking.block_n + k_blocks = (self.k + register_blocking.block_k - 1) // register_blocking.block_k + factors = get_factors(self.num_threads) + assert len(factors) > 0 + for factor in factors: + if n_blocks % factor == 0 and m_blocks % (self.num_threads // factor) == 0: + return get_blocking( + self.num_threads, factor, m_blocks, n_blocks, k_blocks + ) + for factor in factors: + if n_blocks % factor == 0: + return get_blocking( + self.num_threads, factor, m_blocks, n_blocks, k_blocks + ) + cofactor = self.num_threads // factor + if m_blocks % cofactor == 0: + return get_blocking( + self.num_threads, factor, m_blocks, n_blocks, k_blocks + ) + raise AssertionError("Should not reach here.") + + @cache_on_self + def cache_blocking(self) -> GemmBlocking: + # TODO: revise me + assert ( + not self.is_dynamic_M + ), "Unable to determine cache blocking for dynamic M." + thread_blocking = self.thread_blocking() + return GemmBlocking(thread_blocking.block_m, 1, thread_blocking.block_k) @staticmethod def add_choices( @@ -101,7 +214,13 @@ def transpose_weight(inputs, layout_or_out): new_inputs[1] = W.transpose(0, 1) return new_inputs, layout_or_out - n_block_size = 16 + num_threads = parallel_num_threads() + new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) + m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) + micro_gemm = create_micro_gemm( + "micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads + ) + _, block_n, _ = micro_gemm.register_blocking def pack_weight(inputs, layout_or_out): W = inputs[1] @@ -111,10 +230,10 @@ def pack_weight(inputs, layout_or_out): W = ir.TensorBox(W) k, n = W.get_size() assert ( - n % n_block_size == 0 - ), f"The last dimension of W must be a multiple of {n_block_size}." + n % block_n == 0 + ), f"The last dimension of W must be a multiple of {block_n}." blocked_w = permute( - view(W, (k, n // n_block_size, n_block_size)), + view(W, (k, n // block_n, block_n)), [1, 0, 2], ) blocked_w = ir.ExternKernel.require_contiguous(blocked_w) @@ -122,9 +241,7 @@ def pack_weight(inputs, layout_or_out): else: k, n = list(W.shape) blocked_w = ( - W.reshape(k, n // n_block_size, n_block_size) - .transpose(0, 1) - .contiguous() + W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous() ) new_inputs[1] = blocked_w return new_inputs, layout_or_out @@ -141,8 +258,7 @@ def postprocessor(output): template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) assert isinstance(template_buffer, ir.CppTemplateBuffer) new_input_nodes, _ = reorder_and_filter(input_nodes, layout) - W_node = ir.InputsKernel.unwrap_storage_for_input(new_input_nodes[1]) - assert isinstance(W_node, ir.ConstantBuffer) + W_node = new_input_nodes[1] assert W_node.get_name() in V.graph.constants W = V.graph.constants[W_node.get_name()] new_input_nodes[1] = W @@ -162,7 +278,10 @@ def postprocessor(output): postprocessor, input_nodes=input_nodes, layout=layout, - n_block_size=n_block_size, + num_threads=num_threads, + register_blocking=micro_gemm.register_blocking, + beta=beta, + alpha=alpha, ) template.maybe_append_choice(choices) return template @@ -189,6 +308,17 @@ def render( # type: ignore[override] Y = cast(Buffer, epilogue_nodes[-1]) assert self.output_node is not None + micro_gemm = create_micro_gemm( + f"{kernel.kernel_name}_micro_gemm", + self.m, + self.n, + self.k, + self.layout.dtype, + alpha=self.alpha, + num_threads=self.num_threads, + ) + assert self.register_blocking == micro_gemm.register_blocking + options = dict( X=X, W=W, @@ -196,7 +326,9 @@ def render( # type: ignore[override] Y=Y, beta=self.beta, alpha=self.alpha, - n_bs=self.n_block_size, + num_threads=self.num_threads, + micro_gemm=micro_gemm, + is_dynamic_M=self.is_dynamic_M, template=self, kernel=kernel, epilogues=epilogue_nodes, diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py new file mode 100644 index 0000000000000..ea9fb1d1f4051 --- /dev/null +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -0,0 +1,297 @@ +from collections import namedtuple + +import torch + +from .. import ir +from ..utils import parallel_num_threads +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateKernel +from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp + + +class CppMicroGemm: + DECLARE_KERNEL = r""" +template +inline void {{kernel_name}}( + const {{input_t}}* __restrict__ A, + const {{input_t}}* __restrict__ B, + {{output_t}}* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) +""" + + def __init__( + self, + name, + input_dtype, + output_dtype, + compute_dtype, + register_blocking, + alpha=1, + ): + self.name = name + self.input_dtype = input_dtype + self.output_dtype = output_dtype + self.compute_dtype = compute_dtype + self.register_blocking = register_blocking + self.alpha = alpha + + def get_common_options(self): + return { + "kernel_name": self.name, + "input_t": DTYPE_TO_CPP[self.input_dtype], + "output_t": DTYPE_TO_CPP[self.output_dtype], + "compute_t": DTYPE_TO_CPP[self.compute_dtype], + "alpha": self.alpha, + } + + def get_kernel_declaration(self): + options = self.get_common_options() + return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) + + def codegen_define(self) -> str: + raise NotImplementedError + + def codegen_call( + self, + kernel: CppTemplateKernel, + A: ir.Buffer, + B: ir.Buffer, + C: ir.Buffer, + accum: bool, + ) -> str: + A_ptr = f"&({kernel.index(A, [0, 0])})" + B_ptr = f"&({kernel.index(B, [0, 0])})" + C_ptr = f"&({kernel.index(C, [0, 0])})" + M = kernel.size(C, 0) + N = kernel.size(C, 1) + K = kernel.size(A, 1) + lda = kernel.stride(A, 0) + ldb = kernel.stride(B, 0) + ldc = kernel.stride(C, 0) + return f"{self.name}<{value_to_cpp(accum, 'bool')}>({A_ptr}, {B_ptr}, {C_ptr}, {M}, {N}, {K}, {lda}, {ldb}, {ldc});" + + +class CppMicroGemmRef(CppMicroGemm): + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + for (int64_t k = 0; k < K; ++k) { + C[m * ldc + n] = + ({{compute_t}})C[m * ldc + n] * accum + + ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; + } + } + } +} +""" + + def __init__(self, name, input_dtype, output_dtype, compute_dtype, alpha): + super().__init__( + name, input_dtype, output_dtype, compute_dtype, GemmBlocking(1, 1, 1), alpha + ) + + def codegen_define(self) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + **self.get_common_options(), + } + return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) + + +class CppMicroGemmFP32AVX(CppMicroGemm): + TEMPLATE_ENTRY = r""" +{{declare_kernel}} { + TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); + TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); + // TODO: loop unroll + for (int64_t m = 0; m < M; m += {{block_m}}) { + int64_t block_m = std::min(M - m, {{block_m}}); + for (int64_t n = 0; n < N; n += {{block_n}}) { + switch (block_m) { + {% for b in range(block_m, 0, -1) %} + case {{b}}: + {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( + A + m * lda, + B + n, + C + m * ldc + n, + {{block_k}}, + lda, + ldb, + ldc + ); + {% endfor %} + default: + TORCH_CHECK(false, "Unsupported block_m"); + } + } + } +} +""" + + TEMPLATE_KERNEL = r""" +template +inline void {{kernel_name}}_kernel( + const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc +) { + using Vectorized = at::vec::Vectorized; + constexpr auto VLEN = Vectorized::size(); + constexpr auto ROWS = BLOCK_M; + constexpr auto COLS = BLOCK_N / VLEN; + + Vectorized va; + at::vec::VectorizedN vb; + at::vec::VectorizedN vc; + + auto loadc = [&](auto i) { + if constexpr (accum) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i] = Vectorized::loadu(C + row * ldc + col * VLEN); + } else { + vc[i] = Vectorized(0.0f); + } + }; + c10::ForcedUnroll{}(loadc); + + auto compute = [&, COLS](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + {% if alpha != 1 %} + va = Vectorized(A[row * lda + k] * {{alpha}}); + {% else %} + va = Vectorized(A[row * lda + k]); + {% endif %} + } + + if constexpr (row == 0) { + vb[col] = Vectorized::loadu(B + k * ldb + col * VLEN); + } + + constexpr int idx = row * COLS + col; + vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); + }; + + // TODO: unroll k + for (int k = 0; k < K; ++k) { + c10::ForcedUnroll{}(compute, k); + } + + // store to C + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + vc[i].store(C + row * ldc + col * VLEN); + }; + c10::ForcedUnroll{}(storec); +} +""" + + def codegen_define(self) -> str: + options = { + "declare_kernel": self.get_kernel_declaration(), + "block_m": self.register_blocking.block_m, + "block_n": self.register_blocking.block_n, + "block_k": self.register_blocking.block_k, + **self.get_common_options(), + } + result = KernelTemplate._template_from_string(self.TEMPLATE_KERNEL).render( + options + ) + result += KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render( + options + ) + return result + + +CppMicroGemmConfig = namedtuple( + "CppMicroGemmConfig", + ["cls", "input_dtype", "output_dtype", "compute_dtype", "register_blocking"], +) + + +micro_gemm_configs = [ + # TODO: decide register_blocking per cpu arch, assume avx512 now + CppMicroGemmConfig( + CppMicroGemmFP32AVX, + torch.float32, + torch.float32, + torch.float32, + GemmBlocking(8, 32, 1), + ), + CppMicroGemmConfig( + CppMicroGemmFP32AVX, + torch.float32, + torch.float32, + torch.float32, + GemmBlocking(16, 16, 1), + ), +] + + +def create_micro_gemm( + name, + m, + n, + k, + input_dtype, + output_dtype=None, + compute_dtype=None, + alpha=1, + num_threads=-1, + use_ref=True, +) -> CppMicroGemm: + def create_from_config(config: CppMicroGemmConfig): + return config.cls( + name, + config.input_dtype, + config.output_dtype, + config.compute_dtype, + config.register_blocking, + alpha, + ) + + assert isinstance(n, int) or n.is_number + assert isinstance(k, int) or k.is_number + if output_dtype is None: + output_dtype = input_dtype + if compute_dtype is None: + compute_dtype = input_dtype + if num_threads < 0: + num_threads = parallel_num_threads() + matched_configs = [] + for config in micro_gemm_configs: + if ( + config.input_dtype == input_dtype + and config.output_dtype == output_dtype + and config.compute_dtype == compute_dtype + ): + score = 0 + block_m, block_n, block_k = config.register_blocking + if n % block_n == 0: + score += 1 + if k % block_k == 0: + score += 1 + if m % block_m == 0: + score += 1 + n_blocks = (n + block_n - 1) // block_n + if n_blocks >= num_threads: + score += 1 + matched_configs.append((score, config)) + if len(matched_configs) == 0 or use_ref: + return CppMicroGemmRef(name, input_dtype, output_dtype, compute_dtype, alpha) + return create_from_config(max(matched_configs, key=lambda x: x[0])[1]) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 5afb6195d48f4..cec3e5dfd1ee7 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -266,3 +266,101 @@ atomic_add(volatile T *addr, T offset) { std::atomic *atomic_addr = (std::atomic *)addr; atomic_addr->fetch_add(offset, std::memory_order_relaxed); } + +std::tuple mm_get_thread_blocking( + int64_t M, + int64_t N, + int64_t K, + int64_t M0, + int64_t N0, + int64_t K0, + int num_threads) { + auto get_factors = [](int64_t number) { + int count = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + count += 2; + } + } + auto factors = std::make_unique(count); + int index = 0; + for (int64_t i = std::sqrt(number); i > 0; --i) { + if (number % i == 0) { + factors[index++] = number / i; + factors[index++] = i; + } + } + return std::make_tuple(std::move(factors), count); + }; + + auto get_blocking = [](int64_t num_threads, + int64_t factor, + int64_t m_blocks, + int64_t n_blocks, + int64_t k_blocks) { + int64_t thread_block_n = (n_blocks + factor - 1) / factor; + int64_t cofactor = num_threads / factor; + int64_t thread_block_m = (m_blocks + cofactor - 1) / cofactor; + return std::make_tuple(thread_block_m, thread_block_n, k_blocks); + }; + + int64_t m_blocks = (M + M0 - 1) / M0; + int64_t n_blocks = (N + N0 - 1) / N0; + int64_t k_blocks = (K + K0 - 1) / K0; + + auto [factors, count] = get_factors(num_threads); + assert(count > 0); + + for (int i = 0; i < count; ++i) { + int64_t factor = factors[i]; + if (n_blocks % factor == 0 && + m_blocks % (num_threads / factor) == 0) { + return get_blocking( + num_threads, factor, m_blocks, n_blocks, k_blocks); + } + } + + for (int i = 0; i < count; ++i) { + int64_t factor = factors[i]; + if (n_blocks % factor == 0) { + return get_blocking( + num_threads, factor, m_blocks, n_blocks, k_blocks); + } + int64_t cofactor = num_threads / factor; + if (m_blocks % cofactor == 0) { + return get_blocking( + num_threads, factor, m_blocks, n_blocks, k_blocks); + } + } + + TORCH_CHECK(false, "Should not reach here."); + // Dummy return to avoid compiler warning + return std::make_tuple(0, 0, 0); +} + +inline void mm_get_thread_blocks( + int thread_id, + int64_t M_blocks, + int64_t N_blocks, + int64_t K_blocks, + int64_t Mt_blocks, + int64_t Nt_blocks, + int64_t Kt_blocks, + int64_t& m_block_start, + int64_t& m_block_end, + int64_t& n_block_start, + int64_t& n_block_end, + int64_t& k_block_start, + int64_t& k_block_end) { + int64_t num_Kt = (K_blocks + Kt_blocks - 1) / Kt_blocks; + k_block_start = (thread_id % num_Kt) * Kt_blocks; + k_block_end = std::min(k_block_start + Kt_blocks, K_blocks); + thread_id /= num_Kt; + int64_t num_Nt = (N_blocks + Nt_blocks - 1) / Nt_blocks; + n_block_start = (thread_id % num_Nt) * Nt_blocks; + n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); + thread_id /= num_Nt; + int64_t num_Mt = (M_blocks + Mt_blocks - 1) / Mt_blocks; + m_block_start = (thread_id % num_Mt) * Mt_blocks; + m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); +} \ No newline at end of file diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 99ec7305f8324..04d4a3d4b12a4 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -98,16 +98,11 @@ def make_kernel_render( def header(self) -> IndentedBuffer: res = IndentedBuffer() - res.splice( - """ - #include - #include - #include - #include - #include - """ - ) res.writeline(codecache.cpp_prefix()) + headers = r""" +#include "c10/util/Unroll.h" +""" + res.splice(headers) return res def render(self, **kwargs) -> str: diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index af9a341b49c81..993437b3d1722 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -1,11 +1,14 @@ import itertools -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import sympy +from sympy.parsing.sympy_parser import parse_expr import torch from torch._inductor.autotune_process import CppBenchmarkRequest +from torch._inductor.utils import sympy_index_symbol +from .. import lowering as L from ..ir import ( Buffer, ChoiceCaller, @@ -13,13 +16,25 @@ IRNode, Layout, PrimitiveInfoType, + ReinterpretView, TensorBox, + View, ) from ..virtualized import V from .common import Kernel, OpOverrides from .cpp import cexpr_index +def parse_expr_with_index_symbols(expr_str: str) -> sympy.Expr: + expr = parse_expr(expr_str) + int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} + return expr.subs(int_symbols) + + +def wrap_with_tensorbox(node) -> TensorBox: + return TensorBox.create(node) if isinstance(node, Buffer) else TensorBox(node) + + class CppTemplateKernel(Kernel): overrides = OpOverrides @@ -93,12 +108,32 @@ def acc_dtype(self, node: Buffer) -> str: def size(self, node: Buffer, dim: int) -> str: return str(self.rename_indexing(node.get_size()[dim])) - def index(self, node: Buffer, indices: List[str]) -> str: + def stride(self, node: Buffer, dim: int) -> str: + return str(self.rename_indexing(node.get_stride()[dim])) + + def index(self, node: Buffer, indices: List[Any]) -> str: indexer = node.make_indexer() - index = indexer([sympy.Symbol(idx) for idx in indices]) + index = indexer([sympy.Symbol(str(idx), integer=True) for idx in indices]) index = self.rename_indexing(index) return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" + def slice_nd(self, node, ranges: List[Tuple[Any]]) -> ReinterpretView: + assert len(ranges) == len(node.get_size()) + sliced = wrap_with_tensorbox(node) + for dim, _range in enumerate(ranges): + if len(_range) == 0: + continue + assert len(_range) == 2 + start, end = (parse_expr_with_index_symbols(str(r)) for r in _range) + sliced = L.slice_(sliced, dim, start, end, clamp=False) + assert isinstance(sliced.data, ReinterpretView) + return sliced.data + + def view(self, node, sizes: List[Any]) -> View: + node = wrap_with_tensorbox(node) + sizes = [parse_expr_with_index_symbols(str(s)) for s in sizes] + return L.view(node, sizes).data + class CppTemplateCaller(ChoiceCaller): """ diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py new file mode 100644 index 0000000000000..e6250dc229ced --- /dev/null +++ b/torch/_inductor/codegen/cpp_utils.py @@ -0,0 +1,68 @@ +import math +from collections import namedtuple + +import torch + +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "half", + torch.int64: "int64_t", + torch.int32: "int", + torch.int16: "short", + torch.int8: "signed char", + torch.uint64: "uint64_t", + torch.uint32: "unsigned int", + torch.uint16: "unsigned short", + torch.uint8: "unsigned char", + torch.bool: "bool", + torch.bfloat16: "bfloat16", + torch.complex64: "complex64", + torch.float8_e4m3fn: "float8_e4m3fn", + torch.float8_e5m2: "float8_e5m2", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "cpu": "at::kCPU", + "cuda": "at::kCUDA", +} + +GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) + + +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index 586a3757d72a3..ac032d80e257b 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -388,10 +388,8 @@ def mkl_packed_linear( input_indices=[0, 2], ) - assert isinstance(packed_w.data, ir.StorageBox) - assert isinstance(packed_w.data.data, ir.ConstantBuffer) - assert isinstance(orig_w.data, ir.StorageBox) - assert isinstance(orig_w.data.data, ir.ConstantBuffer) + assert packed_w.get_name() in V.graph.constants + assert orig_w.get_name() in V.graph.constants input_gen_fns = { 1: lambda x: V.graph.constants[x.get_name()], 2: lambda x: V.graph.constants[x.get_name()], diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index effe4f5a8e618..a2fc0fccf0346 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1029,19 +1029,21 @@ def _use_template_for_cpu(layout): def use_cpp_packed_gemm_template(layout, mat1, mat2): from . import ir + from .codegen.cpp_micro_gemm import create_micro_gemm + from .kernel.mm_common import mm_args layout_dtypes = [torch.float32] - _, n = mat2.get_size() + m, n, k, *_ = mm_args(mat1, mat2) if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() - # TODO: decide block size per ISA - # TODO: use larger block size for larger batch sizes # TODO: support n % n_block_size != 0 - n_block_size = 16 + _, n0, _ = create_micro_gemm( + "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() + ).register_blocking return ( _use_template_for_cpu(layout) and layout.dtype in layout_dtypes - and n % n_block_size == 0 + and n % n0 == 0 and isinstance(mat2, ir.StorageBox) and mat2.is_module_buffer() ) From 002bedb6e9dade3b38927004c93148ea64e8db3d Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 27 Apr 2024 06:40:26 -0700 Subject: [PATCH 12/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 6 +++--- torch/_inductor/codegen/cpp_micro_gemm.py | 13 +++++++------ torch/_inductor/codegen/cpp_prefix.h | 4 ++-- torch/_inductor/select_algorithm.py | 6 ++++++ 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 5bbd60116fc9a..aff50663a21a5 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -37,7 +37,6 @@ {% if is_dynamic_M %} const int64_t M = {{kernel.size(Y, 0)}}; const int64_t M0_blocks = (M + M0 - 1) / M0; - // TODO: implement below const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); const int64_t M2_blocks = Mt_blocks; // TODO: improve cache blocking {% else %} @@ -63,7 +62,7 @@ int tid = omp_get_thread_num(); int64_t m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end; mm_get_thread_blocks( - tid, M, N, K, Mt_blocks, Nt_blocks, Kt_blocks, + tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks, m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); for (int64_t m2 = m_block_start; m2 < m_block_end; m2 += M2_blocks) { int64_t m_start = m2 * M0; @@ -113,7 +112,7 @@ def __init__( beta=1, alpha=1, ): - super().__init__("cpp_gemm", input_nodes, layout) + super().__init__("packed_gemm", input_nodes, layout) self.beta = beta self.alpha = alpha self.num_threads = num_threads @@ -214,6 +213,7 @@ def transpose_weight(inputs, layout_or_out): new_inputs[1] = W.transpose(0, 1) return new_inputs, layout_or_out + # TODO: decide proper number of threads per problem size num_threads = parallel_num_threads() new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index ea9fb1d1f4051..c2b521a582344 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -82,11 +82,11 @@ class CppMicroGemmRef(CppMicroGemm): {{declare_kernel}} { for (int64_t m = 0; m < M; ++m) { for (int64_t n = 0; n < N; ++n) { + {{compute_t}} result = accum ? C[m * ldc + n] : 0; for (int64_t k = 0; k < K; ++k) { - C[m * ldc + n] = - ({{compute_t}})C[m * ldc + n] * accum - + ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; + result += ({{compute_t}})A[m * lda + k] * ({{compute_t}})B[k * ldb + n] * {{alpha}}; } + C[m * ldc + n] = result; } } } @@ -121,14 +121,15 @@ class CppMicroGemmFP32AVX(CppMicroGemm): A + m * lda, B + n, C + m * ldc + n, - {{block_k}}, + K, lda, ldb, ldc ); + break; {% endfor %} default: - TORCH_CHECK(false, "Unsupported block_m"); + TORCH_CHECK(false, "Unsupported block_m: ", block_m); } } } @@ -253,7 +254,7 @@ def create_micro_gemm( compute_dtype=None, alpha=1, num_threads=-1, - use_ref=True, + use_ref=False, ) -> CppMicroGemm: def create_from_config(config: CppMicroGemmConfig): return config.cls( diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index cec3e5dfd1ee7..d2a72865a5197 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -360,7 +361,6 @@ inline void mm_get_thread_blocks( n_block_start = (thread_id % num_Nt) * Nt_blocks; n_block_end = std::min(n_block_start + Nt_blocks, N_blocks); thread_id /= num_Nt; - int64_t num_Mt = (M_blocks + Mt_blocks - 1) / Mt_blocks; - m_block_start = (thread_id % num_Mt) * Mt_blocks; + m_block_start = std::min(thread_id * Mt_blocks, M_blocks); m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); } \ No newline at end of file diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 217053e4a344b..69b1afc5928d6 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -889,6 +889,9 @@ def output_node(self) -> ir.TensorBox: result = self._wrapped.output_node() return self._postprocessor(result) + def __repr__(self) -> str: + return f"DataProcessorChoiceCallerWrapper({self._wrapped})" + class DataProcessorTemplateWrapper: def __init__( @@ -925,6 +928,9 @@ def generate(self, **kwargs): choice_caller, self._preprocessor, self._postprocessor ) + def __repr__(self) -> str: + return f"DataProcessorTemplateWrapper({self._wrapped})" + class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): From 2bfc6035d710242a43946aa17d6b3eff499875ef Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 27 Apr 2024 20:29:58 -0700 Subject: [PATCH 13/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 36 +++++++++++--------- torch/_inductor/codegen/cpp_micro_gemm.py | 14 ++++---- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index aff50663a21a5..b74a0599ed3d4 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -34,22 +34,24 @@ static_assert(N % N0 == 0, "N dimension must be multiple of N0"); - {% if is_dynamic_M %} + // TODO(jgong5): improve cache blocking with CPU info (M2, K2) + {%- if is_dynamic_M %} const int64_t M = {{kernel.size(Y, 0)}}; const int64_t M0_blocks = (M + M0 - 1) / M0; const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); - const int64_t M2_blocks = Mt_blocks; // TODO: improve cache blocking - {% else %} + const int64_t M2_blocks = Mt_blocks; + const int64_t K2_blocks = Kt_blocks; + {%- else %} constexpr int64_t M = {{kernel.size(Y, 0)}}; constexpr int64_t M0_blocks = (M + M0 - 1) / M0; constexpr int64_t Mt_blocks = {{template.thread_blocking().block_m}}; constexpr int64_t Nt_blocks = {{template.thread_blocking().block_n}}; constexpr int64_t Kt_blocks = {{template.thread_blocking().block_k}}; constexpr int64_t M2_blocks = {{template.cache_blocking().block_m}}; - {% endif %} constexpr int64_t K2_blocks = {{template.cache_blocking().block_k}}; + {%- endif %} - // TODO: support k-slicing + // TODO(jgong5): support k-slicing TORCH_CHECK(Kt_blocks == K0_blocks, "Do not support k slicing yet."); // make sure all partitions are assigned TORCH_CHECK( @@ -69,31 +71,31 @@ int64_t m_end = std::min((m2 + M2_blocks) * M0, M); for (int64_t n2 = n_block_start; n2 < n_block_end; ++n2) { int64_t n_start = n2 * N0; - // TODO: use float32 temporary buffer to support bfloat16/float16 gemm - {% if inp is not none and beta != 0 %} + // TODO(jgong5): use float32 temporary buffer to support bfloat16/float16 gemm + {%- if inp is not none and beta != 0 %} for (int64_t m = m_start; m < m_end; ++m) { #pragma omp simd for (int64_t n = n_start; n < n_start + N0; ++n) { - {{kernel.index(Y, "m", "n")}} = beta * {{kernel.index(inp, "m", "n")}}; + {{kernel.index(Y, ["m", "n"])}} = {{beta}} * {{kernel.index(inp, ["m", "n"])}}; } } - {% endif %} + {%- endif %} for (int64_t k2 = k_block_start; k2 < k_block_end; k2 += K2_blocks) { int64_t k_start = k2 * K0; int64_t k_end = std::min((k2 + K2_blocks) * K0, K); - {% set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} - {% set tile_W_3d = kernel.slice_nd(W, [("n2", "n2 + 1"), ("k_start", "k_end"), ()]) %} - {% set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} - {% set tile_Y = kernel.slice_nd(Y, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} - {% if inp is not none and beta != 0 %} + {%- set tile_X = kernel.slice_nd(X, [("m_start", "m_end"), ("k_start", "k_end")]) %} + {%- set tile_W_3d = kernel.slice_nd(W, [("n2", "n2 + 1"), ("k_start", "k_end"), ()]) %} + {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} + {%- set tile_Y = kernel.slice_nd(Y, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} + {%- if inp is not none and beta != 0 %} {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True) }} - {% else %} + {%- else %} if (k2 == k_block_start) { {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=False) }} } else { {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True) }} } - {% endif %} + {%- endif %} } } } @@ -168,7 +170,7 @@ def get_blocking(num_threads, factor, m_blocks, n_blocks, k_blocks): @cache_on_self def cache_blocking(self) -> GemmBlocking: - # TODO: revise me + # TODO(jgong5): improve cache blocking with CPU info assert ( not self.is_dynamic_M ), "Unable to determine cache blocking for dynamic M." diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index c2b521a582344..7fb411156cc2b 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -110,12 +110,12 @@ class CppMicroGemmFP32AVX(CppMicroGemm): {{declare_kernel}} { TORCH_CHECK(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}"); TORCH_CHECK(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}"); - // TODO: loop unroll + // TODO(jgong5): loop unroll for M and N for (int64_t m = 0; m < M; m += {{block_m}}) { int64_t block_m = std::min(M - m, {{block_m}}); for (int64_t n = 0; n < N; n += {{block_n}}) { switch (block_m) { - {% for b in range(block_m, 0, -1) %} + {%- for b in range(block_m, 0, -1) %} case {{b}}: {{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>( A + m * lda, @@ -127,7 +127,7 @@ class CppMicroGemmFP32AVX(CppMicroGemm): ldc ); break; - {% endfor %} + {%- endfor %} default: TORCH_CHECK(false, "Unsupported block_m: ", block_m); } @@ -172,11 +172,11 @@ class CppMicroGemmFP32AVX(CppMicroGemm): constexpr int col = i % COLS; if constexpr (col == 0) { - {% if alpha != 1 %} + {%- if alpha != 1 %} va = Vectorized(A[row * lda + k] * {{alpha}}); - {% else %} + {%- else %} va = Vectorized(A[row * lda + k]); - {% endif %} + {%- endif %} } if constexpr (row == 0) { @@ -187,7 +187,7 @@ class CppMicroGemmFP32AVX(CppMicroGemm): vc[idx] = at::vec::fmadd(va, vb[col], vc[idx]); }; - // TODO: unroll k + // TODO(jgong5): unroll k for (int k = 0; k < K; ++k) { c10::ForcedUnroll{}(compute, k); } From a416d41b6ff67bb3c088793e30d7a29fdc6ea19a Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 27 Apr 2024 22:17:50 -0700 Subject: [PATCH 14/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp.py | 162 +----------------- torch/_inductor/codegen/cpp_gemm_template.py | 6 +- torch/_inductor/codegen/cpp_micro_gemm.py | 17 +- .../_inductor/codegen/cpp_template_kernel.py | 4 +- torch/_inductor/codegen/cpp_utils.py | 161 +++++++++++++++++ torch/_inductor/codegen/cuda/cuda_kernel.py | 2 +- 6 files changed, 183 insertions(+), 169 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 7f4ca58f31374..4534c750f23c3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -53,7 +53,6 @@ DataTypePropagation, DeferredLine, DTYPE_TO_COMPUTATION_DTYPE, - ExprPrinter, IndentedBuffer, Kernel, KernelArgs, @@ -61,12 +60,10 @@ OptimizationContext, ) -from .cpp_utils import DTYPE_TO_CPP, value_to_cpp +from .cpp_utils import cexpr, cexpr_index, DTYPE_TO_CPP, INDEX_TYPE, value_to_cpp schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") -INDEX_TYPE = "long" - NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} RTYPE_TO_CPP = { "sum": "+", @@ -468,163 +465,6 @@ def _merge_outer_fusion_loop_levels( return cpp_kernel_proxy_list[0] -class CppPrinter(ExprPrinter): - def _print_Integer(self, expr): - return f"{int(expr)}L" - - def _print_Where(self, expr): - c = self.paren(self.doprint(expr.args[0])) - p = self.paren(self.doprint(expr.args[1])) - q = self.paren(self.doprint(expr.args[2])) - return f"{c} ? {p} : {q}" - - def _print_ModularIndexing(self, expr): - x, div, mod = expr.args - x = self.paren(self.doprint(x)) - if div != 1: - div = self.paren(self.doprint(div)) - if expr.is_integer: - x = f"c10::div_floor_integer({x}, {div})" - else: - x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" - mod = self.paren(self.doprint(mod)) - return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" - - def _print_FloorDiv(self, expr): - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - if expr.is_integer: - return f"c10::div_floor_integer({x}, {div})" - return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" - - def _print_floor(self, expr): - assert len(expr.args) == 1 - r = f"std::floor({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Pow(self, expr): - # Uses float constants to perform FP div - base, exp = expr.args - base = self._print(base) - - if exp == 0.5 or exp == -0.5: - return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" - assert exp.is_integer - exp = int(exp) - if exp > 0: - r = "*".join([self.paren(base)] * exp) - elif exp < 0: - r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - r = "1.0" - - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Rational(self, expr): - # Uses float constants to perform FP div - if expr.q == 1: - r = f"{expr.p}" - else: - r = f"{expr.p}.0/{expr.q}.0" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_ceiling(self, expr): - assert len(expr.args) == 1 - r = f"std::ceil({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Min(self, expr): - args = [self._print(a) for a in expr.args] - if len(args) == 2: - return f"std::min({args[0]}, {args[1]})" - else: - # Initializer list overload - il = "{" + ", ".join(args) + "}" - return f"std::min({il})" - - def _print_Max(self, expr): - args = [self._print(a) for a in expr.args] - if len(args) == 2: - return f"std::max({args[0]}, {args[1]})" - else: - # Initializer list overload - il = "{" + ", ".join(args) + "}" - return f"std::max({il})" - - def _print_Abs(self, expr): - assert len(expr.args) == 1 - return f"std::abs({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_cos(self, expr): - assert len(expr.args) == 1 - return f"std::cos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_cosh(self, expr): - assert len(expr.args) == 1 - return f"std::cosh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_acos(self, expr): - assert len(expr.args) == 1 - return f"std::acos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sin(self, expr): - assert len(expr.args) == 1 - return f"std::sin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sinh(self, expr): - assert len(expr.args) == 1 - return f"std::sinh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_asin(self, expr): - assert len(expr.args) == 1 - return f"std::asin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tan(self, expr): - assert len(expr.args) == 1 - return f"std::tan({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tanh(self, expr): - assert len(expr.args) == 1 - return f"std::tanh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_atan(self, expr): - assert len(expr.args) == 1 - return f"std::atan({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sqrt(self, expr): - return f"std::sqrt({self._print(expr.args[0])})" - - def _print_Round(self, expr): - assert len(expr.args) == 1 - return f"std::lrint({self._print(expr.args[0])})" - - def _print_RoundDecimal(self, expr): - assert len(expr.args) == 2 - number, ndigits = expr.args - if number.is_integer: - # ndigits < 0 should have been filtered by the sympy function - assert ndigits < 0 - raise ValueError( - f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." - ) - return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" - - def _print_BooleanTrue(self, expr): - return "true" - - def _print_BooleanFalse(self, expr): - return "false" - - -# A function to print, useful for printing sympy symbols. -cexpr = CppPrinter().doprint - - -def cexpr_index(index): - return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" - - class RecordOptimizationContext: def __init__(self, func_name: str = ""): self.func_name = func_name diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index b74a0599ed3d4..0a93cec64e1ae 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -88,12 +88,12 @@ {%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %} {%- set tile_Y = kernel.slice_nd(Y, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %} {%- if inp is not none and beta != 0 %} - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True) }} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(20, false) }} {%- else %} if (k2 == k_block_start) { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=False) }} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=False)|indent(24, false) }} } else { - {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True) }} + {{ micro_gemm.codegen_call(kernel, tile_X, tile_W, tile_Y, accum=True)|indent(24, false) }} } {%- endif %} } diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 7fb411156cc2b..e027ead8a1ea8 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -3,7 +3,7 @@ import torch from .. import ir -from ..utils import parallel_num_threads +from ..utils import IndentedBuffer, parallel_num_threads from .common import KernelTemplate from .cpp_template_kernel import CppTemplateKernel from .cpp_utils import DTYPE_TO_CPP, GemmBlocking, value_to_cpp @@ -74,7 +74,20 @@ def codegen_call( lda = kernel.stride(A, 0) ldb = kernel.stride(B, 0) ldc = kernel.stride(C, 0) - return f"{self.name}<{value_to_cpp(accum, 'bool')}>({A_ptr}, {B_ptr}, {C_ptr}, {M}, {N}, {K}, {lda}, {ldb}, {ldc});" + res = IndentedBuffer() + res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(") + with res.indent(): + res.writeline(f"{A_ptr},") + res.writeline(f"{B_ptr},") + res.writeline(f"{C_ptr},") + res.writeline(f"{M},") + res.writeline(f"{N},") + res.writeline(f"{K},") + res.writeline(f"{lda},") + res.writeline(f"{ldb},") + res.writeline(f"{ldc}") + res.writeline(");") + return res.getvalue() class CppMicroGemmRef(CppMicroGemm): diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 993437b3d1722..7941ce3ec0dc8 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -22,7 +22,7 @@ ) from ..virtualized import V from .common import Kernel, OpOverrides -from .cpp import cexpr_index +from .cpp_utils import cexpr_index def parse_expr_with_index_symbols(expr_str: str) -> sympy.Expr: @@ -113,7 +113,7 @@ def stride(self, node: Buffer, dim: int) -> str: def index(self, node: Buffer, indices: List[Any]) -> str: indexer = node.make_indexer() - index = indexer([sympy.Symbol(str(idx), integer=True) for idx in indices]) + index = indexer([parse_expr_with_index_symbols(str(idx)) for idx in indices]) index = self.rename_indexing(index) return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index e6250dc229ced..4a606ea512dc9 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -3,6 +3,8 @@ import torch +from .common import ExprPrinter + DTYPE_TO_CPP = { torch.float32: "float", torch.float64: "double", @@ -54,6 +56,165 @@ GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) +INDEX_TYPE = "long" + + +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr): + return f"{int(expr)}L" + + def _print_Where(self, expr): + c = self.paren(self.doprint(expr.args[0])) + p = self.paren(self.doprint(expr.args[1])) + q = self.paren(self.doprint(expr.args[2])) + return f"{c} ? {p} : {q}" + + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + if div != 1: + div = self.paren(self.doprint(div)) + if expr.is_integer: + x = f"c10::div_floor_integer({x}, {div})" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.paren(self.doprint(mod)) + return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" + + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + if expr.is_integer: + return f"c10::div_floor_integer({x}, {div})" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Pow(self, expr): + # Uses float constants to perform FP div + base, exp = expr.args + base = self._print(base) + + if exp == 0.5 or exp == -0.5: + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + assert exp.is_integer + exp = int(exp) + if exp > 0: + r = "*".join([self.paren(base)] * exp) + elif exp < 0: + r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Rational(self, expr): + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min({il})" + + def _print_Max(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max({il})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"std::abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sqrt(self, expr): + return f"std::sqrt({self._print(expr.args[0])})" + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" + + def _print_BooleanTrue(self, expr): + return "true" + + def _print_BooleanFalse(self, expr): + return "false" + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + def value_to_cpp(value, cpp_type): if value == float("-inf"): diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index 1bb536f28967d..da963368696ab 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -16,7 +16,7 @@ from ...virtualized import V from ..common import IndentedBuffer, Kernel, OpOverrides -from ..cpp import CppPrinter, DTYPE_TO_CPP +from ..cpp_utils import CppPrinter, DTYPE_TO_CPP if TYPE_CHECKING: from torch._inductor.codegen.cuda.cuda_template import CUDATemplate From 8d3f8aad825483669a64e2d4d3a4c0bf08a825b2 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 27 Apr 2024 22:28:26 -0700 Subject: [PATCH 15/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 0a93cec64e1ae..b7c70b2272fce 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -38,7 +38,13 @@ {%- if is_dynamic_M %} const int64_t M = {{kernel.size(Y, 0)}}; const int64_t M0_blocks = (M + M0 - 1) / M0; + {%- if num_threads > 1 %} const auto [Mt_blocks, Nt_blocks, Kt_blocks] = mm_get_thread_blocking(M, N, K, M0, N0, K0, num_threads); + {%- else %} + const auto Mt_blocks = M0_blocks; + const auto Nt_blocks = N0_blocks; + const auto Kt_blocks = K0_blocks; + {%- endif %} const int64_t M2_blocks = Mt_blocks; const int64_t K2_blocks = Kt_blocks; {%- else %} @@ -59,6 +65,7 @@ "Not all partitions are assigned." ); + {%- if num_threads > 1 %} #pragma omp parallel num_threads({{num_threads}}) { int tid = omp_get_thread_num(); @@ -66,6 +73,15 @@ mm_get_thread_blocks( tid, M0_blocks, N0_blocks, K0_blocks, Mt_blocks, Nt_blocks, Kt_blocks, m_block_start, m_block_end, n_block_start, n_block_end, k_block_start, k_block_end); + {%- else %} + { + int64_t m_block_start = 0; + int64_t m_block_end = M0_blocks; + int64_t n_block_start = 0; + int64_t n_block_end = N0_blocks; + int64_t k_block_start = 0; + int64_t k_block_end = K0_blocks; + {%- endif %} for (int64_t m2 = m_block_start; m2 < m_block_end; m2 += M2_blocks) { int64_t m_start = m2 * M0; int64_t m_end = std::min((m2 + M2_blocks) * M0, M); From 701a0cd9588a9be52acf0783d4f8ced06dea155a Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sat, 27 Apr 2024 23:42:10 -0700 Subject: [PATCH 16/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_micro_gemm.py | 113 +++++++++++++--------- torch/_inductor/codegen/cpp_template.py | 9 +- 2 files changed, 72 insertions(+), 50 deletions(-) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index e027ead8a1ea8..9121da43d5500 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1,8 +1,10 @@ from collections import namedtuple +from typing import Dict, List, Type import torch from .. import ir +from ..codecache import pick_vec_isa, VecAVX2, VecAVX512 from ..utils import IndentedBuffer, parallel_num_threads from .common import KernelTemplate from .cpp_template_kernel import CppTemplateKernel @@ -90,6 +92,32 @@ def codegen_call( return res.getvalue() +CppMicroGemmConfig = namedtuple( + "CppMicroGemmConfig", + [ + "input_dtype", + "output_dtype", + "compute_dtype", + "vec_isa_cls", + "register_blocking", + ], +) + +micro_gemm_configs: Dict[Type[CppMicroGemm], List[CppMicroGemmConfig]] = {} + + +def register_micro_gemm(*configs): + def inner(cls): + assert ( + cls not in micro_gemm_configs + ), f"Duplicate micro_gemm registration for {cls}" + assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" + micro_gemm_configs[cls] = list(configs) + return cls + + return inner + + class CppMicroGemmRef(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { @@ -118,6 +146,20 @@ def codegen_define(self) -> str: return KernelTemplate._template_from_string(self.TEMPLATE_ENTRY).render(options) +@register_micro_gemm( + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(8, 32, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX512, GemmBlocking(16, 16, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(4, 16, 1) + ), + CppMicroGemmConfig( + torch.float32, torch.float32, torch.float32, VecAVX2, GemmBlocking(8, 8, 1) + ), +) class CppMicroGemmFP32AVX(CppMicroGemm): TEMPLATE_ENTRY = r""" {{declare_kernel}} { @@ -232,31 +274,6 @@ def codegen_define(self) -> str: return result -CppMicroGemmConfig = namedtuple( - "CppMicroGemmConfig", - ["cls", "input_dtype", "output_dtype", "compute_dtype", "register_blocking"], -) - - -micro_gemm_configs = [ - # TODO: decide register_blocking per cpu arch, assume avx512 now - CppMicroGemmConfig( - CppMicroGemmFP32AVX, - torch.float32, - torch.float32, - torch.float32, - GemmBlocking(8, 32, 1), - ), - CppMicroGemmConfig( - CppMicroGemmFP32AVX, - torch.float32, - torch.float32, - torch.float32, - GemmBlocking(16, 16, 1), - ), -] - - def create_micro_gemm( name, m, @@ -269,8 +286,8 @@ def create_micro_gemm( num_threads=-1, use_ref=False, ) -> CppMicroGemm: - def create_from_config(config: CppMicroGemmConfig): - return config.cls( + def create_from_config(cls, config: CppMicroGemmConfig): + return cls( name, config.input_dtype, config.output_dtype, @@ -287,25 +304,29 @@ def create_from_config(config: CppMicroGemmConfig): compute_dtype = input_dtype if num_threads < 0: num_threads = parallel_num_threads() + vec_isa = pick_vec_isa() matched_configs = [] - for config in micro_gemm_configs: - if ( - config.input_dtype == input_dtype - and config.output_dtype == output_dtype - and config.compute_dtype == compute_dtype - ): - score = 0 - block_m, block_n, block_k = config.register_blocking - if n % block_n == 0: - score += 1 - if k % block_k == 0: - score += 1 - if m % block_m == 0: - score += 1 - n_blocks = (n + block_n - 1) // block_n - if n_blocks >= num_threads: - score += 1 - matched_configs.append((score, config)) + for cls, configs in micro_gemm_configs.items(): + for config in configs: + if not isinstance(vec_isa, config.vec_isa_cls): + continue + if ( + config.input_dtype == input_dtype + and config.output_dtype == output_dtype + and config.compute_dtype == compute_dtype + ): + score = 0 + block_m, block_n, block_k = config.register_blocking + if n % block_n == 0: + score += 1 + if k % block_k == 0: + score += 1 + if m % block_m == 0: + score += 1 + n_blocks = (n + block_n - 1) // block_n + if n_blocks >= num_threads: + score += 1 + matched_configs.append((score, cls, config)) if len(matched_configs) == 0 or use_ref: return CppMicroGemmRef(name, input_dtype, output_dtype, compute_dtype, alpha) - return create_from_config(max(matched_configs, key=lambda x: x[0])[1]) + return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 04d4a3d4b12a4..4a5272b7fdb07 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -99,10 +99,11 @@ def make_kernel_render( def header(self) -> IndentedBuffer: res = IndentedBuffer() res.writeline(codecache.cpp_prefix()) - headers = r""" -#include "c10/util/Unroll.h" -""" - res.splice(headers) + res.splice( + """ + #include "c10/util/Unroll.h" + """ + ) return res def render(self, **kwargs) -> str: From 85ce15a5f8f484d86d400863b9a51d0527d51ae6 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sun, 28 Apr 2024 00:07:46 -0700 Subject: [PATCH 17/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/common.py | 4 +- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/codegen/cpp_gemm_template.py | 29 ++++---- torch/_inductor/codegen/cpp_prefix.h | 2 +- torch/_inductor/codegen/cpp_template.py | 13 ++-- .../_inductor/codegen/cpp_template_kernel.py | 67 +++++++++---------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 18 +---- torch/_inductor/utils.py | 6 +- 8 files changed, 58 insertions(+), 83 deletions(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e34d3311d1e98..eac34b847e954 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1006,7 +1006,7 @@ def wrap_size_arg(self, size): return str(size) def cpp_argdefs(self): - from .cpp import DTYPE_TO_CPP, INDEX_TYPE + from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE call_args = [] arg_defs = [] @@ -1155,7 +1155,7 @@ def update_on_args(self, name, args, kwargs): class CppWrapperKernelArgs(KernelArgs): def wrap_ptr_arg(self, buf, dtype): - from .cpp import DTYPE_TO_CPP + from .cpp_utils import DTYPE_TO_CPP if config.abi_compatible: # In the abi_compatible model, we just return the buf here. diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 4534c750f23c3..113359c21adc8 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3592,7 +3592,7 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): - # TODO: support vertical fusion for template nodes + # TODO(jgong5): support vertical fusion for template nodes if node1.is_template() or node2.is_template(): return False return ( diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index b7c70b2272fce..0870b6a6181eb 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -1,11 +1,9 @@ from typing import cast, List, Optional import torch -from .. import ir +from .. import ir, lowering as L -from ..ir import Buffer, CppTemplateBuffer, IRNode, Layout from ..kernel.mm_common import mm_args -from ..lowering import permute, view from ..select_algorithm import DataProcessorTemplateWrapper from ..utils import cache_on_self, parallel_num_threads from ..virtualized import V @@ -124,7 +122,7 @@ class CppPackedGemmTemplate(CppTemplate): def __init__( self, input_nodes, - layout: Layout, + layout: ir.Layout, num_threads: int, register_blocking: GemmBlocking, beta=1, @@ -142,7 +140,7 @@ def __init__( @cache_on_self def thread_blocking(self) -> GemmBlocking: - # TODO: allow tuning various blocking options + # TODO(jgong5): allow tuning various blocking options def get_factors(number): factors = [] # priorize more evenly divided factors @@ -224,14 +222,14 @@ def transpose_weight(inputs, layout_or_out): if isinstance(W, ir.IRNode): if not isinstance(W, ir.TensorBox): W = ir.TensorBox(W) - new_inputs[1] = permute(W, [1, 0]) + new_inputs[1] = L.permute(W, [1, 0]) return new_inputs, layout_or_out else: assert isinstance(W, torch.Tensor) new_inputs[1] = W.transpose(0, 1) return new_inputs, layout_or_out - # TODO: decide proper number of threads per problem size + # TODO(jgong5): decide proper number of threads per problem size num_threads = parallel_num_threads() new_inputs, _ = transpose_weight(*reorder_and_filter(input_nodes, layout)) m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) @@ -250,8 +248,8 @@ def pack_weight(inputs, layout_or_out): assert ( n % block_n == 0 ), f"The last dimension of W must be a multiple of {block_n}." - blocked_w = permute( - view(W, (k, n // block_n, block_n)), + blocked_w = L.permute( + L.view(W, (k, n // block_n, block_n)), [1, 0, 2], ) blocked_w = ir.ExternKernel.require_contiguous(blocked_w) @@ -268,11 +266,10 @@ def preprocessor(inputs, layout): return pack_weight(*transpose_weight(*reorder_and_filter(inputs, layout))) def postprocessor(output): - if isinstance(output, ir.IRNode): + if isinstance(output, ir.TensorBox): # prepack the weight as input to the template buffer - # TODO: prune the unused constants in V.graph - # TODO: should we implement it with constant folding in the scheduler instead? - assert isinstance(output, ir.TensorBox) + # TODO(jgong5): prune the unused constants in V.graph + # Should we implement it with constant folding in the scheduler instead? template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) assert isinstance(template_buffer, ir.CppTemplateBuffer) new_input_nodes, _ = reorder_and_filter(input_nodes, layout) @@ -307,8 +304,8 @@ def postprocessor(output): def render( # type: ignore[override] self, kernel: CppTemplateKernel, - template_buffer_node: Optional[CppTemplateBuffer] = None, - epilogue_nodes: Optional[List[IRNode]] = None, + template_buffer_node: Optional[ir.CppTemplateBuffer] = None, + epilogue_nodes: Optional[List[ir.IRNode]] = None, **kwargs, ) -> str: assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template." @@ -323,7 +320,7 @@ def render( # type: ignore[override] W = template_buffer_node.inputs[1] Y = template_buffer_node if epilogue_nodes is not None and len(epilogue_nodes) > 0: - Y = cast(Buffer, epilogue_nodes[-1]) + Y = cast(ir.Buffer, epilogue_nodes[-1]) assert self.output_node is not None micro_gemm = create_micro_gemm( diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index d2a72865a5197..f23ad21a9e791 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -363,4 +363,4 @@ inline void mm_get_thread_blocks( thread_id /= num_Nt; m_block_start = std::min(thread_id * Mt_blocks, M_blocks); m_block_end = std::min(m_block_start + Mt_blocks, M_blocks); -} \ No newline at end of file +} diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index 4a5272b7fdb07..deedf9a624a28 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -6,13 +6,12 @@ import sympy -from .. import codecache +from .. import codecache, ir from ..autotune_process import CppBenchmarkRequest, TensorMeta -from ..ir import Buffer, IRNode, Layout from ..utils import IndentedBuffer, Placeholder, unique from ..virtualized import V from .common import KernelTemplate -from .cpp_template_kernel import CppTemplateBuffer, CppTemplateCaller, CppTemplateKernel +from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel log = logging.getLogger(__name__) @@ -24,11 +23,11 @@ def __init__( self, name: str, input_nodes, - layout: Layout, + layout: ir.Layout, ): super().__init__(name) self.input_nodes = input_nodes - self.output_node: Buffer = Buffer("buf_out", layout) + self.output_node: ir.Buffer = ir.Buffer("buf_out", layout) self.layout = layout def generate(self, **kwargs): @@ -71,8 +70,8 @@ def generate(self, **kwargs): ) def make_kernel_render( - template_node: CppTemplateBuffer, - epilogue_nodes: Optional[List[IRNode]] = None, + template_node: ir.CppTemplateBuffer, + epilogue_nodes: Optional[List[ir.IRNode]] = None, ): kernel = CppTemplateKernel( kernel_name=str(Placeholder.KERNEL_NAME), diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 7941ce3ec0dc8..5d2f0681e2b43 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -8,18 +8,7 @@ from torch._inductor.autotune_process import CppBenchmarkRequest from torch._inductor.utils import sympy_index_symbol -from .. import lowering as L -from ..ir import ( - Buffer, - ChoiceCaller, - CppTemplateBuffer, - IRNode, - Layout, - PrimitiveInfoType, - ReinterpretView, - TensorBox, - View, -) +from .. import ir, lowering as L from ..virtualized import V from .common import Kernel, OpOverrides from .cpp_utils import cexpr_index @@ -31,8 +20,10 @@ def parse_expr_with_index_symbols(expr_str: str) -> sympy.Expr: return expr.subs(int_symbols) -def wrap_with_tensorbox(node) -> TensorBox: - return TensorBox.create(node) if isinstance(node, Buffer) else TensorBox(node) +def wrap_with_tensorbox(node) -> ir.TensorBox: + return ( + ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) + ) class CppTemplateKernel(Kernel): @@ -44,8 +35,8 @@ def __init__(self, kernel_name): def def_kernel( self, - inputs: List[Buffer], - outputs: List[Buffer], + inputs: List[ir.Buffer], + outputs: List[ir.Buffer], names_str: str = "", ) -> str: input_names = [inp.get_name() if inp is not None else None for inp in inputs] @@ -80,12 +71,12 @@ def def_kernel( cpp_argdefs, _, _ = self.args.cpp_argdefs() return f"void {self.kernel_name}({', '.join(cpp_argdefs)})" - def call_kernel(self, name: str, node: CppTemplateBuffer): + def call_kernel(self, name: str, node: ir.CppTemplateBuffer): wrapper = V.graph.wrapper_code _, call_args, arg_types = self.args.cpp_argdefs() wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) - def dtype(self, node: Buffer) -> str: + def dtype(self, node: ir.Buffer) -> str: if node.get_dtype() == torch.float32: return "float" elif node.get_dtype() == torch.bfloat16: @@ -95,7 +86,7 @@ def dtype(self, node: Buffer) -> str: else: raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") - def acc_dtype(self, node: Buffer) -> str: + def acc_dtype(self, node: ir.Buffer) -> str: if node.get_dtype() == torch.float32: return "float" elif node.get_dtype() == torch.bfloat16: @@ -105,19 +96,19 @@ def acc_dtype(self, node: Buffer) -> str: else: raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") - def size(self, node: Buffer, dim: int) -> str: + def size(self, node: ir.Buffer, dim: int) -> str: return str(self.rename_indexing(node.get_size()[dim])) - def stride(self, node: Buffer, dim: int) -> str: + def stride(self, node: ir.Buffer, dim: int) -> str: return str(self.rename_indexing(node.get_stride()[dim])) - def index(self, node: Buffer, indices: List[Any]) -> str: + def index(self, node: ir.Buffer, indices: List[Any]) -> str: indexer = node.make_indexer() index = indexer([parse_expr_with_index_symbols(str(idx)) for idx in indices]) index = self.rename_indexing(index) return f"{self.args.input(node.get_name())}[{cexpr_index(index)}]" - def slice_nd(self, node, ranges: List[Tuple[Any]]) -> ReinterpretView: + def slice_nd(self, node, ranges: List[Tuple[Any]]) -> ir.ReinterpretView: assert len(ranges) == len(node.get_size()) sliced = wrap_with_tensorbox(node) for dim, _range in enumerate(ranges): @@ -126,38 +117,40 @@ def slice_nd(self, node, ranges: List[Tuple[Any]]) -> ReinterpretView: assert len(_range) == 2 start, end = (parse_expr_with_index_symbols(str(r)) for r in _range) sliced = L.slice_(sliced, dim, start, end, clamp=False) - assert isinstance(sliced.data, ReinterpretView) + assert isinstance(sliced.data, ir.ReinterpretView) return sliced.data - def view(self, node, sizes: List[Any]) -> View: + def view(self, node, sizes: List[Any]) -> ir.View: node = wrap_with_tensorbox(node) sizes = [parse_expr_with_index_symbols(str(s)) for s in sizes] return L.view(node, sizes).data -class CppTemplateCaller(ChoiceCaller): +class CppTemplateCaller(ir.ChoiceCaller): """ CppTemplateCaller - This class represents a caller for CPP template kernels. It is a subclass of ChoiceCaller. + This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. Attributes: name (str): The name of the caller. category (str): The category of the caller. bmreq (CppBenchmarkRequest): The benchmark request for the caller. - template_buffer (CppTemplateBuffer): The template buffer for the caller. + template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. """ def __init__( self, name: str, category: str, - input_nodes: List[Buffer], - layout: Layout, - make_kernel_render: Callable[[CppTemplateBuffer, Optional[List[IRNode]]], str], + input_nodes: List[ir.Buffer], + layout: ir.Layout, + make_kernel_render: Callable[ + [ir.CppTemplateBuffer, Optional[List[ir.IRNode]]], str + ], bmreq: CppBenchmarkRequest, template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 info_kwargs: Optional[ - Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]] ] = None, ): super().__init__(name, input_nodes, layout) @@ -183,12 +176,14 @@ def hash_key(self) -> str: ] ) - def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + def info_dict( + self, + ) -> Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]]: return {"backend": "CPP", "op_type": "unknown"} - def output_node(self) -> TensorBox: - return TensorBox.create( - CppTemplateBuffer( + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.CppTemplateBuffer( layout=self.layout, inputs=self.input_nodes, make_kernel_render=self.make_kernel_render, diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 34ecfda6e8c67..0cb0284581aea 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -16,6 +16,7 @@ from ..utils import cache_on_self, sympy_product from ..virtualized import V from .common import IndentedBuffer +from .cpp_utils import cexpr, CppPrinter, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP from .wrapper import EnterSubgraphLine, ExitSubgraphLine, pexpr, WrapperCodeGen @@ -52,9 +53,6 @@ def __init__(self): self.cached_output_id = count() self.scalar_to_tensor_id = count() self.custom_op_wrapper_loaded = False - - from .cpp import cexpr, CppPrinter - self.expr_printer = cexpr # CppPrinter sometimes calls at::native functions which causes problems in @@ -262,7 +260,6 @@ def write_input_output_info( @staticmethod def get_input_cpp_type(input): assert config.use_minimal_arrayref_interface - from .cpp import DTYPE_TO_CPP if isinstance(input, sympy.Expr): from ..graph import may_get_constant_buffer_dtype @@ -273,8 +270,6 @@ def get_input_cpp_type(input): return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" def generate_input_output_runtime_checks(self): - from .cpp import DTYPE_TO_ATEN - # In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each # real input/output tensor match ones provided at compile time via sample # input/output. @@ -375,8 +370,6 @@ def write_wrapper_decl(self): inputs_len = len(V.graph.graph_inputs.keys()) if V.graph.aot_mode: if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: - from .cpp import DTYPE_TO_CPP - input_cpp_types = ", ".join( f"{CppWrapperCpu.get_input_cpp_type(x)}" for x in V.graph.graph_inputs.values() @@ -537,7 +530,6 @@ def write_wrapper_decl(self): # unwrap input tensor back to scalar if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): from ..graph import may_get_constant_buffer_dtype - from .cpp import DTYPE_TO_CPP dtype = may_get_constant_buffer_dtype( V.graph.graph_inputs[input_key] @@ -1298,8 +1290,6 @@ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: return f"{{{', '.join(parts)}}}" def codegen_dynamic_scalar(self, node): - from .cpp import DTYPE_TO_ATEN, DTYPE_TO_CPP - (data,) = (t.codegen_reference() for t in node.inputs) if config.abi_compatible: dtype = node.inputs[0].get_dtype() @@ -1375,8 +1365,6 @@ def codegen_device(self, device): self.used_cached_devices.add(device.type) return f"cached_torch_device_type_{device.type},{device.index if device.index else 0}" else: - from .cpp import DEVICE_TO_ATEN - return ( f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})" if device.index is not None @@ -1389,8 +1377,6 @@ def codegen_dtype(self, dtype): self.used_cached_dtypes.add(dtype_str) return f"cached_torch_dtype_{dtype_str}" else: - from .cpp import DTYPE_TO_ATEN - return DTYPE_TO_ATEN[dtype] @functools.lru_cache(None) @@ -1453,8 +1439,6 @@ def make_allocation( device_type, device_id = device_str.split(",") device_idx = "this->device_idx_" if V.graph.aot_mode else device_id if buffer_if_can_stack_allocate is not None: - from .cpp import DTYPE_TO_CPP - self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate cpp_type = DTYPE_TO_CPP[dtype] numel = buffer_if_can_stack_allocate.get_numel() diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a2fc0fccf0346..aef8e17da16b6 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1036,14 +1036,14 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): m, n, k, *_ = mm_args(mat1, mat2) if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() - # TODO: support n % n_block_size != 0 - _, n0, _ = create_micro_gemm( + # TODO(jgong5): support n % n_block_size != 0 + _, n_block_size, _ = create_micro_gemm( "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() ).register_blocking return ( _use_template_for_cpu(layout) and layout.dtype in layout_dtypes - and n % n0 == 0 + and n % n_block_size == 0 and isinstance(mat2, ir.StorageBox) and mat2.is_module_buffer() ) From 5f0133ed657f60f255d8d7a1549a31a1f2a9b38e Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sun, 28 Apr 2024 07:20:26 -0700 Subject: [PATCH 18/28] Update [ghstack-poisoned] --- test/inductor/test_cpu_select_algorithm.py | 63 ++++++++++++++------ test/inductor/test_extension_backend.py | 4 +- torch/_inductor/autotune_process.py | 3 + torch/_inductor/codegen/cpp_gemm_template.py | 20 ++++++- torch/_inductor/codegen/cpp_micro_gemm.py | 4 +- torch/_inductor/utils.py | 3 + 6 files changed, 75 insertions(+), 22 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 706484fee8c41..5b59068c4dad8 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -9,8 +9,12 @@ import torch._inductor.select_algorithm as select_algorithm from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, +) -from torch.testing._internal.common_utils import IS_MACOS, TEST_MKL +from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL aten = torch.ops.aten @@ -37,31 +41,56 @@ def wrapped(*args, **kwargs): class TestSelectAlgorithm(TestCase): - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - def test_linear_fp32_cpu(self): + def _test_linear(self, batch_size, in_features, out_features, bias, dtype): class M(torch.nn.Module): def __init__(self, bias): super().__init__() - self.linear = torch.nn.Linear(10, 32, bias) + self.linear = torch.nn.Linear(in_features, out_features, bias) @torch.compile def forward(self, x): return self.linear(x) - for bias in [True, False]: - counters.clear() - mod = M(bias=bias).eval() - v = torch.randn(2, 10) - mod(v) - self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) - + counters.clear() + mod = M(bias=bias).to(dtype=dtype).eval() + v = torch.randn(batch_size, in_features).to(dtype=dtype) + mod(v) + self.assertEqual( + counters["inductor"]["select_algorithm_autotune"], + 1 if out_features != 1 else 0, + ) -@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) -class TestDynamicSelectAlgorithm(TestCase): - test_linear_fp32_dynamic_shapes_cpu = TestSelectAlgorithm.test_linear_fp32_cpu + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (1, 2, 1000)) + @parametrize("in_features", (1, 2, 1000)) + @parametrize("out_features", (1, 32, 1024)) + @parametrize("bias", (True, False)) + @dtypes(torch.float) + def test_linear_static_shapes( + self, batch_size, in_features, out_features, bias, dtype + ): + self._test_linear(batch_size, in_features, out_features, bias, dtype) + + @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (1, 2, 1000)) + @parametrize("in_features", (1, 2, 1000)) + @parametrize("out_features", (1, 32, 1024)) + @parametrize("bias", (True, False)) + @dtypes(torch.float) + def test_linear_dynamic_shapes( + self, batch_size, in_features, out_features, bias, dtype + ): + self._test_linear(batch_size, in_features, out_features, bias, dtype) + + +instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") if __name__ == "__main__": diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 097fdaad73fb9..8bc1cfb6fbedc 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -24,7 +24,7 @@ import torch._inductor.config as config from torch._inductor import metrics -from torch._inductor.codegen import cpp +from torch._inductor.codegen import cpp_utils from torch._inductor.codegen.common import ( get_scheduling_for_device, get_wrapper_codegen_for_device, @@ -139,7 +139,7 @@ def test_open_device_registration(self): def fn(a, b, c): return a * b + c - cpp.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1" + cpp_utils.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1" for cpp_wrapper_flag in [True, False]: with config.patch({"cpp_wrapper": cpp_wrapper_flag}): metrics.reset() diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index e7e66062abef9..4e4a2112bf552 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import ctypes import dataclasses import functools import logging @@ -726,6 +727,7 @@ def precompile(self): def make_run_fn( self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor ) -> Callable[[], None]: + # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf self.DLL = CppCodeCache.load(self.source_code, cuda=False) args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] log.debug( @@ -736,6 +738,7 @@ def make_run_fn( self.extra_args, ) run_method = getattr(self.DLL, self.kernel_name) + run_method.argtypes = [ctypes.c_ulonglong] * len(args) # Generate partial function. return functools.partial( diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 0870b6a6181eb..b306100e63e6d 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -1,6 +1,7 @@ from typing import cast, List, Optional import torch +import torch.utils from .. import ir, lowering as L from ..kernel.mm_common import mm_args @@ -252,13 +253,30 @@ def pack_weight(inputs, layout_or_out): L.view(W, (k, n // block_n, block_n)), [1, 0, 2], ) - blocked_w = ir.ExternKernel.require_contiguous(blocked_w) blocked_w = ir.ExternKernel.realize_input(blocked_w) + blocked_w = ir.ExternKernel.require_contiguous(blocked_w) + if isinstance(blocked_w, ir.ReinterpretView): + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + assert isinstance(blocked_w.layout, ir.FixedLayout) + blocked_w.layout = ir.FixedLayout( + blocked_w.layout.device, + blocked_w.layout.dtype, + blocked_w.layout.size, + ir.FlexibleLayout.contiguous_strides(blocked_w.layout.size), + blocked_w.layout.offset, + ) else: k, n = list(W.shape) blocked_w = ( W.reshape(k, n // block_n, block_n).transpose(0, 1).contiguous() ) + # normalize stride to be "contiguous_strides" per size + # this avoids the problems in L.view during template codegen + new_stride = [1] + for sz in reversed(blocked_w.shape[1:]): + new_stride.insert(0, new_stride[0] * sz) + blocked_w = blocked_w.as_strided(blocked_w.shape, new_stride) new_inputs[1] = blocked_w return new_inputs, layout_or_out diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 9121da43d5500..99a8e22abfc23 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -296,8 +296,8 @@ def create_from_config(cls, config: CppMicroGemmConfig): alpha, ) - assert isinstance(n, int) or n.is_number - assert isinstance(k, int) or k.is_number + assert isinstance(n, int) or n.is_number, n + assert isinstance(k, int) or k.is_number, k if output_dtype is None: output_dtype = input_dtype if compute_dtype is None: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index aef8e17da16b6..bf226fee17b16 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1034,6 +1034,9 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): layout_dtypes = [torch.float32] m, n, k, *_ = mm_args(mat1, mat2) + # TODO(jgong5): support dynamic shapes for n or k + if n.free_symbols or k.free_symbols: + return False if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() # TODO(jgong5): support n % n_block_size != 0 From b1f731bd0537ceb7c7c0afe0984d5d3472c1d1ff Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sun, 28 Apr 2024 07:59:42 -0700 Subject: [PATCH 19/28] Update [ghstack-poisoned] --- test/inductor/test_cpu_select_algorithm.py | 17 +++++++++----- .../_inductor/codegen/cpp_template_kernel.py | 23 +++++-------------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 5b59068c4dad8..a0ad70efb7637 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -41,7 +41,9 @@ def wrapped(*args, **kwargs): class TestSelectAlgorithm(TestCase): - def _test_linear(self, batch_size, in_features, out_features, bias, dtype): + def _test_linear( + self, batch_size, in_features, out_features, bias, input_3d, dtype + ): class M(torch.nn.Module): def __init__(self, bias): super().__init__() @@ -53,7 +55,8 @@ def forward(self, x): counters.clear() mod = M(bias=bias).to(dtype=dtype).eval() - v = torch.randn(batch_size, in_features).to(dtype=dtype) + B = (2, batch_size) if input_3d else (batch_size,) + v = torch.randn(*B, in_features).to(dtype=dtype) mod(v) self.assertEqual( counters["inductor"]["select_algorithm_autotune"], @@ -68,11 +71,12 @@ def forward(self, x): @parametrize("in_features", (1, 2, 1000)) @parametrize("out_features", (1, 32, 1024)) @parametrize("bias", (True, False)) + @parametrize("input_3d", (True, False)) @dtypes(torch.float) def test_linear_static_shapes( - self, batch_size, in_features, out_features, bias, dtype + self, batch_size, in_features, out_features, bias, input_3d, dtype ): - self._test_linear(batch_size, in_features, out_features, bias, dtype) + self._test_linear(batch_size, in_features, out_features, bias, input_3d, dtype) @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) @inductor_config.patch({"freezing": True}) @@ -83,11 +87,12 @@ def test_linear_static_shapes( @parametrize("in_features", (1, 2, 1000)) @parametrize("out_features", (1, 32, 1024)) @parametrize("bias", (True, False)) + @parametrize("input_3d", (True, False)) @dtypes(torch.float) def test_linear_dynamic_shapes( - self, batch_size, in_features, out_features, bias, dtype + self, batch_size, in_features, out_features, bias, input_3d, dtype ): - self._test_linear(batch_size, in_features, out_features, bias, dtype) + self._test_linear(batch_size, in_features, out_features, bias, input_3d, dtype) instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 5d2f0681e2b43..6bbd974839749 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -11,7 +11,7 @@ from .. import ir, lowering as L from ..virtualized import V from .common import Kernel, OpOverrides -from .cpp_utils import cexpr_index +from .cpp_utils import cexpr_index, DTYPE_TO_CPP def parse_expr_with_index_symbols(expr_str: str) -> sympy.Expr: @@ -65,7 +65,7 @@ def def_kernel( for sym in itertools.chain(output.get_size(), output.get_stride()) for s in sym.free_symbols } - sizevars = sorted(unique_sizevars) + sizevars = sorted(unique_sizevars, key=str) for sizevar in sizevars: self.args.sizevars[sizevar] = f"k{sizevar}" cpp_argdefs, _, _ = self.args.cpp_argdefs() @@ -77,30 +77,19 @@ def call_kernel(self, name: str, node: ir.CppTemplateBuffer): wrapper.generate_kernel_call(name, call_args, cuda=False, arg_types=arg_types) def dtype(self, node: ir.Buffer) -> str: - if node.get_dtype() == torch.float32: - return "float" - elif node.get_dtype() == torch.bfloat16: - return "float" - elif node.get_dtype() == torch.half: - return "float" - else: - raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + return DTYPE_TO_CPP[node.get_dtype()] def acc_dtype(self, node: ir.Buffer) -> str: - if node.get_dtype() == torch.float32: - return "float" - elif node.get_dtype() == torch.bfloat16: - return "float" - elif node.get_dtype() == torch.half: + if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: return "float" else: raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") def size(self, node: ir.Buffer, dim: int) -> str: - return str(self.rename_indexing(node.get_size()[dim])) + return cexpr_index(self.rename_indexing(node.get_size()[dim])) def stride(self, node: ir.Buffer, dim: int) -> str: - return str(self.rename_indexing(node.get_stride()[dim])) + return cexpr_index(self.rename_indexing(node.get_stride()[dim])) def index(self, node: ir.Buffer, indices: List[Any]) -> str: indexer = node.make_indexer() From ab8e6a95cb614aa9f40ff5647d34a0252ccf241a Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Sun, 28 Apr 2024 23:23:55 -0700 Subject: [PATCH 20/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 6 +++--- torch/_inductor/codegen/cpp_micro_gemm.py | 9 +++++---- torch/_inductor/codegen/cpp_prefix.h | 2 +- torch/_inductor/codegen/cpp_template_kernel.py | 7 +++++++ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index b306100e63e6d..00314359dfa5e 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -17,7 +17,7 @@ GEMM_TEMPLATE = r""" {{template.header().getvalue()}} -{{micro_gemm.codegen_define()}} +{{micro_gemm.codegen_define(kernel)}} extern "C" {{kernel.def_kernel(inputs=[X, W, inp], outputs=[Y], names_str="X, W, inp, Y")}} @@ -57,9 +57,9 @@ {%- endif %} // TODO(jgong5): support k-slicing - TORCH_CHECK(Kt_blocks == K0_blocks, "Do not support k slicing yet."); + {{kernel.assert_function}}(Kt_blocks == K0_blocks, "Do not support k slicing yet."); // make sure all partitions are assigned - TORCH_CHECK( + {{kernel.assert_function}}( Mt_blocks * Nt_blocks * Kt_blocks * {{num_threads}} >= M0_blocks * N0_blocks * K0_blocks, "Not all partitions are assigned." ); diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 99a8e22abfc23..e197ff134ef59 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -56,7 +56,7 @@ def get_kernel_declaration(self): options = self.get_common_options() return KernelTemplate._template_from_string(self.DECLARE_KERNEL).render(options) - def codegen_define(self) -> str: + def codegen_define(self, kernel: CppTemplateKernel) -> str: raise NotImplementedError def codegen_call( @@ -138,7 +138,7 @@ def __init__(self, name, input_dtype, output_dtype, compute_dtype, alpha): name, input_dtype, output_dtype, compute_dtype, GemmBlocking(1, 1, 1), alpha ) - def codegen_define(self) -> str: + def codegen_define(self, kernel: CppTemplateKernel) -> str: options = { "declare_kernel": self.get_kernel_declaration(), **self.get_common_options(), @@ -184,7 +184,7 @@ class CppMicroGemmFP32AVX(CppMicroGemm): break; {%- endfor %} default: - TORCH_CHECK(false, "Unsupported block_m: ", block_m); + {{kernel.assert_function}}(false, "Unsupported block_m: ", block_m); } } } @@ -257,9 +257,10 @@ class CppMicroGemmFP32AVX(CppMicroGemm): } """ - def codegen_define(self) -> str: + def codegen_define(self, kernel: CppTemplateKernel) -> str: options = { "declare_kernel": self.get_kernel_declaration(), + "kernel": kernel, "block_m": self.register_blocking.block_m, "block_n": self.register_blocking.block_n, "block_k": self.register_blocking.block_k, diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 065817b221201..c034522b83332 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -339,7 +339,7 @@ std::tuple mm_get_thread_blocking( } } - TORCH_CHECK(false, "Should not reach here."); + assert(false && "Should not reach here."); // Dummy return to avoid compiler warning return std::make_tuple(0, 0, 0); } diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 6bbd974839749..8a53bc98f6fc9 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -114,6 +114,13 @@ def view(self, node, sizes: List[Any]) -> ir.View: sizes = [parse_expr_with_index_symbols(str(s)) for s in sizes] return L.view(node, sizes).data + @property + def assert_function(self) -> str: + if V.graph.aot_mode: + return "AOTI_TORCH_CHECK" + else: + return "TORCH_CHECK" + class CppTemplateCaller(ir.ChoiceCaller): """ From c2c5d2d5da8ad21b368453d68fcefa15631e614f Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Mon, 29 Apr 2024 05:47:19 -0700 Subject: [PATCH 21/28] Update [ghstack-poisoned] --- torch/_inductor/autotune_process.py | 5 ++--- torch/_inductor/codegen/cpp_gemm_template.py | 1 + torch/_inductor/codegen/cpp_template.py | 9 ++++++++- .../_inductor/codegen/cpp_template_kernel.py | 10 +++++++++- torch/_inductor/runtime/runtime_utils.py | 19 +++++++++++++++++++ torch/_inductor/select_algorithm.py | 5 ++--- 6 files changed, 41 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 4f5ef813f0b55..66c28bc4ba168 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -43,8 +43,7 @@ from torch._inductor.select_algorithm import TritonTemplateCaller from . import config -from .runtime.runtime_utils import do_bench -from .utils import timed +from .runtime.runtime_utils import do_bench, do_bench_cpu from .virtualized import V CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" @@ -759,7 +758,7 @@ def do_bench( *input_tensors: torch.Tensor, output_tensor: Optional[torch.Tensor] = None, ) -> float: - return timed(fn, ()) + return do_bench_cpu(fn) @dataclasses.dataclass diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 00314359dfa5e..cbbe968691ac2 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -22,6 +22,7 @@ extern "C" {{kernel.def_kernel(inputs=[X, W, inp], outputs=[Y], names_str="X, W, inp, Y")}} { + {{kernel.maybe_codegen_profile()}} constexpr int64_t num_threads = {{num_threads}}; constexpr int64_t N = {{kernel.size(Y, 1)}}; constexpr int64_t K = {{kernel.size(X, 1)}}; diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index deedf9a624a28..3d15010a8838b 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -1,12 +1,14 @@ import functools import itertools import logging + +import sys from typing import List, Optional from unittest.mock import patch import sympy -from .. import codecache, ir +from .. import codecache, config, ir from ..autotune_process import CppBenchmarkRequest, TensorMeta from ..utils import IndentedBuffer, Placeholder, unique from ..virtualized import V @@ -103,6 +105,11 @@ def header(self) -> IndentedBuffer: #include "c10/util/Unroll.h" """ ) + enable_kernel_profile = ( + config.cpp.enable_kernel_profile and sys.platform == "linux" + ) + if enable_kernel_profile: + res.writelines(["#include "]) return res def render(self, **kwargs) -> str: diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 8a53bc98f6fc9..181a22436d668 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -8,7 +8,7 @@ from torch._inductor.autotune_process import CppBenchmarkRequest from torch._inductor.utils import sympy_index_symbol -from .. import ir, lowering as L +from .. import config, ir, lowering as L from ..virtualized import V from .common import Kernel, OpOverrides from .cpp_utils import cexpr_index, DTYPE_TO_CPP @@ -121,6 +121,14 @@ def assert_function(self) -> str: else: return "TORCH_CHECK" + def maybe_codegen_profile(self) -> str: + if config.cpp.enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' + else: + return "" + class CppTemplateCaller(ir.ChoiceCaller): """ diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index c0fdf65ec9b7a..19f4debbfa9b9 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -7,6 +7,7 @@ import os import re import tempfile +import time import torch @@ -100,6 +101,24 @@ def load_triton(): return triton_do_bench(*args, **kwargs)[0] +def do_bench_cpu(fn, warmup=3, times=10): + assert times > 0 + for _ in range(warmup): + fn() + durations = [] + for _ in range(times): + t0 = time.perf_counter() + fn() + t1 = time.perf_counter() + durations.append((t1 - t0) * 1000) + # return the median time + sorted_durations = sorted(durations) + if times % 2 == 0: + return (sorted_durations[times // 2 - 1] + sorted_durations[times // 2]) / 2 + else: + return sorted_durations[times // 2] + + def cache_dir() -> str: cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") if cache_dir is None: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 151534149cb36..48a532c81062b 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -35,7 +35,7 @@ from .codegen.triton_utils import config_of, signature_to_meta from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType -from .runtime.runtime_utils import do_bench +from .runtime.runtime_utils import do_bench, do_bench_cpu from .utils import ( get_dtype_size, is_cpu_device, @@ -43,7 +43,6 @@ sympy_dot, sympy_index_symbol, sympy_product, - timed, unique, ) from .virtualized import V @@ -846,7 +845,7 @@ def benchmark(self, *args, out): ) out.copy_(out_new) # for correctness checking if is_cpu_device(args): - return timed(lambda: algo(*args), ()) + return do_bench_cpu(lambda: algo(*args)) else: return do_bench(lambda: algo(*args)) From b079a2c277591fbd93df3fcebde1679e6c45ad5e Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Mon, 29 Apr 2024 06:28:06 -0700 Subject: [PATCH 22/28] Update [ghstack-poisoned] --- test/inductor/test_extension_backend.py | 4 +- torch/_inductor/codegen/common.py | 4 +- torch/_inductor/codegen/cpp.py | 278 +++++++++++++++++---- torch/_inductor/codegen/cpp_utils.py | 234 ----------------- torch/_inductor/codegen/cpp_wrapper_cpu.py | 18 +- 5 files changed, 251 insertions(+), 287 deletions(-) delete mode 100644 torch/_inductor/codegen/cpp_utils.py diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 3cb473255e74b..7bb531d980770 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -24,7 +24,7 @@ import torch._inductor.config as config from torch._inductor import codecache, metrics -from torch._inductor.codegen import cpp_utils +from torch._inductor.codegen import cpp from torch._inductor.codegen.common import ( get_scheduling_for_device, get_wrapper_codegen_for_device, @@ -140,7 +140,7 @@ def test_open_device_registration(self): def fn(a, b, c): return a * b + c - cpp_utils.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1" + cpp.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1" for cpp_wrapper_flag in [True, False]: with config.patch({"cpp_wrapper": cpp_wrapper_flag}): metrics.reset() diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a8876567c0c94..7126d565cf268 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -989,7 +989,7 @@ def wrap_size_arg(self, size): return str(size) def cpp_argdefs(self): - from .cpp_utils import DTYPE_TO_CPP, INDEX_TYPE + from .cpp import DTYPE_TO_CPP, INDEX_TYPE call_args = [] arg_defs = [] @@ -1138,7 +1138,7 @@ def update_on_args(self, name, args, kwargs): class CppWrapperKernelArgs(KernelArgs): def wrap_ptr_arg(self, buf, dtype): - from .cpp_utils import DTYPE_TO_CPP + from .cpp import DTYPE_TO_CPP if config.abi_compatible: # In the abi_compatible model, we just return the buf here. diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 7f6ca9098c3db..42b1aebe20f65 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -8,7 +8,7 @@ import sys from copy import copy, deepcopy from enum import Enum -from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import sympy @@ -19,7 +19,6 @@ from torch.utils import _pytree as pytree from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges -from ..._dynamo.utils import counters from .. import codecache, config, ir, metrics from ..codegen.wrapper import WrapperCodeGen @@ -53,6 +52,7 @@ DataTypePropagation, DeferredLine, DTYPE_TO_COMPUTATION_DTYPE, + ExprPrinter, IndentedBuffer, Kernel, KernelArgs, @@ -60,10 +60,59 @@ OptimizationContext, ) -from .cpp_utils import cexpr, cexpr_index, DTYPE_TO_CPP, INDEX_TYPE, value_to_cpp - schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "half", + torch.int64: "int64_t", + torch.int32: "int", + torch.int16: "short", + torch.int8: "signed char", + torch.uint64: "uint64_t", + torch.uint32: "unsigned int", + torch.uint16: "unsigned short", + torch.uint8: "unsigned char", + torch.bool: "bool", + torch.bfloat16: "bfloat16", + torch.complex64: "complex64", + torch.float8_e4m3fn: "float8_e4m3fn", + torch.float8_e5m2: "float8_e5m2", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "cpu": "at::kCPU", + "cuda": "at::kCUDA", +} + +INDEX_TYPE = "long" + NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"} RTYPE_TO_CPP = { "sum": "+", @@ -114,6 +163,19 @@ BIN_CMP_OPS = ["eq", "ne", "le", "ge", "lt", "gt"] +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + def reduction_init(reduction_type, dtype): if dtype in DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, the initial @@ -465,6 +527,168 @@ def _merge_outer_fusion_loop_levels( return cpp_kernel_proxy_list[0] +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr): + return f"{int(expr)}L" + + def _print_Where(self, expr): + c = self.paren(self.doprint(expr.args[0])) + p = self.paren(self.doprint(expr.args[1])) + q = self.paren(self.doprint(expr.args[2])) + return f"{c} ? {p} : {q}" + + def _print_ModularIndexing(self, expr): + x, div, mod = expr.args + x = self.paren(self.doprint(x)) + if div != 1: + div = self.paren(self.doprint(div)) + if expr.is_integer: + x = f"c10::div_floor_integer({x}, {div})" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.paren(self.doprint(mod)) + return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" + + def _print_FloorDiv(self, expr): + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + if expr.is_integer: + return f"c10::div_floor_integer({x}, {div})" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Trunc(self, expr): + assert len(expr.args) == 1 + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Pow(self, expr): + # Uses float constants to perform FP div + base, exp = expr.args + base = self._print(base) + + if exp == 0.5 or exp == -0.5: + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + assert exp.is_integer + exp = int(exp) + if exp > 0: + r = "*".join([self.paren(base)] * exp) + elif exp < 0: + r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Rational(self, expr): + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min({il})" + + def _print_Max(self, expr): + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max({args[0]}, {args[1]})" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max({il})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return f"std::abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sqrt(self, expr): + return f"std::sqrt({self._print(expr.args[0])})" + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr): + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" + + def _print_BooleanTrue(self, expr): + return "true" + + def _print_BooleanFalse(self, expr): + return "false" + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + + class RecordOptimizationContext: def __init__(self, func_name: str = ""): self.func_name = func_name @@ -3517,8 +3741,6 @@ def _can_fuse_horizontal_impl(self, node1, node2): return self._why_fuse_nodes(node1, node2) is not None def can_fuse_horizontal(self, node1, node2): - if node1.is_template() or node2.is_template(): - return False if ( len(node1.get_nodes()) + len(node2.get_nodes()) > config.cpp.max_horizontal_fusion_size @@ -3599,9 +3821,6 @@ def get_fusion_pair_priority(self, node1, node2): return 0 def can_fuse_vertical(self, node1, node2): - # TODO(jgong5): support vertical fusion for template nodes - if node1.is_template() or node2.is_template(): - return False return ( self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) @@ -3658,42 +3877,6 @@ def codegen_node( if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM: self._set_flush_status(True) - def is_cpp_template(self, node: BaseSchedulerNode) -> bool: - return isinstance(node, SchedulerNode) and isinstance( - node.node, ir.CppTemplateBuffer - ) - - def codegen_template( - self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode] - ): - """ - Codegen a CPP template, possibly with fused epilogues - """ - counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) - assert self.is_cpp_template( - template_node - ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" - template_node = cast(SchedulerNode, template_node) - _, (_, rnumel) = template_node.group - assert rnumel == () - ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node) - epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes] - assert all( - isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes - ), "Epilogue nodes must all be instances of ir.ComputedBuffer" - kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes) - with kernel: - for node in [template_node, *epilogue_nodes]: - node.mark_run() - src_code = render() - - with V.set_kernel_handler(kernel): - node_schedule = [template_node, *epilogue_nodes] - kernel_name = self.define_kernel(src_code, node_schedule, kernel.args) - kernel.call_kernel(kernel_name, ctb) - V.graph.removed_buffers |= kernel.removed_buffers - self.scheduler.free_buffers() - def _get_scheduled_num_args(self): return self.kernel_group.get_num_args() @@ -3703,7 +3886,7 @@ def ready_to_flush(self): def codegen_sync(self): pass - def define_kernel(self, src_code, nodes, kernel_args=None): + def define_kernel(self, src_code, nodes): wrapper = V.graph.wrapper_code fused_name = ( get_fused_kernel_name(nodes, config.cpp.descriptive_names) @@ -3719,8 +3902,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None): src_code = src_code.replace("#pragma CMT", "//") compile_wrapper = IndentedBuffer() - args = self.kernel_group.args if kernel_args is None else kernel_args - _, _, arg_types = args.cpp_argdefs() + _, _, arg_types = self.kernel_group.args.cpp_argdefs() if not V.graph.cpp_wrapper: compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''") compile_wrapper.splice(src_code, strip=True) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py deleted file mode 100644 index 2435ba19f2d06..0000000000000 --- a/torch/_inductor/codegen/cpp_utils.py +++ /dev/null @@ -1,234 +0,0 @@ -import math -from collections import namedtuple - -import torch - -from .common import ExprPrinter - -DTYPE_TO_CPP = { - torch.float32: "float", - torch.float64: "double", - torch.float16: "half", - torch.int64: "int64_t", - torch.int32: "int", - torch.int16: "short", - torch.int8: "signed char", - torch.uint64: "uint64_t", - torch.uint32: "unsigned int", - torch.uint16: "unsigned short", - torch.uint8: "unsigned char", - torch.bool: "bool", - torch.bfloat16: "bfloat16", - torch.complex64: "complex64", - torch.float8_e4m3fn: "float8_e4m3fn", - torch.float8_e5m2: "float8_e5m2", -} - -DTYPE_TO_ATEN = { - torch.float32: "at::kFloat", - torch.float64: "at::kDouble", - torch.float16: "at::kHalf", - torch.int64: "at::kLong", - torch.int32: "at::kInt", - torch.int16: "at::kShort", - torch.int8: "at::kChar", - torch.uint64: "at::kUInt64", - torch.uint32: "at::kUInt32", - torch.uint16: "at::kUInt16", - torch.uint8: "at::kByte", - torch.uint32: "at::kUInt32", - torch.uint64: "at::kUInt64", - torch.bool: "at::kBool", - torch.bfloat16: "at::kBFloat16", - torch.complex32: "at::kComplexHalf", - torch.complex64: "at::kComplexFloat", - torch.complex128: "at::kComplexDouble", - torch.float8_e4m3fn: "at::kFloat8_e4m3fn", - torch.float8_e5m2: "at::kFloat8_e5m2", - torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", - torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", -} - -DEVICE_TO_ATEN = { - "cpu": "at::kCPU", - "cuda": "at::kCUDA", -} - -GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) - -INDEX_TYPE = "long" - - -class CppPrinter(ExprPrinter): - def _print_Integer(self, expr): - return f"{int(expr)}L" - - def _print_Where(self, expr): - c = self.paren(self.doprint(expr.args[0])) - p = self.paren(self.doprint(expr.args[1])) - q = self.paren(self.doprint(expr.args[2])) - return f"{c} ? {p} : {q}" - - def _print_ModularIndexing(self, expr): - x, div, mod = expr.args - x = self.paren(self.doprint(x)) - if div != 1: - div = self.paren(self.doprint(div)) - if expr.is_integer: - x = f"c10::div_floor_integer({x}, {div})" - else: - x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" - mod = self.paren(self.doprint(mod)) - return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" - - def _print_FloorDiv(self, expr): - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - if expr.is_integer: - return f"c10::div_floor_integer({x}, {div})" - return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" - - def _print_floor(self, expr): - assert len(expr.args) == 1 - r = f"std::floor({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Trunc(self, expr): - assert len(expr.args) == 1 - r = f"std::trunc({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Pow(self, expr): - # Uses float constants to perform FP div - base, exp = expr.args - base = self._print(base) - - if exp == 0.5 or exp == -0.5: - return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" - assert exp.is_integer - exp = int(exp) - if exp > 0: - r = "*".join([self.paren(base)] * exp) - elif exp < 0: - r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - r = "1.0" - - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Rational(self, expr): - # Uses float constants to perform FP div - if expr.q == 1: - r = f"{expr.p}" - else: - r = f"{expr.p}.0/{expr.q}.0" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_ceiling(self, expr): - assert len(expr.args) == 1 - r = f"std::ceil({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Min(self, expr): - args = [self._print(a) for a in expr.args] - if len(args) == 2: - return f"std::min({args[0]}, {args[1]})" - else: - # Initializer list overload - il = "{" + ", ".join(args) + "}" - return f"std::min({il})" - - def _print_Max(self, expr): - args = [self._print(a) for a in expr.args] - if len(args) == 2: - return f"std::max({args[0]}, {args[1]})" - else: - # Initializer list overload - il = "{" + ", ".join(args) + "}" - return f"std::max({il})" - - def _print_Abs(self, expr): - assert len(expr.args) == 1 - return f"std::abs({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_cos(self, expr): - assert len(expr.args) == 1 - return f"std::cos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_cosh(self, expr): - assert len(expr.args) == 1 - return f"std::cosh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_acos(self, expr): - assert len(expr.args) == 1 - return f"std::acos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sin(self, expr): - assert len(expr.args) == 1 - return f"std::sin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sinh(self, expr): - assert len(expr.args) == 1 - return f"std::sinh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_asin(self, expr): - assert len(expr.args) == 1 - return f"std::asin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tan(self, expr): - assert len(expr.args) == 1 - return f"std::tan({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tanh(self, expr): - assert len(expr.args) == 1 - return f"std::tanh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_atan(self, expr): - assert len(expr.args) == 1 - return f"std::atan({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sqrt(self, expr): - return f"std::sqrt({self._print(expr.args[0])})" - - def _print_Round(self, expr): - assert len(expr.args) == 1 - return f"std::lrint({self._print(expr.args[0])})" - - def _print_RoundDecimal(self, expr): - assert len(expr.args) == 2 - number, ndigits = expr.args - if number.is_integer: - # ndigits < 0 should have been filtered by the sympy function - assert ndigits < 0 - raise ValueError( - f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." - ) - return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" - - def _print_BooleanTrue(self, expr): - return "true" - - def _print_BooleanFalse(self, expr): - return "false" - - -# A function to print, useful for printing sympy symbols. -cexpr = CppPrinter().doprint - - -def cexpr_index(index): - return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" - - -def value_to_cpp(value, cpp_type): - if value == float("-inf"): - return f"-std::numeric_limits<{cpp_type}>::infinity()" - elif value == float("inf"): - return f"std::numeric_limits<{cpp_type}>::infinity()" - elif isinstance(value, bool): - return f"static_cast<{cpp_type}>({str(value).lower()})" - elif math.isnan(value): - return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" - else: - return f"static_cast<{cpp_type}>({repr(value)})" diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 90461166626f5..95e4ef3ac7015 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -18,7 +18,6 @@ from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import IndentedBuffer -from .cpp_utils import cexpr, CppPrinter, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP from .wrapper import EnterSubgraphLine, ExitSubgraphLine, WrapperCodeGen @@ -55,6 +54,9 @@ def __init__(self): self.cached_output_id = count() self.scalar_to_tensor_id = count() self.custom_op_wrapper_loaded = False + + from .cpp import cexpr, CppPrinter + self.expr_printer = cexpr # CppPrinter sometimes calls at::native functions which causes problems in @@ -271,6 +273,7 @@ def write_input_output_info( @staticmethod def get_input_cpp_type(input): assert config.use_minimal_arrayref_interface + from .cpp import DTYPE_TO_CPP if isinstance(input, sympy.Expr): from ..graph import may_get_constant_buffer_dtype @@ -281,6 +284,8 @@ def get_input_cpp_type(input): return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" def generate_input_output_runtime_checks(self): + from .cpp import DTYPE_TO_ATEN + # In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each # real input/output tensor match ones provided at compile time via sample # input/output. @@ -381,6 +386,8 @@ def write_wrapper_decl(self): inputs_len = len(V.graph.graph_inputs.keys()) if V.graph.aot_mode: if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: + from .cpp import DTYPE_TO_CPP + input_cpp_types = ", ".join( f"{CppWrapperCpu.get_input_cpp_type(x)}" for x in V.graph.graph_inputs.values() @@ -543,6 +550,7 @@ def write_wrapper_decl(self): # unwrap input tensor back to scalar if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): from ..graph import may_get_constant_buffer_dtype + from .cpp import DTYPE_TO_CPP dtype = may_get_constant_buffer_dtype( V.graph.graph_inputs[input_key] @@ -1381,6 +1389,8 @@ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: return f"{{{', '.join(parts)}}}" def codegen_dynamic_scalar(self, node): + from .cpp import DTYPE_TO_ATEN, DTYPE_TO_CPP + (data,) = (t.codegen_reference() for t in node.inputs) if config.abi_compatible: dtype = node.inputs[0].get_dtype() @@ -1466,6 +1476,8 @@ def codegen_device(self, device): self.used_cached_devices.add(device.type) return f"cached_torch_device_type_{device.type}, {device.index if device.index else 0}" else: + from .cpp import DEVICE_TO_ATEN + return ( f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})" if device.index is not None @@ -1478,6 +1490,8 @@ def codegen_dtype(self, dtype): self.used_cached_dtypes.add(dtype_str) return f"cached_torch_dtype_{dtype_str}" else: + from .cpp import DTYPE_TO_ATEN + return DTYPE_TO_ATEN[dtype] @functools.lru_cache(None) @@ -1540,6 +1554,8 @@ def make_allocation( device_type, device_id = device_str.split(",") device_idx = "this->device_idx_" if V.graph.aot_mode else device_id if buffer_if_can_stack_allocate is not None: + from .cpp import DTYPE_TO_CPP + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate cpp_type = DTYPE_TO_CPP[dtype] numel = buffer_if_can_stack_allocate.get_numel() From 59086de7bcfd6dc04bf55294dfe3dca3d04a4335 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Mon, 29 Apr 2024 07:57:24 -0700 Subject: [PATCH 23/28] Update [ghstack-poisoned] --- torch/_inductor/autotune_process.py | 136 ++++------------------- torch/_inductor/codecache.py | 3 + torch/_inductor/ir.py | 7 +- torch/_inductor/runtime/runtime_utils.py | 19 ---- torch/_inductor/select_algorithm.py | 123 +++----------------- 5 files changed, 40 insertions(+), 248 deletions(-) diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 13b5ad97861e3..035961c311dc3 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import ctypes import dataclasses import functools import logging @@ -10,10 +9,9 @@ import time import warnings from concurrent.futures import ThreadPoolExecutor -from ctypes import byref, c_size_t, c_void_p, CDLL +from ctypes import byref, c_size_t, c_void_p from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue -from types import ModuleType from typing import ( Any, Callable, @@ -31,19 +29,13 @@ from torch._dynamo.testing import rand_strided from torch._inductor import ir -from torch._inductor.codecache import ( - CppCodeCache, - CUDACodeCache, - DLLWrapper, - get_hash, - PyCodeCache, -) +from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache if TYPE_CHECKING: from torch._inductor.select_algorithm import TritonTemplateCaller from . import config -from .runtime.runtime_utils import do_bench, do_bench_cpu +from .runtime.runtime_utils import do_bench from .virtualized import V CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES" @@ -495,14 +487,6 @@ def make_run_fn( def cleanup_run_fn(self) -> None: pass - def do_bench( - self, - fn, - *input_tensors: torch.Tensor, - output_tensor: Optional[torch.Tensor] = None, - ) -> float: - raise NotImplementedError - def benchmark( self, *input_tensors: torch.Tensor, @@ -528,7 +512,22 @@ def benchmark( load_elapse = time.time() - start_ts # type: ignore[possibly-undefined] start_ts = time.time() - out = self.do_bench(fn, *input_tensors, output_tensor) + device_idx_set = { + tensor.device.index + for tensor in [*input_tensors, output_tensor] + if isinstance(tensor, torch.Tensor) + and tensor.is_cuda + and tensor.device.index is not None + } + assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" + if len(device_idx_set) == 1: + device_idx = next(iter(device_idx_set)) + else: + device_idx = torch.cuda.current_device() + + with torch.cuda.device(device_idx): + out = do_bench(fn) + torch.cuda.synchronize() # shake out any CUDA errors if debug: bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] @@ -560,34 +559,7 @@ def benchmark( return self.value -class GPUDeviceBenchmarkRequest(BenchmarkRequest): - def do_bench( - self, - fn, - *input_tensors: torch.Tensor, - output_tensor: Optional[torch.Tensor] = None, - ) -> float: - device_idx_set = { - tensor.device.index - for tensor in [*input_tensors, output_tensor] - if isinstance(tensor, torch.Tensor) - and tensor.is_cuda - and tensor.device.index is not None - } - assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}" - if len(device_idx_set) == 1: - device_idx = next(iter(device_idx_set)) - else: - device_idx = torch.cuda.current_device() - - with torch.cuda.device(device_idx): - out = do_bench(fn) - torch.cuda.synchronize() # shake out any CUDA errors - - return out - - -class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest): +class TritonBenchmarkRequest(BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! def __init__( @@ -663,7 +635,7 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}" -class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest): +class CUDABenchmarkRequest(BenchmarkRequest): # Important: Instances of this class have to be serializable # across process boundaries. Do not put CUDA Tensors in here! @@ -751,72 +723,6 @@ def __str__(self) -> str: return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" -class CPUDeviceBenchmarkRequest(BenchmarkRequest): - def do_bench( - self, - fn, - *input_tensors: torch.Tensor, - output_tensor: Optional[torch.Tensor] = None, - ) -> float: - return do_bench_cpu(fn) - - -@dataclasses.dataclass -class CppBenchmarkRequest(CPUDeviceBenchmarkRequest): - # Important: Instances of this class have to be serializable - # across process boundaries. Do not put Tensors in here! - - def __init__( - self, - kernel_name: str, - input_tensor_meta: Union[TensorMeta, List[TensorMeta]], - output_tensor_meta: Union[TensorMeta, List[TensorMeta]], - extra_args: Iterable[Any], - source_code: str, - ): - super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) - self.source_code = source_code - self.hash_key = get_hash(source_code) - self.DLL: Optional[Union[CDLL, ModuleType]] = None - - def precompile(self): - # Prepopulate CppCodeCache - # may happen in separate Threadpool - log.debug("Precompiling %s", self) - CppCodeCache.load(self.source_code, cuda=False) - log.debug("Done precompiling %s", self) - - def make_run_fn( - self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor - ) -> Callable[[], None]: - # TODO(jgong5): use CppPythonBindingsCodeCache for better binding perf - self.DLL = CppCodeCache.load(self.source_code, cuda=False) - args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]] - log.debug( - "make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s", - self.kernel_name, - self.DLL, - args, - self.extra_args, - ) - run_method = getattr(self.DLL, self.kernel_name) - run_method.argtypes = [ctypes.c_ulonglong] * len(args) - - # Generate partial function. - return functools.partial( - run_method, - *args, - *self.extra_args, - ) - - def cleanup_run_fn(self) -> None: - if self.DLL is not None: - self.DLL.close() - - def __str__(self) -> str: - return f"{self.kernel_name=}" - - def benchmark_in_sub_process( choices: List[TritonTemplateCaller], ) -> Dict[TritonTemplateCaller, float]: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 144f9b4631a34..1d148f9d9968f 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -202,6 +202,9 @@ def get_global_cache_path() -> Optional[Path]: ) def __init__(self) -> None: + if not torch.cuda.is_available(): + return + self.system = CacheBase.get_system() def get_local_cache(self) -> Dict[str, Any]: diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a90adcb279591..f2d9c8a98fd48 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -78,7 +78,6 @@ convert_shape_to_symint, developer_warning, get_kernel_metadata, - is_cpu_device, is_dynamic, is_gpu, pad_listlike, @@ -86,7 +85,6 @@ sympy_index_symbol, sympy_product, sympy_subs, - timed, ) from .virtualized import ops, V @@ -3621,10 +3619,7 @@ def __init__(self, name, input_nodes, layout): def benchmark(self, *args, out) -> float: algo = self.to_callable() - if is_cpu_device(args): - return timed(lambda: algo(*args, out=out), ()) - else: - return do_bench(lambda: algo(*args, out=out)) + return do_bench(lambda: algo(*args, out=out)) def call_name(self) -> str: raise NotImplementedError diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 19f4debbfa9b9..c0fdf65ec9b7a 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -7,7 +7,6 @@ import os import re import tempfile -import time import torch @@ -101,24 +100,6 @@ def load_triton(): return triton_do_bench(*args, **kwargs)[0] -def do_bench_cpu(fn, warmup=3, times=10): - assert times > 0 - for _ in range(warmup): - fn() - durations = [] - for _ in range(times): - t0 = time.perf_counter() - fn() - t1 = time.perf_counter() - durations.append((t1 - t0) * 1000) - # return the median time - sorted_durations = sorted(durations) - if times % 2 == 0: - return (sorted_durations[times // 2 - 1] + sorted_durations[times // 2]) / 2 - else: - return sorted_durations[times // 2] - - def cache_dir() -> str: cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") if cache_dir is None: diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 48a532c81062b..c301c3394feb7 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -35,10 +35,9 @@ from .codegen.triton_utils import config_of, signature_to_meta from .exc import CUDACompileError from .ir import ChoiceCaller, PrimitiveInfoType -from .runtime.runtime_utils import do_bench, do_bench_cpu +from .runtime.runtime_utils import do_bench from .utils import ( get_dtype_size, - is_cpu_device, Placeholder, sympy_dot, sympy_index_symbol, @@ -694,19 +693,17 @@ def __init__( has_out_variant=True, op_overload=None, use_fallback_kernel=False, - kernel_creator=None, ): super().__init__() name = name or kernel.__name__ assert callable(kernel) - assert not hasattr(extern_kernels, name), f"duplicate extern kernel: {name}" + assert not hasattr(extern_kernels, name), "duplicate extern kernel" self.name = name self.cpp_kernel_name = cpp_kernel self.has_out_variant = has_out_variant setattr(extern_kernels, name, kernel) self.op_overload = op_overload self.use_fallback_kernel = use_fallback_kernel - self.kernel_creator = kernel_creator def to_callable(self): return getattr(extern_kernels, self.name) @@ -844,10 +841,7 @@ def benchmark(self, *args, out): out_new, tuple(out.size()), tuple(out.stride()) ) out.copy_(out_new) # for correctness checking - if is_cpu_device(args): - return do_bench_cpu(lambda: algo(*args)) - else: - return do_bench(lambda: algo(*args)) + return do_bench(lambda: algo(*args)) def to_callable(self): fn = self.choice.to_callable() @@ -876,8 +870,6 @@ def output_node(self): inner = ir.FallbackKernel.create( self.choice.op_overload, *self.input_nodes, **self.kwargs ) - elif self.choice.kernel_creator is not None: - inner = self.choice.kernel_creator(*self.input_nodes, **self.kwargs) else: cls = ir.ExternKernelOut if self.has_out_variant else ir.ExternKernelAlloc inner = cls( @@ -900,76 +892,6 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType } -class DataProcessorChoiceCallerWrapper: - def __init__(self, wrapped, preprocessor, postprocessor): - self._wrapped = wrapped - if preprocessor is not None: - self._preprocessor = preprocessor - else: - self._preprocessor = lambda x, y: (x, y) - if postprocessor is not None: - self._postprocessor = postprocessor - else: - self._postprocessor = lambda x: x - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def benchmark(self, *args, out) -> float: - new_args, new_out = self._preprocessor(args, out) - result = self._wrapped.benchmark(*new_args, out=new_out) - new_out = self._postprocessor(new_out) - if out is not new_out: - out.copy_(new_out) - return result - - def output_node(self) -> ir.TensorBox: - result = self._wrapped.output_node() - return self._postprocessor(result) - - def __repr__(self) -> str: - return f"DataProcessorChoiceCallerWrapper({self._wrapped})" - - -class DataProcessorTemplateWrapper: - def __init__( - self, - wrapped_template_cls, - preprocessor, - postprocessor, - **kwargs, - ): - if preprocessor is not None: - self._preprocessor = preprocessor - else: - self._preprocessor = lambda x, y: (x, y) - if postprocessor is not None: - self._postprocessor = postprocessor - else: - self._postprocessor = lambda x: x - assert "input_nodes" in kwargs - assert "layout" in kwargs - kwargs["input_nodes"], kwargs["layout"] = preprocessor( - kwargs["input_nodes"], kwargs["layout"] - ) - self._wrapped = wrapped_template_cls(**kwargs) - - def __getattr__(self, name): - return getattr(self._wrapped, name) - - def maybe_append_choice(self, choices, **kwargs): - return type(self._wrapped).maybe_append_choice(self, choices, **kwargs) - - def generate(self, **kwargs): - choice_caller = self._wrapped.generate(**kwargs) - return DataProcessorChoiceCallerWrapper( - choice_caller, self._preprocessor, self._postprocessor - ) - - def __repr__(self) -> str: - return f"DataProcessorTemplateWrapper({self._wrapped})" - - class ErrorFromChoice(RuntimeError): def __init__(self, msg, choice: ChoiceCaller, inputs_str): msg += f"\nFrom choice {choice}\n{inputs_str}" @@ -1091,6 +1013,7 @@ def no_op(*args, **kwargs): [c for c in choices if hasattr(c, "precompile")], timeout=precompilation_timeout_seconds, ) + from triton.runtime.autotuner import OutOfResources @functools.lru_cache(None) def wait_on_futures(): @@ -1110,17 +1033,9 @@ def wait_on_futures(): ) except StopIteration: pass - except Exception as e: - try: - from triton.runtime.autotuner import OutOfResources - - if isinstance(e, OutOfResources): - # This config is invalid due to requiring too many resources - pass - else: - raise e - except ImportError: - raise e + except OutOfResources: + # This config is invalid due to requiring too many resources + pass executor.shutdown(wait=True) @@ -1223,9 +1138,7 @@ def get_inputs(): } example_inputs = list(unique_example_inputs.values()) example_inputs_extern = [ - unique_example_inputs[input_node.get_name()] - if unique_example_inputs[input_node.get_name()].is_mkldnn - else torch.as_strided( + torch.as_strided( unique_example_inputs[input_node.get_name()], V.graph.sizevars.size_hints( input_node.get_size(), @@ -1284,11 +1197,12 @@ def benchmark_choice_in_current_process( result = choice.benchmark(*example_inputs, out=out) if VERIFY: torch.testing.assert_close(out_extern, expected, **VERIFY) - if torch.cuda.is_available(): - torch.cuda.synchronize() # shake out any CUDA errors + torch.cuda.synchronize() # shake out any CUDA errors return result def benchmark_in_current_process(choices): + from triton.runtime.autotuner import OutOfResources + inputs = get_inputs() example_inputs, _, out, _, _ = inputs timings = {} @@ -1312,21 +1226,14 @@ def benchmark_in_current_process(choices): raise ErrorFromChoice( msg, choice, debug_str(example_inputs, out) ) from e + except OutOfResources as e: + log.warning(e) + timing = float("inf") + except AssertionError as e: raise AssertionError( # noqa: TRY200 f"Incorrect result from choice {choice}\n\n{e}" ) - except Exception as e: - try: - from triton.runtime.autotuner import OutOfResources - - if isinstance(e, OutOfResources): - log.warning(e) - timing = float("inf") - else: - raise e - except ImportError: - raise e timings[choice] = timing From 1c5a149b2b0b1d34f8f18671a03d69c03605b9a4 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Tue, 30 Apr 2024 05:37:18 -0700 Subject: [PATCH 24/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 2 ++ torch/_inductor/codegen/cpp_micro_gemm.py | 15 +++++++++++---- torch/_inductor/utils.py | 9 +++++---- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index cbbe968691ac2..9b8826abac1d9 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -238,6 +238,7 @@ def transpose_weight(inputs, layout_or_out): micro_gemm = create_micro_gemm( "micro_gemm", m, n, k, layout.dtype, alpha=alpha, num_threads=num_threads ) + assert micro_gemm is not None _, block_n, _ = micro_gemm.register_blocking def pack_weight(inputs, layout_or_out): @@ -351,6 +352,7 @@ def render( # type: ignore[override] alpha=self.alpha, num_threads=self.num_threads, ) + assert micro_gemm is not None assert self.register_blocking == micro_gemm.register_blocking options = dict( diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 40a19b261704a..c70f63da73990 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Dict, List, Type +from typing import Dict, List, Optional, Type import sympy @@ -308,7 +308,7 @@ def create_micro_gemm( alpha=1, num_threads=-1, use_ref=False, -) -> CppMicroGemm: +) -> Optional[CppMicroGemm]: def create_from_config(cls, config: CppMicroGemmConfig): return cls( name, @@ -322,6 +322,7 @@ def create_from_config(cls, config: CppMicroGemmConfig): assert isinstance(n, int) or n.is_number, n assert isinstance(k, int) or k.is_number, k m = V.graph.sizevars.size_hint(m) if isinstance(m, sympy.Expr) else m + assert isinstance(m, int), m if output_dtype is None: output_dtype = input_dtype if compute_dtype is None: @@ -340,6 +341,7 @@ def create_from_config(cls, config: CppMicroGemmConfig): and config.compute_dtype == compute_dtype ): block_m, block_n, block_k = config.register_blocking + # TODO(jgong5): support n % n_block_size != 0 if n % block_n != 0: continue # Criteria on the ranking of configurations @@ -365,7 +367,12 @@ def create_from_config(cls, config: CppMicroGemmConfig): config, ) ) - if len(matched_configs) == 0 or use_ref: - return CppMicroGemmRef(name, input_dtype, output_dtype, compute_dtype, alpha) + if len(matched_configs) == 0: + if use_ref: + return CppMicroGemmRef( + name, input_dtype, output_dtype, compute_dtype, alpha + ) + else: + return None # TODO(jgong5): allow autotuning on choices of configs return create_from_config(*max(matched_configs, key=lambda x: x[0])[1:]) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 37366ae060f52..461496e567619 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -960,14 +960,15 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): return False if isinstance(mat2, ir.BaseView): mat2 = mat2.unwrap_view() - # TODO(jgong5): support n % n_block_size != 0 - _, n_block_size, _ = create_micro_gemm( + micro_gemm = create_micro_gemm( "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() - ).register_blocking + ) + # TODO(jgong5): support n % n_block_size != 0 return ( _use_template_for_cpu(layout) and layout.dtype in layout_dtypes - and n % n_block_size == 0 + and micro_gemm is not None + and n % micro_gemm.register_blocking[1] == 0 and isinstance(mat2, ir.StorageBox) and mat2.is_module_buffer() ) From 614a739b7140d7b20d2fdcdf0078472c4baf68a4 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Mon, 6 May 2024 05:37:03 -0700 Subject: [PATCH 25/28] Update [ghstack-poisoned] --- test/inductor/test_cpu_select_algorithm.py | 52 +++++++++------------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 8bfabcf3749e5..df8bb44ee53d0 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -4,6 +4,7 @@ from unittest.mock import patch import torch +import torch._dynamo.config import torch._dynamo.config as dynamo_config import torch._inductor.config as inductor_config import torch._inductor.select_algorithm as select_algorithm @@ -43,7 +44,17 @@ def wrapped(*args, **kwargs): class TestSelectAlgorithm(TestCase): - def _test_linear( + @inductor_config.patch({"freezing": True}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (1, 2, 1000)) + @parametrize("in_features", (1, 2, 1000)) + @parametrize("out_features", (1, 32, 1024)) + @parametrize("bias", (True, False)) + @parametrize("input_3d", (True, False)) + @dtypes(torch.float) + def test_linear_static_shapes( self, batch_size, in_features, out_features, bias, input_3d, dtype ): class M(torch.nn.Module): @@ -65,39 +76,20 @@ def forward(self, x): 1 if out_features != 1 else 0, ) - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("batch_size", (1, 2, 1000)) - @parametrize("in_features", (1, 2, 1000)) - @parametrize("out_features", (1, 32, 1024)) - @parametrize("bias", (True, False)) - @parametrize("input_3d", (True, False)) - @dtypes(torch.float) - def test_linear_static_shapes( - self, batch_size, in_features, out_features, bias, input_3d, dtype - ): - self._test_linear(batch_size, in_features, out_features, bias, input_3d, dtype) - @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) - @inductor_config.patch({"freezing": True}) - @patches - @torch.no_grad - @unittest.skipIf(not TEST_MKL, "Test requires MKL") - @parametrize("batch_size", (1, 2, 1000)) - @parametrize("in_features", (1, 2, 1000)) - @parametrize("out_features", (1, 32, 1024)) - @parametrize("bias", (True, False)) - @parametrize("input_3d", (True, False)) - @dtypes(torch.float) - def test_linear_dynamic_shapes( - self, batch_size, in_features, out_features, bias, input_3d, dtype - ): - self._test_linear(batch_size, in_features, out_features, bias, input_3d, dtype) +@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) +class _DynamicShapesTestBase(TestCase): + pass + + +class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase): + test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu") +instantiate_device_type_tests( + TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu" +) if __name__ == "__main__": From 6b682e2181af2ed1cae8c7288cddb65236a33c20 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Mon, 6 May 2024 05:58:07 -0700 Subject: [PATCH 26/28] Update [ghstack-poisoned] --- test/inductor/test_cpu_select_algorithm.py | 7 ++++++- torch/_inductor/config.py | 5 +++-- torch/_inductor/mkldnn_lowerings.py | 20 ++++++++++++-------- torch/_inductor/utils.py | 6 ++++-- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index df8bb44ee53d0..7e7d63513369c 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -28,7 +28,12 @@ def skip_cache(self, choices, name, key, benchmark): for patcher in [ dynamo_config.patch(verbose=True), - inductor_config.patch(debug=True, max_autotune=True, epilogue_fusion=True), + inductor_config.patch( + debug=True, + max_autotune=True, + epilogue_fusion=True, + max_autotune_gemm_backends="CPP,ATEN", + ), patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)), patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache), ]: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d09621affe52c..e8d55371540f4 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -230,12 +230,13 @@ def is_fbcode(): True if is_fbcode() else os.environ.get("TORCHINDUCTOR_FORCE_SAME_PRECISION") == "1" ) # Specify candidate backends for gemm autotune. -# Possible choices are combinations of: ATen, Triton, CUTLASS. +# Possible choices are combinations of: ATen, Triton, CUTLASS, CPP. # ATen: default Pytorch ATen kernels. # Triton: Triton templates defined in torch inductor. # CUTLASS: Cutlass templates and kernels. +# CPP: CPP templates and kernels for CPU. max_autotune_gemm_backends = os.environ.get( - "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON" + "TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS", "ATEN,TRITON,CPP" ).upper() # the value used as a fallback for the unbacked SymInts diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index ac032d80e257b..a291f33c4f47a 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -3,7 +3,7 @@ import torch import torch.utils._pytree as pytree from torch._inductor.kernel.mm_common import mm_args -from . import config, ir +from . import ir from .codegen.cpp_gemm_template import CppPackedGemmTemplate from .ir import TensorBox from .lowering import ( @@ -15,7 +15,7 @@ to_dtype, ) from .select_algorithm import autotune_select_algorithm, ExternKernelChoice -from .utils import use_cpp_packed_gemm_template +from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template, use_max_autotune from .virtualized import V @@ -369,12 +369,16 @@ def mkl_packed_linear( *, layout=None, ): - choices = [ - aten_mkl_linear.bind( - (x, packed_w, orig_w), layout, B=None, batch_size=batch_size - ) - ] - if config.max_autotune or config.max_autotune_gemm: + choices = ( + [ + aten_mkl_linear.bind( + (x, packed_w, orig_w), layout, B=None, batch_size=batch_size + ) + ] + if use_aten_gemm_kernels() + else [] + ) + if use_max_autotune(): transposed_w = permute(orig_w, [1, 0]) *_, layout, x, transposed_w = mm_args( x, transposed_w, layout=layout diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 124f1fce22c5b..2a6b8cc0c6dc4 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -974,6 +974,9 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): from .codegen.cpp_micro_gemm import create_micro_gemm from .kernel.mm_common import mm_args + if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): + return False + layout_dtypes = [torch.float32] m, n, k, *_ = mm_args(mat1, mat2) # TODO(jgong5): support dynamic shapes for n or k @@ -986,8 +989,7 @@ def use_cpp_packed_gemm_template(layout, mat1, mat2): ) # TODO(jgong5): support n % n_block_size != 0 return ( - _use_template_for_cpu(layout) - and layout.dtype in layout_dtypes + layout.dtype in layout_dtypes and micro_gemm is not None and n % micro_gemm.register_blocking[1] == 0 and isinstance(mat2, ir.StorageBox) From 66f5e313bec295e9dd2a7b63dc1e3a8535cf623f Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Mon, 6 May 2024 22:01:44 -0700 Subject: [PATCH 27/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_gemm_template.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 9b8826abac1d9..70bd66e5f7f62 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -6,7 +6,7 @@ from ..kernel.mm_common import mm_args from ..select_algorithm import DataProcessorTemplateWrapper -from ..utils import cache_on_self, parallel_num_threads +from ..utils import cache_on_self, has_free_symbols, parallel_num_threads from ..virtualized import V from .cpp_micro_gemm import create_micro_gemm from .cpp_template import CppTemplate @@ -138,7 +138,7 @@ def __init__( m, n = layout.size _, k = input_nodes[0].get_size() self.m, self.n, self.k = m, n, k - self.is_dynamic_M = len(self.m.free_symbols) > 0 + self.is_dynamic_M = has_free_symbols((m,)) @cache_on_self def thread_blocking(self) -> GemmBlocking: From 0cacd096621558fd8c96c586c379ec2c7b608677 Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Tue, 14 May 2024 18:58:59 -0700 Subject: [PATCH 28/28] Update [ghstack-poisoned] --- torch/_inductor/codegen/cpp_micro_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 7d54bd8605ec4..353562923c91c 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -344,7 +344,7 @@ def create_from_config(cls, config: CppMicroGemmConfig): assert isinstance(n, int) or n.is_number, n assert isinstance(k, int) or k.is_number, k - m = V.graph.sizevars.size_hint(m) if isinstance(m, sympy.Expr) else m + m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m assert isinstance(m, int), m if output_dtype is None: output_dtype = input_dtype