Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor][cpp] GEMM template (infra and fp32) #124021

Closed
wants to merge 63 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
00eb31a
Update
jgong5 Apr 14, 2024
b6ff5fe
Update
jgong5 Apr 16, 2024
0355c46
Update
jgong5 Apr 16, 2024
ba94cdf
Update
jgong5 Apr 17, 2024
5ad7899
Update
jgong5 Apr 17, 2024
1c4edcd
Update
jgong5 Apr 17, 2024
a56957d
Update
jgong5 Apr 17, 2024
5bf33c4
Update
jgong5 Apr 17, 2024
f780f9c
Update
jgong5 Apr 18, 2024
0580a46
Update
jgong5 Apr 18, 2024
d795f31
Update
jgong5 Apr 26, 2024
002bedb
Update
jgong5 Apr 27, 2024
2bfc603
Update
jgong5 Apr 28, 2024
a416d41
Update
jgong5 Apr 28, 2024
8d3f8aa
Update
jgong5 Apr 28, 2024
701a0cd
Update
jgong5 Apr 28, 2024
85ce15a
Update
jgong5 Apr 28, 2024
5f0133e
Update
jgong5 Apr 28, 2024
b1f731b
Update
jgong5 Apr 28, 2024
c0d77bc
Update
jgong5 Apr 28, 2024
ab8e6a9
Update
jgong5 Apr 29, 2024
c2c5d2d
Update
jgong5 Apr 29, 2024
b079a2c
Update
jgong5 Apr 29, 2024
fac3997
Update
jgong5 Apr 29, 2024
b0e451c
Update
jgong5 Apr 29, 2024
ff91a01
Update
jgong5 Apr 29, 2024
59086de
Update
jgong5 Apr 29, 2024
bfce7d8
Update
jgong5 Apr 29, 2024
7e6490a
Update
jgong5 Apr 29, 2024
7a4dc85
Update
jgong5 Apr 29, 2024
0cec870
Update
jgong5 Apr 30, 2024
fc8a9c8
Update
jgong5 Apr 30, 2024
b337242
Update
jgong5 Apr 30, 2024
1c5a149
Update
jgong5 Apr 30, 2024
7ae7be0
Update
jgong5 May 6, 2024
614a739
Update
jgong5 May 6, 2024
6b682e2
Update
jgong5 May 6, 2024
66f5e31
Update
jgong5 May 7, 2024
d56bebf
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 7, 2024
acb4a95
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 7, 2024
70a6d7d
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 7, 2024
0162cf6
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 8, 2024
92f4ac4
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 8, 2024
8cfdb7d
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 8, 2024
55d98b0
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 9, 2024
ca09328
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 9, 2024
e96352e
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 9, 2024
f427d85
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 12, 2024
f0e2203
Update
jgong5 May 15, 2024
0cacd09
Update
jgong5 May 15, 2024
67877a6
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 19, 2024
6687ccf
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 19, 2024
70b35d3
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 19, 2024
3bcbae9
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 19, 2024
3a8012d
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 19, 2024
fbb0064
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 19, 2024
2993c2e
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 20, 2024
ac36018
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 21, 2024
ad5e500
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 23, 2024
5f07582
Update on "[inductor][cpp] GEMM template (infra and fp32)"
jgong5 May 24, 2024
c872b7c
Update
jgong5 May 28, 2024
bdb239e
Update
jgong5 May 29, 2024
6ad24d3
Update
jgong5 May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
135 changes: 135 additions & 0 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Owner(s): ["oncall: cpu inductor"]
import functools
import unittest
from unittest.mock import patch

import torch
import torch._dynamo.config
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)

from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL

aten = torch.ops.aten


def patches(fn):
def skip_cache(self, choices, name, key, benchmark):
if benchmark is None:
return {}
return benchmark(choices)

for patcher in [
dynamo_config.patch(verbose=True),
inductor_config.patch(
debug=True,
max_autotune=True,
epilogue_fusion=True,
max_autotune_gemm_backends="CPP,ATEN",
),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
]:
fn = patcher(fn)

@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
return fn(*args, **kwargs)

return wrapped


