Skip to content

Commit

Permalink
[inductor] run all kernel benchmarks individually in a compiled module (
Browse files Browse the repository at this point in the history
pytorch#95845)

This is a follow up for PR pytorch#95506 to run all the triton kernels in a compiled module individually as suggested by Horace.

Here are the steps:
1. Run the model as usual with a benchmark script and with TORCHINDUCTOR_BENCHMARK_KERNEL enabled. e.g.
```
TORCHINDUCTOR_BENCHMARK_KERNEL=1 python benchmarks/dynamo/torchbench.py --backend inductor --amp --performance --dashboard --only resnet18 --disable-cudagraphs --training
```
2. From the output we will see 3 lines like
```
Compiled module path: /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py
```
That's because we have one graph module for fwd/bwd/optitimizer respectively. Each graph module will have one such output corresponding to the compiled module.

3. We can run the compiled module directly. Without any extra arguments, we just maintain the previous behavior to run the call function -- which just does what the original graph module does but in a more efficient way. But if we add the '-k' argument, we will run benchmark for each individual kernels in the file.

```
python /tmp/torchinductor_shunting/rs/crsuc6zrt3y6lktz33jjqgpkuahya56xj6sentyiz7iv4pjud43j.py -k
```

Example output:
<img width="430" alt="Screenshot 2023-03-01 at 4 51 06 PM" src="https://user-images.githubusercontent.com/52589240/222302996-814a85be-472b-463c-9e85-39d2c9d20e1a.png">

Note: I use the first 10 characters of the hash to identify each kernel since
1. hash is easier to get in the code :)
2. name like `triton__3` only makes sense within a compiled module, but a hash can make sense even without specifying the compiled module (assuming we have enough bytes for the hash)

If we found a triton kernel with hash like c226iuf2wi having poor performance, we can look it up in the original compiled module file. It works since we comment each compiled triton kernel with the full hash.

Pull Request resolved: pytorch#95845
Approved by: https://github.com/Chillee
  • Loading branch information
ydwu4 committed Mar 11, 2023
1 parent 77fe0df commit dc720a7
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 15 deletions.
15 changes: 9 additions & 6 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,13 +1224,7 @@ def codegen_kernel_benchmark(self):

result.writelines(["\n", "\n", "if __name__ == '__main__':"])
with result.indent():
result.writeline(
"from torch._C import _cuda_getCurrentRawStream as get_cuda_stream"
)
result.writeline("from torch._dynamo.testing import rand_strided")
result.writeline("from torch._inductor.utils import get_num_bytes")
result.writeline("import torch")
result.writeline("from torch._inductor.triton_ops.autotune import grid")
result.writeline("from triton.testing import do_bench")
result.writeline("")

Expand Down Expand Up @@ -1273,6 +1267,15 @@ def codegen_kernel(self, name=None):
from torch._inductor.utils import instance_descriptor
"""
)
if config.benchmark_kernel:
code.splice(
"""
from torch._dynamo.testing import rand_strided
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
import torch
from torch._inductor.triton_ops.autotune import grid
"""
)

argdefs, _, signature = self.args.python_argdefs()
# maps actual expression to SizeArg if its in sizevars replacements
Expand Down
47 changes: 38 additions & 9 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

from .. import codecache, config, ir
from ..codecache import code_hash, cpp_compile_command, get_code_path
from ..utils import cache_on_self, has_triton, sympy_dot, sympy_product
from ..utils import (
cache_on_self,
get_benchmark_name,
has_triton,
sympy_dot,
sympy_product,
)
from ..virtualized import V
from .common import CodeGen, DeferredLine, IndentedBuffer, Kernel, PythonPrinter

Expand Down Expand Up @@ -549,13 +555,7 @@ def generate(self):

return result.getvalue()

def add_benchmark_harness(self, output):
"""
Append a benchmark harness to generated code for debugging
"""
if not config.benchmark_harness:
return

def benchmark_compiled_module(self, output):
def add_fake_input(name, shape, stride, device, dtype):
output.writeline(
f"{name} = rand_strided("
Expand All @@ -567,7 +567,7 @@ def add_fake_input(name, shape, stride, device, dtype):
def add_expr_input(name, val):
output.writeline(f"{name} = {val}")

output.writelines(["", "", 'if __name__ == "__main__":'])
output.writelines(["", "", "def benchmark_compiled_module():"])
with output.indent():
output.splice(
"""
Expand Down Expand Up @@ -596,6 +596,35 @@ def add_expr_input(name, val):
f"print_performance(lambda: call([{', '.join(V.graph.graph_inputs.keys())}]))"
)

def add_benchmark_harness(self, output):
"""
Append a benchmark harness to generated code for debugging
"""
if not config.benchmark_harness:
return

self.benchmark_compiled_module(output)

output.writelines(["", "", 'if __name__ == "__main__":'])
with output.indent():
output.writelines(
[
"import argparse",
"from torch._inductor.utils import benchmark_all_kernels",
"",
"parser = argparse.ArgumentParser()",
'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long
"args = parser.parse_args()",
"",
"if args.benchmark_kernels:",
]
)
with output.indent():
output.writeline(f"benchmark_all_kernels('{get_benchmark_name()}')")
output.writeline("else:")
with output.indent():
output.writeline("benchmark_compiled_module()")

def define_kernel(self, name: str, kernel: str, kernel_path: str = None):
kernel_path_comment = f"# kernel path: {kernel_path}\n" if kernel_path else ""
self.header.splice(f"\n\n{kernel_path_comment}{name} = {kernel}")
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/triton_ops/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import operator
import os
import os.path
import re
import threading
from typing import List

Expand Down
69 changes: 69 additions & 0 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import operator
import os
import shutil
import sys
import tempfile
import textwrap
import time
Expand Down Expand Up @@ -579,3 +580,71 @@ def get_num_bytes(*args):
for arg in args
if isinstance(arg, torch.Tensor)
)


def get_benchmark_name():
"""
An experimental API used only when config.benchmark_kernel is true.
The benchmark name is only available at codegen time. So we can not
directly call it in benchmark_all_kernels which is run after codegen.
The function assumes the argument after --only is the benchmark name.
It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
scripts, this function may return None.
There are 2 flavors of --only argument we need handle:
1. --only model_name
2. --only=model_name
"""
try:
idx = sys.argv.index("--only")
if (
idx + 1 < len(sys.argv)
and len(sys.argv[idx + 1]) > 0
and sys.argv[idx + 1][0] != "-"
):
return sys.argv[idx + 1]
except ValueError:
pass

for arg in sys.argv:
if arg.startswith("--only="):
return arg[len("--only=") :]


def benchmark_all_kernels(benchmark_name):
"""
An experimental API used only when config.benchmark_kernel is true.
Run the kernel benchmarks for all the kernels cached in PyCodeCache.
Used in the compiled modules.
Put this method here rather than codegen it for convenience since its implementation
does not change based on different graph modules being compiled.
"""
from torch._inductor.codecache import PyCodeCache

nfound = 0
for kernel_key, kernel_mod in PyCodeCache.cache.items():
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
continue
args = kernel_mod.get_args()
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
num_gb = get_num_bytes(*args) / 1e9
gb_per_s = num_gb / (ms / 1e3)

# follow what we do in DebugAutotuner
info_str = f"{benchmark_name:20} {kernel_key[:10]} {ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
import colorama

if ms > 0.012 and gb_per_s < 650:
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
else:
print(info_str)

nfound += 1
if nfound == 0:
print(
"No kernel with benchmark functionality found. Make sure you run inductor with config.benchmark_kernel being True"
)

0 comments on commit dc720a7

Please sign in to comment.