Skip to content

Commit

Permalink
[aotinductor] Avoid generating redundant kernel loading code
Browse files Browse the repository at this point in the history
Summary: 1) Stop forcing triton.unique_kernel_names to True for AOTInductor, because the unique kernel name can be read from metadata; 2) Only generate load_kernel once for each kernel since we don't have control flow in our generated code.  This solves #105553.

ghstack-source-id: 4725771e16833a6fd2e49cc4488e686f38fe9b00
Pull Request resolved: #110510
  • Loading branch information
desertfire committed Oct 5, 2023
1 parent cf1b494 commit 52ab93f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 24 deletions.
38 changes: 34 additions & 4 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# Owner(s): ["module: inductor"]
import copy
import os
import sys
import tempfile
import unittest

import torch
import torch._export
import torch._inductor

import torch.fx._pytree as fx_pytree
from torch._dynamo.testing import same
from torch._inductor import config
from torch._inductor.utils import aot_inductor_launcher

from torch.testing import FileCheck

from torch.testing._internal.common_utils import (
IS_CI,
IS_FBCODE,
Expand Down Expand Up @@ -183,7 +185,7 @@ def forward(self, x, y):
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
self.check_model(Model().to(self.device), example_inputs)
self.check_model(Model(), example_inputs)

def test_output_path(self):
class Model(torch.nn.Module):
Expand All @@ -199,7 +201,7 @@ def forward(self, x, y):
torch.randn(10, 10, device=self.device),
)
with config.patch("aot_inductor.output_path", "tmp_output_"):
self.check_model(Model().to(self.device), example_inputs)
self.check_model(Model(), example_inputs)

@requires_cuda()
def test_multi_device(self):
Expand Down Expand Up @@ -227,7 +229,7 @@ def forward(self, x, y):
torch.randn(1, 250112, device=self.device),
torch.randn(1, 512, device=self.device),
)
self.check_model(Model().to(self.device), example_inputs)
self.check_model(Model(), example_inputs)

def test_with_offset(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -781,6 +783,34 @@ def forward(self, x, y):
self.assertTrue(same(result_cpu, result_cuda_0.cpu()))
self.assertTrue(same(result_cpu, result_cuda_1.cpu()))

def test_reuse_kernel(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
a = torch.sin(x)
b = torch.mm(a, y)
c = torch.sin(b)
d = torch.mm(b, c)
return d

example_inputs = (
torch.randn(87, 87, device=self.device),
torch.randn(87, 87, device=self.device),
)
self.check_model(Model(), example_inputs)

if self.device == "cuda":
so_path, _ = torch._export.aot_compile(Model(), example_inputs)
with open(os.path.splitext(so_path)[0] + ".cpp") as cpp:
src_code = cpp.read()
FileCheck().check_count(
"triton_poi_fused_sin_0 = loadKernel(",
1,
exactly=True,
).run(src_code)


class AOTInductorTestABICompatibleCpu(TestCase):
device = "cpu"
Expand Down
32 changes: 15 additions & 17 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,28 +2009,18 @@ def generate(self):
self.prefix.writeline("\n")
return super().generate()

def generate_load_kernel(self, name, params):
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get("cubin_path", None)
assert os.path.exists(
cubin_path
), f"cubin file should already exist at this moment: {cubin_path}"

shared_mem = params.get("shared_mem", 0)

@functools.lru_cache(None)
def generate_load_kernel_once(
self, name: str, mangled_name: str, cubin_path: str, shared_mem: int
):
if V.graph.aot_mode:
self.writeline(f"if (kernels.{name} == nullptr) {{")
self.writeline(
f""" kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
f"""kernels.{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem}, this->cubin_dir_);"""
)
self.writeline("}")
else:
self.writeline(f"if ({name} == nullptr) {{")
self.writeline(
f""" {name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
f"""{name} = loadKernel("{cubin_path}", "{mangled_name}", {shared_mem});"""
)
self.writeline("}")

def generate_args_decl(self, call_args):
dynamic_symbols = V.graph.sizevars.free_symbols()
Expand Down Expand Up @@ -2092,6 +2082,7 @@ def generate_kernel_call(
self, name, call_args, grid=None, device_index=None, cuda=True, triton=True
):
if not cuda:
# Even in CudaWrapperCodeGen, we may see cpp kernels
return super().generate_kernel_call(
name, call_args, grid, device_index, cuda, triton
)
Expand All @@ -2100,8 +2091,15 @@ def generate_kernel_call(
assert (
params is not None
), f"cuda kernel parameters for {name} should already exist at this moment"
mangled_name = params.get("mangled_name", None)
assert mangled_name is not None, "missing mangled_name"
cubin_path = params.get("cubin_path", None)
assert cubin_path is not None and os.path.exists(
cubin_path
), f"cubin file should already exist at this moment: {cubin_path}"
shared_mem = params.get("shared_mem", 0)

self.generate_load_kernel(name, params)
self.generate_load_kernel_once(name, mangled_name, cubin_path, shared_mem)

call_args = self.generate_args_decl(call_args)
kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}"
Expand Down
2 changes: 0 additions & 2 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,6 @@ def compile_fx(
"cpp_wrapper": False,
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False,
# CudaWrapperCodeGen relies on kernel name to find the autotuned cubin file
"triton.unique_kernel_names": True,
}
), V.set_real_inputs(
example_inputs_
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ def save_cuda_kernel(self, grid, stream, launcher):
else:
grid_x, grid_y, grid_z = grid

key = launcher.fn.fn.__qualname__ # unique kernel name
key = self.meta.get("kernel_name", None) # unique kernel name
assert key is not None, "kernel_name can not be None"
params = {
"mangled_name": launcher.bin.metadata["name"],
"grid_x": grid_x,
Expand Down

0 comments on commit 52ab93f

Please sign in to comment.