Skip to content

Commit

Permalink
[AOTI] Improve the two-pass wrapper codegen (#114067)
Browse files Browse the repository at this point in the history
Summary: For the second-pass, we don't have to rerun the whole inductor flow again. This PR moves that second-pass to the codegen time. This change not only speeds up the compilation, but also removes kernel scheduling inconsistency between the two passes. Another future improvement is to make the second-pass reuse the scheduler and do the wrapper codegen only.

This is a copy of #113762 to land in github first.

Pull Request resolved: #114067
Approved by: https://github.com/chenyang78
  • Loading branch information
desertfire authored and pytorchmergebot committed Nov 19, 2023
1 parent 226384b commit 5a96a42
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 106 deletions.
41 changes: 41 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_quantization import skip_if_no_torchvision

from torch.testing._internal.common_utils import (
IS_CI,
Expand Down Expand Up @@ -1060,6 +1061,42 @@ def forward(self, x):
example_inputs = (torch.randn(3, 10, device=self.device),)
self.check_model(Model(), example_inputs)

@skip_if_no_torchvision
def test_missing_cubin(self):
from torchvision.models.resnet import Bottleneck, ResNet

class Model(ResNet):
def __init__(self):
super().__init__(
block=Bottleneck,
layers=[3, 4, 6, 3],
replace_stride_with_dilation=[False, False, True],
norm_layer=None,
)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
f1 = x
x = self.maxpool(x)
x = self.layer1(x)
f2 = x
x = self.layer2(x)
f3 = x
x = self.layer3(x)
x = self.layer4(x)
f4 = x
return [f1, f2, f3, f4]

# Call eval() here so that batch_norm won't update the running stats
# Use float64 to avoid numeric difference failure
model = Model().to(device=self.device, dtype=torch.float64).eval()
example_inputs = (
torch.randn(4, 3, 64, 64, device=self.device, dtype=torch.float64),
)
self.check_model(model, example_inputs)

@common_utils.parametrize("grid_type", [1, 2, 3])
@common_utils.parametrize("num_dims", [1, 2])
@common_utils.parametrize("dynamic", [False, True])
Expand Down Expand Up @@ -1194,6 +1231,8 @@ class AOTInductorTestABICompatibleCpu(TestCase):
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
# NotImplementedError: Cannot access storage of OpaqueTensorImpl
"test_freezing": TestFailure(("abi_compatible_cpu",), is_skip=True),
# Need to support convolution
"test_missing_cubin": TestFailure(("abi_compatible_cpu",)),
"test_normal_functional": TestFailure(("abi_compatible_cpu",)),
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
# There is a double-free issue which will be fixed in another PR
Expand All @@ -1219,6 +1258,8 @@ class AOTInductorTestABICompatibleCuda(TestCase):
# test_failures, xfail by default, set is_skip=True to skip
{
"test_dup_unbacked_sym_decl": TestFailure(("abi_compatible_cuda",)),
# Need to support convolution
"test_missing_cubin": TestFailure(("abi_compatible_cuda",)),
"test_normal_functional": TestFailure(("abi_compatible_cuda",)),
# There is a double-free issue which will be fixed in another PR
"test_repeat_output": TestFailure(("abi_compatible_cuda",), is_skip=True),
Expand Down
15 changes: 2 additions & 13 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ def common(
if len(codes) == 1:
codes = codes[0]
torch.testing.assert_close(actual, expected)
if inductor_config.cpp_wrapper:
# CPP wrapper runs everything twice, so we'll match the pattern twice
expected_matches *= 2
expected_nodes *= 2

self.assertEqual(
counters["inductor"]["pattern_matcher_count"], expected_matches
Expand Down Expand Up @@ -519,13 +515,6 @@ def fn(a, b, c):
self.common(fn, args, 2, 5)

def test_cat_slice_cat(self):
def check_counter(counter, expected):
if not inductor_config.cpp_wrapper:
self.assertEqual(counter, expected)
else:
# cpp_wrapper for the CUDA backend runs two passes
self.assertEqual(counter, 2 * expected)

def fn(a, b):
cat_1 = torch.ops.aten.cat.default([a, b], 1)
slice_1 = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807)
Expand All @@ -548,8 +537,8 @@ def fn(a, b):
torch.testing.assert_close(actual, expected)
# We don't recompile for dynamic-shape cases.
if dynamo_config.assume_static_by_default:
check_counter(counters["inductor"]["pattern_matcher_count"], 1)
check_counter(counters["inductor"]["pattern_matcher_nodes"], 3)
self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1)
self.assertEqual(counters["inductor"]["pattern_matcher_nodes"], 3)

# Verify we fallback to non-optimal path for negative `end`.
def fn(a, b):
Expand Down
21 changes: 7 additions & 14 deletions test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ def wrapped(*args, **kwargs):


class TestSelectAlgorithm(TestCase):
def check_counter(self, counter, expected):
if not inductor_config.cpp_wrapper:
self.assertEqual(counter, expected)
else:
# cpp_wrapper for the CUDA backend runs two passes
self.assertEqual(counter, 2 * expected)

@expectedFailureDynamicWrapper
@patches
def test_linear_relu(self):
Expand All @@ -64,7 +57,7 @@ def foo(input, weight, bias):
torch.randn(1, 16, device="cuda"),
)
# Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
# It would be nice to assert this got fused into a single kernel, but that
# only happens if we select a triton template (and not aten).

Expand All @@ -82,7 +75,7 @@ def foo(input, weight, bias):
)

foo(*inps)
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
@patches
Expand Down Expand Up @@ -112,7 +105,7 @@ def foo(a, b):
torch.randn(8, 32, device="cuda"),
torch.randn(32, 8, device="cuda"),
)
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test__int_mm(self):
Expand Down Expand Up @@ -206,7 +199,7 @@ def foo(a, b, c, d):
torch.randn(512, 512, device="cuda"),
)
# Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm_dup_args(self):
Expand All @@ -215,7 +208,7 @@ def foo(a):
return torch.mm(a, a)

