Skip to content

Commit

Permalink
[inductor][cpp] GEMM template
Browse files Browse the repository at this point in the history
ghstack-source-id: 30af30579007f953540b54d0d2811a9ae6868234
Pull Request resolved: #124021
  • Loading branch information
jgong5 committed Apr 17, 2024
1 parent ab66e95 commit 453bca3
Show file tree
Hide file tree
Showing 12 changed files with 888 additions and 57 deletions.
71 changes: 71 additions & 0 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Owner(s): ["oncall: cpu inductor"]
import functools
import unittest
from unittest.mock import patch

import torch
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.select_algorithm as select_algorithm
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase

from torch.testing._internal.common_utils import IS_MACOS, TEST_MKL

aten = torch.ops.aten


def patches(fn):
def skip_cache(self, choices, name, key, generate):
return generate(choices)

for patcher in [
dynamo_config.patch(verbose=True),
inductor_config.patch(debug=True, max_autotune=True, epilogue_fusion=True),
patch.object(select_algorithm, "VERIFY", dict(atol=1e-4, rtol=1e-4)),
patch.object(select_algorithm.AlgorithmSelectorCache, "lookup", skip_cache),
]:
fn = patcher(fn)

@functools.wraps(fn)
def wrapped(*args, **kwargs):
counters.clear()
torch.manual_seed(12345)
return fn(*args, **kwargs)

return wrapped


class TestSelectAlgorithm(TestCase):
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_linear_fp32_cpu(self):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(10, 32, bias)

@torch.compile
def forward(self, x):
return self.linear(x)

for bias in [True, False]:
counters.clear()
mod = M(bias=bias).eval()
v = torch.randn(2, 10)
mod(v)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)


@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class TestDynamicSelectAlgorithm(TestCase):
test_linear_fp32_dynamic_shapes_cpu = TestSelectAlgorithm.test_linear_fp32_cpu


if __name__ == "__main__":
from torch.testing._internal.inductor_utils import HAS_CPU

