Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] generate triton kernel benchmark #95506

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 86 additions & 14 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..._dynamo import config as dynamo_config
from .. import config, ir, scheduler
from ..codecache import get_code_path
from ..ir import ReductionHint
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..utils import (
Expand Down Expand Up @@ -1169,6 +1170,78 @@ def codegen_body(self):
self.stores.clear()
self.suffix.clear()

def codegen_kernel_benchmark(self):
result = IndentedBuffer()
argdefs, call_args, signature = self.args.python_argdefs()

result.writelines(["", "", "def get_args():"])
with result.indent():
for arg_name in call_args:
buf = V.graph.get_buffer(arg_name)
if buf:
result.writeline(
f"{arg_name} = rand_strided({tuple(buf.get_size())}, {tuple(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
)
elif arg_name in V.graph.constants:
# note that random seed is put in V.graph.constants
const_tensor = V.graph.constants[arg_name]
result.writeline(
f"{arg_name} = rand_strided({tuple(const_tensor.size())}, {tuple(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # noqa: B950 line too long
)
else:
raise KeyError(
f"Don't find the buffer or const tensor for {arg_name}"
)
result.writeline(f"return {', '.join(call_args)},")

result.writelines(["\n", "\n", "def call(args):"])
grid = []
extra_args = []
with result.indent():
index = V.graph.scheduler.current_device.index
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
with result.indent():
result.writeline(
f"torch.cuda.set_device({index})"
) # no-op to ensure context
for tree in self.range_trees:
expr = pexpr(tree.numel)
if tree.prefix != "r" or self.inside_reduction:
extra_args.append(expr)
if tree.prefix != "r":
grid.append(expr)

stream_name = f"stream{index}"
result.writeline(f"{stream_name} = get_cuda_stream({index})")
extra_args_str = ", ".join(map(str, extra_args)) + ", "
result.writeline(
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
)

result.writelines(["\n", "\n", "if __name__ == '__main__':"])
with result.indent():
result.writeline(
"from torch._C import _cuda_getCurrentRawStream as get_cuda_stream"
)
result.writeline("from torch._dynamo.testing import rand_strided")
result.writeline("from torch._inductor.utils import get_num_bytes")
result.writeline("import torch")
result.writeline("from torch._inductor.triton_ops.autotune import grid")
result.writeline("from triton.testing import do_bench")
result.writeline("")

result.writeline("args = get_args()")
result.writeline(
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)[0]"
)
result.writeline("num_gb = get_num_bytes(*args) / 1e9")
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
result.writeline(
'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")'
)

return result

def codegen_kernel(self, name=None):
from triton import next_power_of_2

Expand Down Expand Up @@ -1275,21 +1348,13 @@ def codegen_kernel(self, name=None):
code.writeline(f"{old} = {new}")
code.splice(self.body)

if config.benchmark_kernel:
code.splice(self.codegen_kernel_benchmark())

if name is not None:
return code.getvalue()

wrapper = IndentedBuffer()
wrapper.writeline("async_compile.triton('''")
wrapper.splice(code.getvalue(), strip=True)
wrapper.writeline("''')")
return wrapper.getvalue()

def codegen_template_wrapper(self, src_code):
wrapper = IndentedBuffer()
wrapper.writeline("async_compile.triton('''")
wrapper.splice(src_code, strip=True)
wrapper.writeline("''')")
return wrapper.getvalue()
return code.getvalue()

def codegen_static_numels(self, code):
"""
Expand Down Expand Up @@ -1577,7 +1642,14 @@ def define_kernel(self, src_code, node_schedule):
# TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
# not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
src_code = src_code.replace("#pragma CMT", "#")
wrapper.define_kernel(kernel_name, src_code)

_, _, kernel_path = get_code_path(src_code, "py", extra="")
compile_wrapper = IndentedBuffer()
compile_wrapper.writeline("async_compile.triton('''")
compile_wrapper.splice(src_code, strip=True)
compile_wrapper.writeline("''')")

wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), kernel_path)
return kernel_name

def codegen_template(self, template_node, epilogue_nodes):
Expand All @@ -1594,7 +1666,7 @@ def codegen_template(self, template_node, epilogue_nodes):
for node in epilogue_nodes:
node.codegen(kernel.split_and_set_ranges(node.get_ranges()))

src_code = kernel.codegen_template_wrapper(render())
src_code = render()
kernel_name = self.define_kernel(src_code, [template_node, *epilogue_nodes])
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
self.scheduler.free_buffers()
Expand Down
5 changes: 3 additions & 2 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,9 @@ def add_expr_input(name, val):
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
)

def define_kernel(self, name: str, kernel: str):
self.header.splice(f"\n\n{name} = {kernel}")
def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
kernel_path_comment = f"# kernel path: {kernel_path}\n" if kernel_path else ""
self.header.splice(f"\n\n{kernel_path_comment}{name} = {kernel}")

def load_kernel(self, name: str = None, kernel: str = None, arg_types: List = None):
return
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@

comment_origin = False

benchmark_kernel = os.environ.get("TORCHINDUCTOR_BENCHMARK_KERNEL", "0") == "1"


def is_fbcode():
return not hasattr(torch.version, "git_version")
Expand Down
11 changes: 9 additions & 2 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def warn_fallback(self, name):
def fake_mode(self):
return V.fake_mode

def get_buffer(self, buffer_name: str):
if buffer_name in self.name_to_buffer:
return self.name_to_buffer[buffer_name]
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name]
return None

def get_dtype(self, buffer_name: str):
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
Expand Down Expand Up @@ -599,8 +606,8 @@ def compile_to_module(self):
for name, value in self.constants.items():
setattr(mod, name, value)

if dynamo_config.output_code:
log.info("Output code: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)
V.debug.rename(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
Expand Down
18 changes: 9 additions & 9 deletions torch/_inductor/triton_ops/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from .. import config
from ..codecache import cache_dir
from ..ir import ReductionHint, TileHint
from ..utils import ceildiv, conditional_product, do_bench, has_triton, next_power_of_2
from ..utils import (
ceildiv,
conditional_product,
do_bench,
get_num_bytes,
has_triton,
next_power_of_2,
)
from .conv_perf_model import (
early_config_prune as conv_early_config_prune,
estimate_conv_time,
Expand Down Expand Up @@ -238,18 +245,11 @@ def run(self, *args, grid, stream):
super().run(*args, grid=grid, stream=stream)
(launcher,) = self.launchers

def get_num_bytes(*args):
return sum(
arg.numel() * arg.element_size()
for arg in args
if isinstance(arg, torch.Tensor)
)

ms = self.bench(launcher, *args, grid=grid)[0]
num_gb = get_num_bytes(*args) / 1e9
gb_per_s = num_gb / (ms / 1e3)

collected_calls.append((kernel_name, ms, num_gb, 1e3 * num_gb / ms))
collected_calls.append((kernel_name, ms, num_gb, gb_per_s))
import colorama

info_str = f"{kernel_name}\t {ms:.3f}ms\t{num_gb:.3f} GB \t {gb_per_s:.2f}GB/s"
Expand Down
11 changes: 11 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,14 @@ def developer_warning(msg):
log.warning(msg)
else:
log.info(msg)


def get_num_bytes(*args):
"""
Return the total number of bytes the arguments of tensor type takes.
"""
return sum(
arg.numel() * arg.element_size()
for arg in args
if isinstance(arg, torch.Tensor)
)