Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reland #96248 [inductor] show performance for each autotune config for a kernel #96458

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 14 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,8 +1205,9 @@ def codegen_kernel_benchmark(self):
result.writelines(["\n", "\n", "def call(args):"])
grid = []
extra_args = []
extra_args_str = None
index = V.graph.scheduler.current_device.index
with result.indent():
index = V.graph.scheduler.current_device.index
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
with result.indent():
result.writeline(
Expand All @@ -1226,6 +1227,18 @@ def codegen_kernel_benchmark(self):
f"triton_.run(*args, {extra_args_str}grid=grid({', '.join(grid)}), stream={stream_name})"
)

# benchmark all configs
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
with result.indent():
result.writeline(f"with torch.cuda._DeviceGuard({index}):")
with result.indent():
result.writeline(
f"torch.cuda.set_device({index})"
) # no-op to ensure context
result.writeline(
f"return triton_.benchmark_all_configs(*args, {extra_args_str}grid=grid({', '.join(grid)}))"
)

result.writelines(["\n", "\n", "if __name__ == '__main__':"])
with result.indent():
result.writeline("from torch._inductor.utils import get_num_bytes")
Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,16 @@ def add_benchmark_harness(self, output):
"",
"parser = argparse.ArgumentParser()",
'parser.add_argument("--benchmark-kernels", "-k", action="store_true", help="Whether to benchmark each individual kernels")', # noqa: B950, line too long
'parser.add_argument("--benchmark-all-configs", "-c", action="store_true", help="Whether to benchmark each individual config for a kernel")', # noqa: B950, line too long
"args = parser.parse_args()",
"",
"if args.benchmark_kernels:",
]
)
with output.indent():
output.writeline(f"benchmark_all_kernels('{get_benchmark_name()}')")
output.writeline(
f"benchmark_all_kernels('{get_benchmark_name()}', args.benchmark_all_configs)"
)
output.writeline("else:")
with output.indent():
output.writeline("benchmark_compiled_module()")
Expand Down
33 changes: 23 additions & 10 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,29 @@ def is_fbcode():
# warnings intended for PyTorch developers, disable for point releases
developer_warnings = is_fbcode() or "+" in torch.__version__

compile_threads = (
1
if sys.platform == "win32" or is_fbcode()
else min(
32,
len(os.sched_getaffinity(0))
if hasattr(os, "sched_getaffinity")
else os.cpu_count(),
)
)

def decide_compile_threads():
"""
Here are the precedence to decide compile_threads
1. User can override it by TORCHINDUCTOR_COMPILE_THREADS. One may want to disable async compiling by
setting this to 1 to make pdb happy.
2. Set to 1 if it's win32 platform or it's a fbcode build
3. decide by the number of CPU cores
"""
if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ:
return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"])
elif sys.platform == "win32" or is_fbcode():
return 1
else:
return min(
32,
len(os.sched_getaffinity(0))
if hasattr(os, "sched_getaffinity")
else os.cpu_count(),
)


compile_threads = decide_compile_threads()

# autotuning global cache path
if is_fbcode():
Expand Down
17 changes: 13 additions & 4 deletions torch/_inductor/triton_ops/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def kernel_call():
return do_bench(kernel_call, rep=40, fast_flush=True)

@dynamo_timed
def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
def benchmark_all_configs(self, *args, **kwargs):
from ..compile_fx import clone_preserve_strides

# clone inplace buffers to avoid autotune contaminating them if
Expand All @@ -171,9 +170,14 @@ def autotune_to_one_config(self, *args, **kwargs):
cloned_args.append(arg)

timings = {
launcher: self.bench(launcher, *cloned_args, **kwargs)
launcher: self.bench(launcher, *cloned_args, **kwargs)[0]
for launcher in self.launchers
}
return timings

def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
timings = self.benchmark_all_configs(*args, **kwargs)
self.launchers = [builtins.min(timings, key=timings.get)]
if self.save_cache_hook:
self.save_cache_hook(self.launchers[0].config)
Expand Down Expand Up @@ -313,8 +317,13 @@ def cached_autotune(
configs = unique_configs(configs)
assert len(configs) == 1 or filename

# The autotune cache will simply replace the list of candidate configs with
# the best config cached. We don't want that when we benchmark triton kernels.
# We need the perf for each of the candidate config instead.
cache_autotune_result = not config.benchmark_kernel

# on disk caching logic
if filename is not None and len(configs) > 1:
if cache_autotune_result and filename is not None and len(configs) > 1:
cache_filename = os.path.splitext(filename)[0] + ".best_config"
configs_hash = hash_configs(configs)
best_config = load_cached_autotuning(cache_filename, configs_hash, configs)
Expand Down
36 changes: 26 additions & 10 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def get_benchmark_name():
return arg[len("--only=") :]


def benchmark_all_kernels(benchmark_name):
def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
"""
An experimental API used only when config.benchmark_kernel is true.

Expand All @@ -638,18 +638,34 @@ def benchmark_all_kernels(benchmark_name):
if not hasattr(kernel_mod, "get_args") or not hasattr(kernel_mod, "call"):
continue
args = kernel_mod.get_args()
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
num_gb = get_num_bytes(*args) / 1e9
gb_per_s = num_gb / (ms / 1e3)

# follow what we do in DebugAutotuner
info_str = f"{benchmark_name:20} {kernel_key[:10]} {ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
import colorama

if ms > 0.012 and gb_per_s < 650:
print(colorama.Fore.RED + info_str + colorama.Fore.RESET)
def get_info_str(ms, prefix=""):
gb_per_s = num_gb / (ms / 1e3)
# follow what we do in DebugAutotuner
info_str = f"{prefix}{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s"
import colorama

if ms > 0.012 and gb_per_s < 650:
info_str = colorama.Fore.RED + info_str + colorama.Fore.RESET
return info_str

bench_result = []
if benchmark_all_configs:
assert hasattr(kernel_mod, "benchmark_all_configs")
bench_result = kernel_mod.benchmark_all_configs(args)
bench_result = [
(launcher.config, ms) for launcher, ms in bench_result.items()
]
print(f"{benchmark_name:20} {kernel_key[:10]}")
for cfg, ms in bench_result:
print(f" {get_info_str(ms)} @ {cfg}")
else:
print(info_str)
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)[0]
assert (
len(kernel_mod.triton_.launchers) == 1
), "Autotuner should have selected the best config"
print(get_info_str(ms, prefix=f"{benchmark_name:20} {kernel_key[:10]} "))

nfound += 1
if nfound == 0:
Expand Down