class TestSelectAlgorithm(TestCase):
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (1, 2, 1000))
@parametrize("in_features", (1, 2, 1000))
@parametrize("out_features", (1, 32, 1024))
@parametrize("bias", (True, False))
@parametrize("input_3d", (True, False))
@dtypes(torch.float)
def test_linear_static_shapes(
self, batch_size, in_features, out_features, bias, input_3d, dtype
):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)

@torch.compile
def forward(self, x):
return self.linear(x)

counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype=dtype)
mod(v)
if (
counters["inductor"]["decompose_mm"] > 0
or counters["inductor"]["decompose_addmm"] > 0
):
# This is a special case where we go directly with vectorized codegen
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)
else:
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bias", (True, False))
@dtypes(torch.float)
def test_linear_input_transpose(self, bias, dtype):
batch_size = 384
in_features = 196
out_features = 384

class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)

@torch.compile
def forward(self, x):
return self.linear(x)

counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
v = torch.randn(in_features, batch_size).to(dtype=dtype)
mod(v.transpose(0, 1))
# TODO(jgong5): support transposed input
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)


@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(TestCase):
pass


class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes


instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
instantiate_device_type_tests(
TestSelectAlgorithmDynamicShapes, globals(), only_for="cpu"
)


if __name__ == "__main__":
from torch.testing._internal.inductor_utils import HAS_CPU

if HAS_CPU and not IS_MACOS:
run_tests()
51 changes: 48 additions & 3 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from copy import copy, deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union

import sympy

Expand All @@ -20,6 +20,7 @@
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from ..._dynamo.utils import counters

from .. import codecache, config, ir, metrics
from ..codegen.wrapper import WrapperCodeGen
Expand Down Expand Up @@ -3584,6 +3585,8 @@ def _can_fuse_horizontal_impl(self, node1, node2):
return self._why_fuse_nodes(node1, node2) is not None

def can_fuse_horizontal(self, node1, node2):
if node1.is_template() or node2.is_template():
return False
if (
len(node1.get_nodes()) + len(node2.get_nodes())
> config.cpp.max_horizontal_fusion_size
Expand Down Expand Up @@ -3664,6 +3667,9 @@ def get_fusion_pair_priority(self, node1, node2):
return 0

def can_fuse_vertical(self, node1, node2):
# TODO(jgong5): support vertical fusion for template nodes
if node1.is_template() or node2.is_template():
return False
return (
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
) or self.can_fuse_vertical_outer_loop(node1, node2)
Expand Down Expand Up @@ -3720,6 +3726,44 @@ def codegen_node(
if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
self._set_flush_status(True)

def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
return isinstance(node, SchedulerNode) and isinstance(
node.node, ir.CppTemplateBuffer
)

def codegen_template(
self,
template_node: BaseSchedulerNode,
epilogue_nodes: Sequence[BaseSchedulerNode],
):
"""
Codegen a CPP template, possibly with fused epilogues
"""
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template(
template_node
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
template_node = cast(SchedulerNode, template_node)
_, (_, rnumel) = template_node.group
assert rnumel == ()
ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
epilogue_ir_nodes: List[Optional[ir.Buffer]] = [n.node for n in epilogue_nodes]
assert all(
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
with kernel:
for node in [template_node, *epilogue_nodes]:
node.mark_run() # type: ignore[attr-defined]
src_code = render()

with V.set_kernel_handler(kernel):
node_schedule = [template_node, *epilogue_nodes]
kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.scheduler.free_buffers()

def _get_scheduled_num_args(self):
return self.kernel_group.get_num_args()

Expand All @@ -3729,7 +3773,7 @@ def ready_to_flush(self):
def codegen_sync(self):
pass

def define_kernel(self, src_code, nodes):
def define_kernel(self, src_code, nodes, kernel_args=None):
wrapper = V.graph.wrapper_code
fused_name = (
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
Expand All @@ -3745,7 +3789,8 @@ def define_kernel(self, src_code, nodes):
src_code = src_code.replace("#pragma CMT", "//")

compile_wrapper = IndentedBuffer()
_, _, arg_types = self.kernel_group.args.cpp_argdefs()
args = self.kernel_group.args if kernel_args is None else kernel_args
_, _, arg_types = args.cpp_argdefs()
if not V.graph.cpp_wrapper:
compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
compile_wrapper.splice(src_code, strip=True)
Expand Down