Skip to content

Commit

Permalink
[inductor][cpp] GEMM template
Browse files Browse the repository at this point in the history
ghstack-source-id: 78e95234d720874d2de4d57f523821bec8b90461
Pull Request resolved: #124021
  • Loading branch information
jgong5 committed Apr 16, 2024
1 parent ab66e95 commit e3458a8
Show file tree
Hide file tree
Showing 10 changed files with 715 additions and 73 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def test_linear_fp32(self):
class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(10, 30, bias)
self.linear = torch.nn.Linear(10, 32, bias)

def forward(self, x):
return self.linear(x)
Expand Down
133 changes: 112 additions & 21 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p
from ctypes import byref, c_size_t, c_void_p, CDLL
from multiprocessing.process import BaseProcess
from multiprocessing.queues import Queue
from types import ModuleType
from typing import (
Any,
Callable,
Expand All @@ -29,13 +30,19 @@
from torch._dynamo.testing import rand_strided

from torch._inductor import ir
from torch._inductor.codecache import CUDACodeCache, DLLWrapper, PyCodeCache
from torch._inductor.codecache import (
CppCodeCache,
CUDACodeCache,
DLLWrapper,
get_hash,
PyCodeCache,
)

if TYPE_CHECKING:
from torch._inductor.select_algorithm import TritonTemplateCaller

from . import config
from .utils import do_bench
from .utils import do_bench, timed
from .virtualized import V

CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
Expand Down Expand Up @@ -427,6 +434,14 @@ def make_run_fn(
def cleanup_run_fn(self) -> None:
pass

def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
raise NotImplementedError()

def benchmark(
self,
*input_tensors: torch.Tensor,
Expand All @@ -452,22 +467,7 @@ def benchmark(
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
start_ts = time.time()

device_idx_set = {
tensor.device.index
for tensor in [*input_tensors, output_tensor]
if isinstance(tensor, torch.Tensor)
and tensor.is_cuda
and tensor.device.index is not None
}
assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
if len(device_idx_set) == 1:
device_idx = next(iter(device_idx_set))
else:
device_idx = torch.cuda.current_device()

with torch.cuda.device(device_idx):
out = do_bench(fn)
torch.cuda.synchronize() # shake out any CUDA errors
out = self.do_bench(fn, *input_tensors, output_tensor)

if debug:
bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
Expand Down Expand Up @@ -499,7 +499,34 @@ def benchmark(
return self.value


class TritonBenchmarkRequest(BenchmarkRequest):
class GPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
device_idx_set = {
tensor.device.index
for tensor in [*input_tensors, output_tensor]
if isinstance(tensor, torch.Tensor)
and tensor.is_cuda
and tensor.device.index is not None
}
assert len(device_idx_set) <= 1, f"Can not mix devices {device_idx_set}"
if len(device_idx_set) == 1:
device_idx = next(iter(device_idx_set))
else:
device_idx = torch.cuda.current_device()

with torch.cuda.device(device_idx):
out = do_bench(fn)
torch.cuda.synchronize() # shake out any CUDA errors

return out


class TritonBenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!

Expand Down Expand Up @@ -573,7 +600,7 @@ def __str__(self) -> str:
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"


class CUDABenchmarkRequest(BenchmarkRequest):
class CUDABenchmarkRequest(GPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!

Expand Down Expand Up @@ -661,6 +688,70 @@ def __str__(self) -> str:
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"


class CPUDeviceBenchmarkRequest(BenchmarkRequest):
def do_bench(
self,
fn,
*input_tensors: torch.Tensor,
output_tensor: Optional[torch.Tensor] = None,
) -> float:
return timed(fn, ())


@dataclasses.dataclass
class CppBenchmarkRequest(CPUDeviceBenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put Tensors in here!

def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
):
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.source_code = source_code
self.hash_key = get_hash(source_code)
self.DLL: Optional[Union[CDLL, ModuleType]] = None

def precompile(self):
# Prepopulate CppCodeCache
# may happen in separate Threadpool
log.debug("Precompiling %s", self)
CppCodeCache.load(self.source_code, cuda=False)
log.debug("Done precompiling %s", self)

def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
self.DLL = CppCodeCache.load(self.source_code, cuda=False)
args = [tensor.data_ptr() for tensor in list(input_tensors) + [output_tensor]]
log.debug(
"make_run_fn: self.kernel_name=%s, self.DLL=%s, args=%s, self.extra_args=%s",
self.kernel_name,
self.DLL,
args,
self.extra_args,
)
run_method = getattr(self.DLL, self.kernel_name)

# Generate partial function.
return functools.partial(
run_method,
*args,
*self.extra_args,
)

def cleanup_run_fn(self) -> None:
if self.DLL is not None:
self.DLL.close()

def __str__(self) -> str:
return f"{self.kernel_name=}"


def benchmark_in_sub_process(
choices: List[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]:
Expand Down
3 changes: 0 additions & 3 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,6 @@ def get_global_cache_path() -> Optional[Path]:
)

def __init__(self) -> None:
if not torch.cuda.is_available():
return

self.system = CacheBase.get_system()

self.local_cache_path = CacheBase.get_local_cache_path()
Expand Down
49 changes: 46 additions & 3 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
from copy import copy, deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union

import sympy

Expand All @@ -19,6 +19,7 @@
from torch.utils import _pytree as pytree
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from ..._dynamo.utils import counters

from .. import codecache, config, ir, metrics
from ..codegen.wrapper import WrapperCodeGen
Expand Down Expand Up @@ -3729,6 +3730,8 @@ def _can_fuse_horizontal_impl(self, node1, node2):
return self._why_fuse_nodes(node1, node2) is not None

def can_fuse_horizontal(self, node1, node2):
if node1.is_template() or node2.is_template():
return False
if (
len(node1.get_nodes()) + len(node2.get_nodes())
> config.cpp.max_horizontal_fusion_size
Expand Down Expand Up @@ -3809,6 +3812,9 @@ def get_fusion_pair_priority(self, node1, node2):
return 0

def can_fuse_vertical(self, node1, node2):
# TODO: support vertical fusion for template nodes
if node1.is_template() or node2.is_template():
return False
return (
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
) or self.can_fuse_vertical_outer_loop(node1, node2)
Expand Down Expand Up @@ -3865,6 +3871,42 @@ def codegen_node(
if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
self._set_flush_status(True)

def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
return isinstance(node, SchedulerNode) and isinstance(
node.node, ir.CppTemplateBuffer
)

def codegen_template(
self, template_node: BaseSchedulerNode, epilogue_nodes: List[SchedulerNode]
):
"""
Codegen a CPP template, possibly with fused epilogues
"""
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template(
template_node
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
template_node = cast(SchedulerNode, template_node)
_, (_, rnumel) = template_node.group
assert rnumel == ()
ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
epilogue_ir_nodes: List[ir.Buffer] = [n.node for n in epilogue_nodes]
assert all(
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
with kernel:
for node in [template_node, *epilogue_nodes]:
node.mark_run()
src_code = render()

with V.set_kernel_handler(kernel):
node_schedule = [template_node, *epilogue_nodes]
kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
kernel.call_kernel(kernel_name, ctb)
V.graph.removed_buffers |= kernel.removed_buffers
self.scheduler.free_buffers()

def _get_scheduled_num_args(self):
return self.kernel_group.get_num_args()

Expand All @@ -3874,7 +3916,7 @@ def ready_to_flush(self):
def codegen_sync(self):
pass

def define_kernel(self, src_code, nodes):
def define_kernel(self, src_code, nodes, kernel_args=None):
wrapper = V.graph.wrapper_code
fused_name = (
get_fused_kernel_name(nodes, config.cpp.descriptive_names)
Expand All @@ -3890,7 +3932,8 @@ def define_kernel(self, src_code, nodes):
src_code = src_code.replace("#pragma CMT", "//")

compile_wrapper = IndentedBuffer()
_, _, arg_types = self.kernel_group.args.cpp_argdefs()
args = self.kernel_group.args if kernel_args is None else kernel_args
_, _, arg_types = args.cpp_argdefs()
if not V.graph.cpp_wrapper:
compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
compile_wrapper.splice(src_code, strip=True)
Expand Down
79 changes: 79 additions & 0 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import cast, List, Optional

from ..ir import Buffer, CppTemplateBuffer, IRNode, Layout
from .cpp_template import CppTemplate

from .cpp_template_kernel import CppTemplateKernel

GEMM_TEMPLATE = r"""
{{template.header().getvalue()}}
// TODO: use micro-kernel to replace this naive GEMM implementation below
extern "C"
{{kernel.def_kernel(inputs=[X, W], outputs=[Y], names_str="X, W, Y")}}
{
// TODO: support dynamic shapes
int64_t M = {{kernel.size(Y, 0)}};
int64_t N = {{kernel.size(Y, 1)}};
int64_t K = {{kernel.size(X, 1)}};
#pragma omp parallel for collapse(2)
for (int64_t i = 0; i < M; ++i) {
for (int64_t j = 0; j < N/{{n_bs}}; ++j) {
{{kernel.acc_dtype(Y)}} sum[16];
for (int64_t ni = 0; ni < {{n_bs}}; ++ni) {
sum[ni] = 0;
}
for (int64_t k = 0; k < K; ++k) {
for (int64_t ni = 0; ni < {{n_bs}}; ++ni) {
sum[ni] += {{kernel.index(X, ["i", "k"])}} * {{kernel.index(W, ["j", "k", "ni"])}};
}
}
for (int64_t ni = 0; ni < {{n_bs}}; ++ni) {
int64_t n = j * {{n_bs}} + ni;
{{kernel.index(Y, ["i", "n"])}} = sum[ni];
}
}
}
}
"""


class CppPackedGemmTemplate(CppTemplate):
def __init__(
self,
input_nodes,
layout: Layout,
n_block_size: int = 1,
):
super().__init__("cpp_gemm", input_nodes, layout)
self.n_block_size = n_block_size

def render( # type: ignore[override]
self,
kernel: CppTemplateKernel,
template_buffer_node: Optional[CppTemplateBuffer] = None,
epilogue_nodes: Optional[List[IRNode]] = None,
**kwargs,
) -> str:
assert not epilogue_nodes, "Epilogue nodes are not supported for GEMM template."
assert len(self.input_nodes) >= 2

if template_buffer_node is not None:
self.output_node = template_buffer_node
if epilogue_nodes is not None and len(epilogue_nodes) > 0:
self.output_node = cast(Buffer, epilogue_nodes[-1])
assert self.output_node is not None

X, W = self.input_nodes[0], self.input_nodes[1]
Y = self.output_node

options = dict(
X=X,
W=W,
Y=Y,
n_bs=self.n_block_size,
template=self,
kernel=kernel,
epilogues=epilogue_nodes,
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)

0 comments on commit e3458a8

Please sign in to comment.