Skip to content

Commit

Permalink
[inductor] fix cpp_wrapper inputs mismatch (#116197)
Browse files Browse the repository at this point in the history
Summary: fixes #115035, where in the cpp_wrapper JIT inductor, the input args should contain the lifted parameters.

Pull Request resolved: #116197
Approved by: https://github.com/jansel
  • Loading branch information
desertfire authored and pytorchmergebot committed Dec 26, 2023
1 parent 7571511 commit e5bcfe2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 27 deletions.
1 change: 1 addition & 0 deletions test/inductor/test_cuda_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class BaseTest(NamedTuple):
BaseTest("test_index_put_deterministic_fallback"),
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_index_tensor"),
BaseTest("test_layer_norm"),
BaseTest("test_linear1"),
BaseTest("test_linear2"),
BaseTest("test_mm_views"),
Expand Down
53 changes: 28 additions & 25 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,12 @@ def _run_and_assert_no_indirect_indexing(test_case, func, *args, **kwargs):
return result


def assertGeneratedKernelCountEqual(self: TestCase, expected: int):
if config.cpp_wrapper:
expected *= 2
self.assertEqual(torch._inductor.metrics.generated_kernel_count, expected)


class SweepInputs2:
input_gen_types1 = [
"dense",
Expand Down Expand Up @@ -804,7 +810,7 @@ def fn(sa, ct, p):
torch.randn(26),
),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_forced_buffer_realize(self):
# Test torch._test_inductor_realize forces a buffer to be realized
Expand Down Expand Up @@ -839,10 +845,7 @@ def fn(sa, ct, p):
),
)
self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 5)
self.assertEqual(
torch._inductor.metrics.generated_kernel_count,
1 if self.device == "cuda" else 3,
)
assertGeneratedKernelCountEqual(self, 1 if self.device == "cuda" else 3)

def test_index_propagation(self):
def flip(x):
Expand Down Expand Up @@ -2902,7 +2905,7 @@ def fn(x):
(torch.randn(2, 4, 21, 21),),
check_lowp=False,
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

def test_multi_threading(self):
model = torch.nn.Linear(2, 3).eval()
Expand Down Expand Up @@ -3159,7 +3162,7 @@ def fn(x):
fn,
(torch.randn([16, 64, 55, 55]),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

# From https://github.com/pytorch/pytorch/issues/94775
def test_max_pool2d7(self):
Expand All @@ -3185,7 +3188,7 @@ def fn(x):
fn,
(torch.randn([2, 2, 3, 6]),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

def test_avg_pool2d1(self):
def fn(x):
Expand Down Expand Up @@ -3265,7 +3268,7 @@ def fn(x):
fn,
(-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

def test_avg_pool2d8(self):
# https://github.com/pytorch/pytorch/issues/100987
Expand Down Expand Up @@ -3569,7 +3572,7 @@ def test_layer_norm(self):
with torch.no_grad():
self.common(m, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_transpose_add(self):
def fn(a, b):
Expand All @@ -3579,7 +3582,7 @@ def fn(a, b):
fn, (torch.randn([16, 32]), torch.randn([32, 16])), check_lowp=False
)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

@patch.object(config.triton, "persistent_reductions", True)
def test_softmax_one_kernel_persist(self):
Expand All @@ -3592,7 +3595,7 @@ def fn(x):

self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

@patch.object(config.triton, "persistent_reductions", False)
def test_softmax_one_kernel_loop(self):
Expand All @@ -3604,7 +3607,7 @@ def fn(x):

self.common(fn, (torch.randn([16, 32]),), check_lowp=False)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_complex_fallback(self):
def fn(x):
Expand All @@ -3614,7 +3617,7 @@ def fn(x):
fn,
(torch.randn([1, 2, 4, 8]).to(dtype=torch.complex64),),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

class ToComplex(nn.Module):
def forward(self, x):
Expand All @@ -3623,7 +3626,7 @@ def forward(self, x):
self.common(ToComplex(), (torch.rand([1, 2, 4, 8]),), check_lowp=False)

if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_view_as_complex(self):
class Repro(torch.nn.Module):
Expand Down Expand Up @@ -3675,7 +3678,7 @@ def fn(x, y):
check_lowp=False,
)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_gather_scatter(self):
def fn(node_feat, edge_index):
Expand All @@ -3701,7 +3704,7 @@ def fn(node_feat, edge_index):
check_lowp=False,
)
if self.device != "cpu":
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
assertGeneratedKernelCountEqual(self, 2)

@config.patch(max_fusion_size=1)
def test_no_mega_fusion_during_lowering(self):
Expand All @@ -3728,7 +3731,7 @@ def fn(x):

self.common(fn, (torch.randn([32]),), check_lowp=False)
# if we have a copy there will be more than 1 kernel
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_leaky_relu(self):
def fn(x):
Expand Down Expand Up @@ -6230,7 +6233,7 @@ def fn(a, b, c):
indices,
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_max_pool2d_with_indices_backward5(self):
# Window size is too big. Should fallback
Expand All @@ -6257,7 +6260,7 @@ def fn(a, b, c):
indices,
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

# From https://github.com/pytorch/pytorch/issues/93384
def test_max_pool2d_with_indices_backward6(self):
Expand Down Expand Up @@ -6285,7 +6288,7 @@ def fn(a, b, c):
indices,
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

def test_issue102546(self):
def fn(x):
Expand Down Expand Up @@ -6356,7 +6359,7 @@ def fn(a, b):
torch.randn([1, 2016, 21, 21]),
],
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_avg_pool2d_backward4(self):
def fn(a, b):
Expand All @@ -6380,7 +6383,7 @@ def fn(a, b):
],
check_lowp=False,
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

@config.patch(search_autotune_cache=False)
def test_mm_views(self):
Expand Down Expand Up @@ -7316,7 +7319,7 @@ def forward(arg6, arg7, arg16):
)

# expanded dim should not cause copy in require_stride_order
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0)
assertGeneratedKernelCountEqual(self, 0)

@requires_cuda()
@unittest.skipIf(
Expand Down Expand Up @@ -7913,7 +7916,7 @@ def fn(input, offsets, add_value):

self.common(fn, (input, boundaries, add_value), check_lowp=False)

self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
assertGeneratedKernelCountEqual(self, 1)

def test_bucketize_computed_offsets(self):
def fn(inp, offsets):
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def fx_codegen_and_compile(
# example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning.
# For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass,
# we currently use fake tensors and defake them later.
example_inputs=V.real_inputs if is_inference else example_inputs,
example_inputs=example_inputs,
shape_env=shape_env,
num_static_inputs=num_fixed,
graph_id=graph_id,
Expand Down
17 changes: 16 additions & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
import operator
import os
Expand Down Expand Up @@ -1038,9 +1039,23 @@ def materialize(x):
), "Unknown type when creating real inputs" + str(type(x))
return x

if tracing_context := torch._guards.TracingContext.try_get():
if tracing_context.output_strides:
tracing_context.output_strides.clear()

params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
]
else:
real_inputs = [materialize(x) for x in V.real_inputs]

with torch.utils._python_dispatch._disable_current_modes():
assert self.example_inputs is not None
real_inputs = [materialize(x) for x in self.example_inputs]
compiled(real_inputs)
del real_inputs

Expand Down

0 comments on commit e5bcfe2

Please sign in to comment.