Skip to content

Commit

Permalink
Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
Browse files Browse the repository at this point in the history
This reverts commit 9da7efa.

Reverted #124021 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](#124021 (comment)))
  • Loading branch information
pytorchmergebot committed May 23, 2024
1 parent 45784cd commit 25b8dbc
Show file tree
Hide file tree
Showing 15 changed files with 14 additions and 1,603 deletions.
135 changes: 0 additions & 135 deletions test/inductor/test_cpu_select_algorithm.py

This file was deleted.

49 changes: 3 additions & 46 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from copy import copy, deepcopy
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import sympy

Expand All @@ -20,7 +20,6 @@
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from ..._dynamo.utils import counters

from .. import codecache, config, ir, metrics
from ..codegen.wrapper import WrapperCodeGen
Expand Down Expand Up @@ -3583,8 +3582,6 @@ def _can_fuse_horizontal_impl(self, node1, node2):
return self._why_fuse_nodes(node1, node2) is not None

def can_fuse_horizontal(self, node1, node2):
if node1.is_template() or node2.is_template():
return False
if (
len(node1.get_nodes()) + len(node2.get_nodes())
> config.cpp.max_horizontal_fusion_size
Expand Down Expand Up @@ -3665,9 +3662,6 @@ def get_fusion_pair_priority(self, node1, node2):
return 0

def can_fuse_vertical(self, node1, node2):
# TODO(jgong5): support vertical fusion for template nodes
if node1.is_template() or node2.is_template():
return False
return (
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
) or self.can_fuse_vertical_outer_loop(node1, node2)
Expand Down Expand Up @@ -3724,42 +3718,6 @@ def codegen_node(
if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
self._set_flush_status(True)

def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
return isinstance(node, SchedulerNode) and isinstance(
node.node, ir.CppTemplateBuffer
)

def codegen_template(
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
):
"""
Codegen a CPP template, possibly with fused epilogues
"""
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template(
template_node
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
template_node = cast(SchedulerNode, template_node)
_, (_, rnumel) = template_node.group
assert rnumel == ()
ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes]
assert all(
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
with kernel:
for node in [template_node, *epilogue_nodes]:
node.mark_run()
src_code = render()

with V.set_kernel_handler(kernel):
node_schedule = [template_node, *epilogue_nodes]
kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.scheduler.free_buffers()

def _get_scheduled_num_args(self):
return self.kernel_group.get_num_args()

Expand All @@ -3769,7 +3727,7 @@ def ready_to_flush(self):
def codegen_sync(self):
pass

def define_kernel(self, src_code, nodes, kernel_args=None):
def define_kernel(self, src_code, nodes):
wrapper = V.graph.wrapper_code
fused_name = (
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
Expand All @@ -3785,8 +3743,7 @@ def define_kernel(self, src_code, nodes, kernel_args=None):
src_code = src_code.replace("#pragma CMT", "//")

compile_wrapper = IndentedBuffer()
args = self.kernel_group.args if kernel_args is None else kernel_args
_, _, arg_types = args.cpp_argdefs()
_, _, arg_types = self.kernel_group.args.cpp_argdefs()
if not V.graph.cpp_wrapper:
compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
compile_wrapper.splice(src_code, strip=True)
Expand Down

0 comments on commit 25b8dbc

Please sign in to comment.