foo(torch.randn(32, 32, device="cuda"))
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@patches
def test_mm_dup_args_view(self):
Expand All @@ -226,7 +219,7 @@ def foo(a):
return torch.mm(q, k.transpose(0, 1))

foo(torch.randn(64, 64, device="cuda"))
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@skipIfRocm
@expectedFailureDynamicWrapper
Expand All @@ -252,7 +245,7 @@ def foo(x, w, b):
torch.randn(34, device="cuda"),
)
# Autotuning checks correctness of each version
self.check_counter(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@skipIfRocm
@patches
Expand Down
81 changes: 6 additions & 75 deletions torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from .fx_passes.pre_grad import pre_grad_passes
from .graph import GraphLowering
from .ir import ExternKernelNode
from .pattern_matcher import clone_graph
from .utils import get_dtype_size, has_incompatible_cudagraph_ops
from .virtualized import V

Expand Down Expand Up @@ -217,79 +216,6 @@ def count_bytes_inner(
return make_boxed_func(gm.forward)


def inner_compile_with_cpp_wrapper(inner_compile: Callable[..., Any]):
@functools.wraps(inner_compile)
def wrapper(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
"""
Compile into cpp wrapper:
For CPU, this is currently done in one pass.
For GPU, this is done in two passes: JIT-compile the model with python wrapper code
and run it to generate autotuned kernel binaries in the first pass; and then generate
cpp wrapper code and compile it to a dynamic library in the second pass.
"""
devices = (
{t.device.type for t in gm.parameters()}
| {t.device.type for t in gm.buffers()}
| {t.device.type for t in example_inputs if isinstance(t, torch.Tensor)}
)

if "cuda" not in devices:
kwargs_patched = {**kwargs, "cpp_wrapper": True}
return inner_compile(gm, example_inputs, **kwargs_patched)
else:
with config.patch(
{
"triton.store_cubin": True,
}
):
# first pass with regular python wrapper code
kwargs_patched = {
**kwargs,
"cpp_wrapper": False,
}
# clone_graph(gm) makes sure no graph modification from the first pass will
# leak to the second pass. It does increase memory pressure, but the problem
# can be alleviated once we have parameters as FakeTensor.

compiled = inner_compile(
clone_graph(gm), example_inputs, **kwargs_patched
)

def materialize(x):
if isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
else:
assert not isinstance(x, FakeTensor)
return x

if tracing_context := torch._guards.TracingContext.try_get():
if tracing_context.output_strides:
tracing_context.output_strides.clear()

params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x) for x in (params_flat + V.real_inputs)
]
else:
real_inputs = [materialize(x) for x in V.real_inputs]

with torch.utils._python_dispatch._disable_current_modes():
compiled(real_inputs)

del real_inputs

# second pass
kwargs_patched = {**kwargs, "cpp_wrapper": True}
return inner_compile(gm, example_inputs, **kwargs_patched)

return wrapper


def fake_tensor_prop(
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
Expand Down Expand Up @@ -592,6 +518,10 @@ def fx_codegen_and_compile(
with V.set_fake_mode(fake_mode):
graph = GraphLowering(
gm,
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
# we currently use fake tensors and defake them later.
example_inputs=V.real_inputs if is_inference else example_inputs,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
Expand Down Expand Up @@ -1033,6 +963,7 @@ def compile_fx(
"cpp_wrapper": False,
"triton.autotune_cublasLt": False,
"triton.cudagraphs": False,
"triton.store_cubin": True,
}
), V.set_real_inputs(example_inputs_):
inputs_ = example_inputs_
Expand All @@ -1055,7 +986,7 @@ def compile_fx(
return compile_fx(
model_,
inputs_,
inner_compile=inner_compile_with_cpp_wrapper(inner_compile),
inner_compile=functools.partial(inner_compile, cpp_wrapper=True),
decompositions=decompositions,
)

Expand Down
54 changes: 50 additions & 4 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import torch._logging
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import dynamo_timed
from torch._dynamo.utils import defake, dynamo_timed
from torch._logging import LazyString
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
from torch.utils._mode_utils import no_dispatch
Expand Down Expand Up @@ -164,6 +165,7 @@ def init_backend_registration(self):
def __init__(
self,
gm: torch.fx.GraphModule,
example_inputs: Optional[List[torch.Tensor]] = None,
shape_env=None,
num_static_inputs=None,
graph_id=None,
Expand All @@ -176,6 +178,7 @@ def __init__(
):
super().__init__(gm)

self.example_inputs = example_inputs
self.layout_opt = (
layout_opt if layout_opt is not None else self.decide_layout_opt(gm)
)
Expand Down Expand Up @@ -921,6 +924,46 @@ def init_wrapper_code(self):
assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
self.wrapper_code = wrapper_code_gen_cls()

def codegen_with_cpp_wrapper(self):
"""
For CPU, the cpp wrapper codegen is done in one pass.
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
generate cpp wrapper code and compile it to a dynamic library in the second pass.
"""
if "cuda" in self.device_types:
# first pass
self.cpp_wrapper = False
compiled = self.compile_to_module().call

def materialize(x):
if isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
elif isinstance(x, FakeTensor):
return defake(x)
else:
assert isinstance(
x, torch.Tensor
), "Unknown type when creating real inputs"
return x

with torch.utils._python_dispatch._disable_current_modes():
assert self.example_inputs is not None
real_inputs = [materialize(x) for x in self.example_inputs]
compiled(real_inputs)
del real_inputs

# second pass
# TODO: reuse self.scheduler from the first pass to speed up the second pass
self.cpp_wrapper = True
self.removed_buffers.clear()
self.inplaced_to_remove.clear()
return self.codegen()
else:
# cpu
return self.codegen()

def codegen(self):
from .scheduler import Scheduler

Expand Down Expand Up @@ -952,7 +995,9 @@ def count_bytes(self):
def compile_to_module(self):
from .codecache import PyCodeCache

code, linemap = self.codegen()
code, linemap = (
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
)
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
key, path = PyCodeCache.write(code)
mod = PyCodeCache.load_by_key_path(
Expand All @@ -975,10 +1020,11 @@ def compile_to_module(self):
return mod

def compile_to_fn(self):
if self.aot_mode and self.cpp_wrapper:
if self.aot_mode:
from .codecache import AotCodeCache

code, linemap = self.codegen()
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
code, linemap = self.codegen_with_cpp_wrapper()
output_code_log.debug("Output code: \n%s", code)

serialized_extern_kernel_nodes = None
Expand Down

0 comments on commit 5a96a42

Please sign in to comment.