if HAS_CPU and not IS_MACOS:
run_tests()
133 changes: 112 additions & 21 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p
from ctypes import byref, c_size_t, c_void_p, CDLL
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from types import ModuleType
from typing import (
Any,
Callable,
Expand All @@ -29,13 +30,19 @@
from torch._dynamo.testing import rand_strided

from torch._inductor import ir
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
from torch._inductor.codecache import (
CppCodeCache,
CUDACodeCache,
DLLWrapper,
get_hash,
PyCodeCache,
)

if TYPE_CHECKING:
from torch._inductor.select_algorithm import TritonTemplateCaller

from . import config
from .utils import do_bench
from .utils import do_bench, timed
from .virtualized import V

CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
Expand Down Expand Up @@ -427,6 +434,14 @@ def make_run_fn(
def cleanup_run_fn(self) -> None:
pass

def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
raise NotImplementedError()

def benchmark(
self,
*input_tensors: torch.Tensor,
Expand All @@ -452,22 +467,7 @@ def benchmark(
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
start_ts = time.time()

device_idx_set = {
tensor.device.index
for tensor in [*input_tensors, output_tensor]
if isinstance(tensor, torch.Tensor)
and tensor.is_cuda
and tensor.device.index is not None
}
assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
if len(device_idx_set) == 1:
device_idx = next(iter(device_idx_set))
else:
device_idx = torch.cuda.current_device()

with torch.cuda.device(device_idx):
out = do_bench(fn)
torch.cuda.synchronize() # shake out any CUDA errors
out = self.do_bench(fn, *input_tensors, output_tensor)

if debug:
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
Expand Down Expand Up @@ -499,7 +499,34 @@ def benchmark(
return self.value


class TritonBenchmarkRequest(BenchmarkRequest):
class GPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
device_idx_set = {
tensor.device.index
for tensor in [*input_tensors, output_tensor]
if isinstance(tensor, torch.Tensor)
and tensor.is_cuda
and tensor.device.index is not None
}
assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
if len(device_idx_set) == 1:
device_idx = next(iter(device_idx_set))
else:
device_idx = torch.cuda.current_device()

with torch.cuda.device(device_idx):
out = do_bench(fn)
torch.cuda.synchronize() # shake out any CUDA errors

return out


class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!

Expand Down Expand Up @@ -573,7 +600,7 @@ def __str__(self) -> str:
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"


class CUDABenchmarkRequest(BenchmarkRequest):
class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!

Expand Down Expand Up @@ -661,6 +688,70 @@ def __str__(self) -> str:
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"


class CPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
return timed(fn, ())


@dataclasses.dataclass
class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put Tensors in here!

def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
):
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.source_code = source_code
self.hash_key = get_hash(source_code)
self.DLL: Optional[Union[CDLL, ModuleType]] = None

def precompile(self):
# Prepopulate CppCodeCache
# may happen in separate Threadpool
log.debug("Precompiling %s", self)
CppCodeCache.load(self.source_code, cuda=False)
log.debug("Done precompiling %s", self)

def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
self.DLL = CppCodeCache.load(self.source_code, cuda=False)
args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
log.debug(
"make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
self.kernel_name,
self.DLL,
args,
self.extra_args,
)
run_method = getattr(self.DLL, self.kernel_name)

# Generate partial function.
return functools.partial(
run_method,
*args,
*self.extra_args,
)

def cleanup_run_fn(self) -> None:
if self.DLL is not None:
self.DLL.close()

def __str__(self) -> str:
return f"{self.kernel_name=}"


def benchmark_in_sub_process(
choices: List[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]:
Expand Down
3 changes: 0 additions & 3 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ def get_global_cache_path() -> Optional[Path]:
)

def __init__(self) -> None:
if not torch.cuda.is_available():
return

self.system = CacheBase.get_system()

self.local_cache_path = CacheBase.get_local_cache_path()
Expand Down
49 changes: 46 additions & 3 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from copy import copy, deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union

import sympy

Expand All @@ -19,6 +19,7 @@
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from ..._dynamo.utils import counters

from .. import codecache, config, ir, metrics
from ..codegen.wrapper import WrapperCodeGen
Expand Down Expand Up @@ -3729,6 +3730,8 @@ def _can_fuse_horizontal_impl(self, node1, node2):
return self._why_fuse_nodes(node1, node2) is not None

def can_fuse_horizontal(self, node1, node2):
if node1.is_template() or node2.is_template():
return False
if (
len(node1.get_nodes()) + len(node2.get_nodes())
> config.cpp.max_horizontal_fusion_size
Expand Down Expand Up @@ -3809,6 +3812,9 @@ def get_fusion_pair_priority(self, node1, node2):
return 0

def can_fuse_vertical(self, node1, node2):
# TODO: support vertical fusion for template nodes
if node1.is_template() or node2.is_template():
return False
return (
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
) or self.can_fuse_vertical_outer_loop(node1, node2)
Expand Down Expand Up @@ -3865,6 +3871,42 @@ def codegen_node(
if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
self._set_flush_status(True)

def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
return isinstance(node, SchedulerNode) and isinstance(
node.node, ir.CppTemplateBuffer
)

def codegen_template(
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
):
"""
Codegen a CPP template, possibly with fused epilogues
"""
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template(
template_node
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
template_node = cast(SchedulerNode, template_node)
_, (_, rnumel) = template_node.group
assert rnumel == ()
ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes]
assert all(
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
with kernel:
for node in [template_node, *epilogue_nodes]:
node.mark_run()
src_code = render()

with V.set_kernel_handler(kernel):
node_schedule = [template_node, *epilogue_nodes]
kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.scheduler.free_buffers()

def _get_scheduled_num_args(self):
return self.kernel_group.get_num_args()

Expand All @@ -3874,7 +3916,7 @@ def ready_to_flush(self):
def codegen_sync(self):
pass

def define_kernel(self, src_code, nodes):
def define_kernel(self, src_code, nodes, kernel_args=None):
wrapper = V.graph.wrapper_code
fused_name = (
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
Expand All @@ -3890,7 +3932,8 @@ def define_kernel(self, src_code, nodes):
src_code = src_code.replace("#pragma CMT", "//")

compile_wrapper = IndentedBuffer()
_, _, arg_types = self.kernel_group.args.cpp_argdefs()
args = self.kernel_group.args if kernel_args is None else kernel_args
_, _, arg_types = args.cpp_argdefs()
if not V.graph.cpp_wrapper:
compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
compile_wrapper.splice(src_code, strip=True)
Expand Down

0 comments on commit 453bca3

Please sign in to comment.