Skip to content

Commit

Permalink
reland #96248 [inductor] show performance for each autotune config fo…
Browse files Browse the repository at this point in the history
…r a kernel (#96458)

Pull Request resolved: #96458
Approved by: https://github.com/ngimel
  • Loading branch information
shunting314 authored and pytorchmergebot committed Mar 10, 2023
1 parent cf3d3a5 commit cc699c5
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 26 deletions.
15 changes: 14 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,8 +1205,9 @@ def codegen_kernel_benchmark(self):
result.writelines(["\n", "\n", "def call(args):"])
grid = []
extra_args = []
extra_args_str = None
index = V.graph.scheduler.current_device.index
with result.indent():
index = V.graph.scheduler.current_device.index
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
with result.indent():
result.writeline(
Expand All @@ -1226,6 +1227,18 @@ def codegen_kernel_benchmark(self):
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
)

# benchmark all configs
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
with result.indent():
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
with result.indent():
result.writeline(
f"torch.cuda.set_device({index})"
) # no-op to ensure context
result.writeline(
f"return triton_.benchmark_all_configs(*args, {extra_args_str}grid=grid({', '.join(grid)}))"
)

result.writelines(["\n", "\n", "if __name__ == '__main__':"])
with result.indent():
result.writeline("from torch._inductor.utils import get_num_bytes")
Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,16 @@ def add_benchmark_harness(self, output):
"",
"parser = argparse.ArgumentParser()",
'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long
'parser.add_argument("--benchmark-all-configs", "-c", action="store_true", help="Whether to benchmark each individual config for a kernel")', # noqa: B950, line too long
"args = parser.parse_args()",
"",
"if args.benchmark_kernels:",
]
)
with output.indent():
output.writeline(f"benchmark_all_kernels('{get_benchmark_name()}')")
output.writeline(
f"benchmark_all_kernels('{get_benchmark_name()}', args.benchmark_all_configs)"
)
output.writeline("else:")
with output.indent():
output.writeline("benchmark_compiled_module()")
Expand Down
33 changes: 23 additions & 10 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,29 @@ def is_fbcode():
# warnings intended for PyTorch developers, disable for point releases
developer_warnings = is_fbcode() or "+" in torch.__version__

compile_threads = (
1
if sys.platform == "win32" or is_fbcode()
else min(
32,
len(os.sched_getaffinity(0))
if hasattr(os, "sched_getaffinity")
else os.cpu_count(),
)
)

def decide_compile_threads():
"""
Here are the precedence to decide compile_threads
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
setting this to 1 to make pdb happy.
2. Set to 1 if it's win32 platform or it's a fbcode build
3. decide by the number of CPU cores
"""
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
elif sys.platform == "win32" or is_fbcode():
return 1
else:
return min(
32,
len(os.sched_getaffinity(0))
if hasattr(os, "sched_getaffinity")
else os.cpu_count(),
)


compile_threads = decide_compile_threads()

# autotuning global cache path
if is_fbcode():
Expand Down
17 changes: 13 additions & 4 deletions torch/_inductor/triton_ops/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def kernel_call():
return do_bench(kernel_call, rep=40, fast_flush=True)

@dynamo_timed
def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
def benchmark_all_configs(self, *args, **kwargs):
from ..compile_fx import clone_preserve_strides

# clone inplace buffers to avoid autotune contaminating them if
Expand All @@ -171,9 +170,14 @@ def autotune_to_one_config(self, *args, **kwargs):
cloned_args.append(arg)

timings = {
launcher: self.bench(launcher, *cloned_args, **kwargs)
launcher: self.bench(launcher, *cloned_args, **kwargs)[0]
for launcher in self.launchers
}
return timings

def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
timings = self.benchmark_all_configs(*args, **kwargs)
self.launchers = [builtins.min(timings, key=timings.get)]
if self.save_cache_hook:
self.save_cache_hook(self.launchers[0].config)
Expand Down Expand Up @@ -313,8 +317,13 @@ def cached_autotune(
configs = unique_configs(configs)
assert len(configs) == 1 or filename

# The autotune cache will simply replace the list of candidate configs with
# the best config cached. We don't want that when we benchmark triton kernels.
# We need the perf for each of the candidate config instead.
cache_autotune_result = not config.benchmark_kernel

# on disk caching logic
if filename is not None and len(configs) > 1:
if cache_autotune_result and filename is not None and len(configs) > 1:
cache_filename = os.path.splitext(filename)[0] + ".best_config"
configs_hash = hash_configs(configs)
best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
Expand Down
36 changes: 26 additions & 10 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def get_benchmark_name():
return arg[len("--only=") :]


def benchmark_all_kernels(benchmark_name):
def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
"""
An experimental API used only when config.benchmark_kernel is true.
Expand All @@ -642,18 +642,34 @@ def benchmark_all_kernels(benchmark_name):
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
continue
args = kernel_mod.get_args()
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
num_gb = get_num_bytes(*args) / 1e9
gb_per_s = num_gb / (ms / 1e3)

# follow what we do in DebugAutotuner
info_str = f"{benchmark_name:20} {kernel_key[:10]} {ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
import colorama

if ms > 0.012 and gb_per_s < 650:
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
def get_info_str(ms, prefix=""):
gb_per_s = num_gb / (ms / 1e3)
# follow what we do in DebugAutotuner
info_str = f"{prefix}{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
import colorama

if ms > 0.012 and gb_per_s < 650:
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET
return info_str

bench_result = []
if benchmark_all_configs:
assert hasattr(kernel_mod, "benchmark_all_configs")
bench_result = kernel_mod.benchmark_all_configs(args)
bench_result = [
(launcher.config, ms) for launcher, ms in bench_result.items()
]
print(f"{benchmark_name:20} {kernel_key[:10]}")
for cfg, ms in bench_result:
print(f" {get_info_str(ms)} @ {cfg}")
else:
print(info_str)
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
assert (
len(kernel_mod.triton_.launchers) == 1
), "Autotuner should have selected the best config"
print(get_info_str(ms, prefix=f"{benchmark_name:20} {kernel_key[:10]} "))

nfound += 1
if nfound == 0:
Expand Down

0 comments on commit cc699c5

Please sign in to comment.