Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] refactor: device dispatch inside do_bench #125736

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from torch._inductor.select_algorithm import TritonTemplateCaller

from . import config
from .runtime.runtime_utils import do_bench, do_bench_cpu
from .runtime.runtime_utils import do_bench_cpu, do_bench_gpu
from .virtualized import V

CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
Expand Down Expand Up @@ -592,7 +592,7 @@ def do_bench(
device_idx = torch.cuda.current_device()

with torch.cuda.device(device_idx):
out = do_bench(fn)
out = do_bench_gpu(fn)
torch.cuda.synchronize() # shake out any CUDA errors

return out
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/multi_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .. import config
from ..codecache import PyCodeCache, TritonFuture
from ..runtime.runtime_utils import do_bench
from ..runtime.runtime_utils import do_bench_gpu
from ..utils import cache_on_self
from ..virtualized import V
from .common import TensorArg
Expand Down Expand Up @@ -339,7 +339,7 @@ def benchmark_sub_kernels(kernel_calls):
be picked.
"""
return [
do_bench(lambda: kernel_call(True), rep=40, fast_flush=True)
do_bench_gpu(lambda: kernel_call(True), rep=40, fast_flush=True)
for kernel_call in kernel_calls
]

Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from ..optimize_indexing import indexing_dtype_strength_reduction
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
from ..runtime.runtime_utils import (
do_bench,
do_bench_gpu,
get_max_y_grid,
green_text,
next_power_of_2,
Expand Down Expand Up @@ -2653,7 +2653,7 @@ def codegen_kernel_benchmark(self, num_gb, grid=None):

result.writeline("args = get_args()")
result.writeline(
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
"ms = do_bench_gpu(lambda: call(args), rep=40, fast_flush=True)"
)
result.writeline(f"num_gb = {num_gb}")
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
Expand Down Expand Up @@ -4036,13 +4036,13 @@ def store_cache():
else:
# We have to clone the inplace updated arguments to avoid earlier calls
# generating out of range indices for later calls.
ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args)[0]))
ms = do_bench_gpu(lambda: call(wrapped_jit_function.clone_args(*args)[0]))

# overhead of cloning args gives bias for fusing the kernel
# in the case of mutating/in-placeable second fusion
# TODO - would be better as a hook in triton do_bench that reset
# the input values between benchmarking
ms = ms - do_bench(lambda: wrapped_jit_function.clone_args(*args))
ms = ms - do_bench_gpu(lambda: wrapped_jit_function.clone_args(*args))

log.debug(
"The fused kernel for %s took %.3f ms to run",
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/fx_passes/pad_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def should_pad_bench(
return False

do_bench = functools.partial(
torch._inductor.runtime.runtime_utils.do_bench,
torch._inductor.runtime.runtime_utils.do_bench_gpu,
warmup=5,
)

Expand Down
8 changes: 2 additions & 6 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
)
from .ops_handler import OpCounterCSE
from .runtime.hints import ReductionHint
from .runtime.runtime_utils import do_bench, do_bench_cpu
from .runtime.runtime_utils import do_bench
from .utils import (
argsort,
cache_on_self,
Expand All @@ -79,7 +79,6 @@
convert_shape_to_symint,
developer_warning,
get_kernel_metadata,
is_cpu_device,
is_dynamic,
is_gpu,
pad_listlike,
Expand Down Expand Up @@ -3628,10 +3627,7 @@ def __init__(self, name, input_nodes, layout):

def benchmark(self, *args, out) -> float:
algo = self.to_callable()
if is_cpu_device(args):
return do_bench_cpu(lambda: algo(*args, out=out))
else:
return do_bench(lambda: algo(*args, out=out))
return do_bench(algo, args, {"out": out})

def call_name(self) -> str:
raise NotImplementedError
Expand Down
13 changes: 12 additions & 1 deletion torch/_inductor/runtime/runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,18 @@ def get_max_y_grid():
return 65535


def do_bench(*args, **kwargs):
def do_bench(fn, fn_args, fn_kwargs, **kwargs):
from torch._inductor.utils import is_cpu_device

args = list(fn_args)
args.extend(fn_kwargs.values())
if is_cpu_device(args):
return do_bench_cpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
else:
return do_bench_gpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)


def do_bench_gpu(*args, **kwargs):
@functools.lru_cache(None)
def load_triton():
try:
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ceildiv,
conditional_product,
create_bandwidth_info_str,
do_bench,
do_bench_gpu,
dynamo_timed,
get_first_attr,
get_max_y_grid,
Expand Down Expand Up @@ -628,7 +628,7 @@ def kernel_call():
stream=stream,
)

return do_bench(kernel_call, rep=40, fast_flush=True)
return do_bench_gpu(kernel_call, rep=40, fast_flush=True)

def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
from ..compile_fx import clone_preserve_strides
Expand Down
8 changes: 2 additions & 6 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@
from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .runtime.hints import DeviceProperties
from .runtime.runtime_utils import do_bench, do_bench_cpu
from .runtime.runtime_utils import do_bench
from .utils import (
get_dtype_size,
is_cpu_device,
Placeholder,
restore_stdout_stderr,
sympy_dot,
Expand Down Expand Up @@ -845,10 +844,7 @@ def benchmark(self, *args, out):
out_new, tuple(out.size()), tuple(out.stride())
)
out.copy_(out_new) # for correctness checking
if is_cpu_device(args):
return do_bench_cpu(lambda: algo(*args))
else:
return do_bench(lambda: algo(*args))
return do_bench(algo, args, {})

def to_callable(self):
fn = self.choice.to_callable()
Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/wrapper_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import torch
from torch.autograd import DeviceType
from .runtime.runtime_utils import create_bandwidth_info_str, do_bench, get_num_bytes
from .runtime.runtime_utils import (
create_bandwidth_info_str,
do_bench_gpu,
get_num_bytes,
)

_kernel_category_choices = [
"foreach",
Expand Down Expand Up @@ -116,7 +120,7 @@ def get_info_str(ms, n_regs, n_spills, shared, prefix=""):
f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}"
)
else:
ms = do_bench(lambda: kernel_mod.call(args), rep=40, fast_flush=True)
ms = do_bench_gpu(lambda: kernel_mod.call(args), rep=40, fast_flush=True)
assert (
len(triton_kernel.launchers) == 1
), "Autotuner should have selected the best config"
Expand Down