Skip to content

Commit

Permalink
[aoti] clear precomputed symbol replacements before cpp wrapper compi…
Browse files Browse the repository at this point in the history
…lation (pytorch#122882)

After we codegen a triton kernel in the triton codegen backend,
we cache the generated triton source code in the wrapper to avoid
producing multiple triton kernels with the same content.

In AOTI compilation flow, this caching mechanism imposes a strong requirement
on the codegen that we must generate the same triton source code
for the same schedule node in both python and cpp codegen phases.
Otherwise, we would end up with a mismatch between the kernel name
formed in the cpp codegen and the cuda kernel key produced from
the python codegen. Consequently, we would hit an missing-cuda-kernel
error.

The precomputed symbol replacements saved in V.graph.sizevars
can cause such source-code inconsistency related to the code for indexing
tensors. For example, let's say in the python codegen phase,
we produce "ks2\*48" as part of indexing an input for schedule
node A while yielding a replacement pair "ks0 -> ks2\*48" in
the precomputed replacements. In the second cpp codegen phase,
we would produce "ks0" for the same indexing code of schedule
node A due to the "ks0 -> ks2*48" replacement pair.

This PR fixed the issue by clearing precomputed_replacements
and inv_precomputed_replacements before cpp wrapper codegen.

Pull Request resolved: pytorch#122882
Approved by: https://github.com/desertfire
  • Loading branch information
chenyang78 authored and Sanket Jayant Purandare committed Apr 22, 2024
1 parent ff24f30 commit fb964ba
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
65 changes: 65 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,70 @@ def forward(self, x, y):
exactly=True,
).run(src_code)

def test_reuse_kernel_dynamic(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.cst = torch.randn(48, device=device, dtype=torch.float)
self.weights = torch.randn(6, 48, 48, device=device, dtype=torch.float)
self.cst_1 = torch.randn(48, device=device, dtype=torch.float)
self.weights_1 = torch.randn(
6, 48, 48, device=device, dtype=torch.float
)

def forward(self, x, y, z):
dim0 = x.size(1)
add_0 = z + z
expand_2 = add_0.expand(-1, -1, 48)
# [s0, 6, 48]
mul_3 = add_0 * expand_2
# [6, s0, 48]
permute_4 = torch.permute(mul_3, (1, 0, 2))
# [6, s0, 48]
bmm_5 = torch.bmm(permute_4, self.weights)
add_6 = bmm_5 + self.cst
reshape_7 = torch.reshape(add_6, [6, dim0 * 6, 8])
# [6*s0, 6, 8]
permute_8 = torch.permute(reshape_7, (1, 0, 2))
mul_9 = permute_8 * 0.123
reshape_10 = torch.reshape(y, [8, dim0 * 6, 4])
# [6*s0, 8, 4]
permute_11 = torch.permute(reshape_10, (1, 0, 2))
bmm_12 = torch.bmm(mul_9, permute_11)

add_0_1 = z + z
expand_2_1 = add_0_1.expand(-1, -1, 48)
# [s0, 6, 48]
mul_3_1 = add_0_1 * expand_2_1
# [6, s0, 48]
permute_4_1 = torch.permute(mul_3_1, (1, 0, 2))
# [6, s0, 48]
bmm_5_1 = torch.bmm(permute_4_1, self.weights_1)
add_6_1 = bmm_5_1 + self.cst_1
reshape_7_1 = torch.reshape(add_6_1, [6, dim0 * 6, 8])
# [6*s0, 6, 8]
permute_8_1 = torch.permute(reshape_7_1, (1, 0, 2))
mul_9_1 = permute_8_1 * 0.123
reshape_10_1 = torch.reshape(y, [8, dim0 * 6, 4])
# [6*s0, 8, 4]
permute_11_1 = torch.permute(reshape_10_1, (1, 0, 2))
bmm_12_1 = torch.bmm(mul_9_1, permute_11_1)
return bmm_12 + bmm_12_1

x = torch.randn(6, 2, 48, device=self.device, dtype=torch.float)
y = torch.randn(48, 2, 4, device=self.device, dtype=torch.float)
z = torch.randn(2, 6, 1, device=self.device, dtype=torch.float)
dim0 = Dim("dim0", min=1, max=2048)
dynamic_shapes = {
"x": {1: dim0},
"y": {1: dim0},
"z": {0: dim0},
}

example_inputs = (x, y, z)
m = Model(self.device).to(dtype=torch.float)
self.check_model(m, example_inputs, dynamic_shapes=dynamic_shapes)

def test_fake_tensor_device_validation(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
Expand Down Expand Up @@ -2399,6 +2463,7 @@ def fail_non_abi_compatible_cuda(is_skip=False):
"test_repeat_interleave": fail_minimal_arrayref_interface(is_skip=True),
"test_return_constant": fail_minimal_arrayref_interface(is_skip=True),
"test_reuse_kernel": fail_minimal_arrayref_interface(is_skip=True),
"test_reuse_kernel_dynamic": fail_minimal_arrayref_interface(is_skip=True),
"test_simple": fail_minimal_arrayref_interface(is_skip=True),
"test_small_constant": fail_minimal_arrayref_interface(is_skip=True),
"test_with_no_triton_profiler": fail_minimal_arrayref_interface(
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,8 @@ def materialize(x):
self.cpp_wrapper = True
self.removed_buffers.clear()
self.inplaced_to_remove.clear()
V.graph.sizevars.precomputed_replacements.clear()
V.graph.sizevars.inv_precomputed_replacements.clear()
return self.codegen()
else:
# cpu
Expand Down

0 comments on commit fb964ba

Please sign in to comment.