Skip to content

Commit

Permalink
Revert "[aoti] clear precomputed symbol replacements before cpp wrapp…
Browse files Browse the repository at this point in the history
…er compilation (#122882)"

This reverts commit 384de46.

Reverted #122882 on behalf of https://github.com/jithunnair-amd due to broke ROCm CI ([comment](#122882 (comment)))
  • Loading branch information
pytorchmergebot committed Mar 29, 2024
1 parent 2a137f7 commit a236fa9
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 67 deletions.
65 changes: 0 additions & 65 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,70 +1137,6 @@ def forward(self, x, y):
exactly=True,
).run(src_code)

def test_reuse_kernel_dynamic(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.cst = torch.randn(48, device=device, dtype=torch.float)
self.weights = torch.randn(6, 48, 48, device=device, dtype=torch.float)
self.cst_1 = torch.randn(48, device=device, dtype=torch.float)
self.weights_1 = torch.randn(
6, 48, 48, device=device, dtype=torch.float
)

def forward(self, x, y, z):
dim0 = x.size(1)
add_0 = z + z
expand_2 = add_0.expand(-1, -1, 48)
# [s0, 6, 48]
mul_3 = add_0 * expand_2
# [6, s0, 48]
permute_4 = torch.permute(mul_3, (1, 0, 2))
# [6, s0, 48]
bmm_5 = torch.bmm(permute_4, self.weights)
add_6 = bmm_5 + self.cst
reshape_7 = torch.reshape(add_6, [6, dim0 * 6, 8])
# [6*s0, 6, 8]
permute_8 = torch.permute(reshape_7, (1, 0, 2))
mul_9 = permute_8 * 0.123
reshape_10 = torch.reshape(y, [8, dim0 * 6, 4])
# [6*s0, 8, 4]
permute_11 = torch.permute(reshape_10, (1, 0, 2))
bmm_12 = torch.bmm(mul_9, permute_11)

add_0_1 = z + z
expand_2_1 = add_0_1.expand(-1, -1, 48)
# [s0, 6, 48]
mul_3_1 = add_0_1 * expand_2_1
# [6, s0, 48]
permute_4_1 = torch.permute(mul_3_1, (1, 0, 2))
# [6, s0, 48]
bmm_5_1 = torch.bmm(permute_4_1, self.weights_1)
add_6_1 = bmm_5_1 + self.cst_1
reshape_7_1 = torch.reshape(add_6_1, [6, dim0 * 6, 8])
# [6*s0, 6, 8]
permute_8_1 = torch.permute(reshape_7_1, (1, 0, 2))
mul_9_1 = permute_8_1 * 0.123
reshape_10_1 = torch.reshape(y, [8, dim0 * 6, 4])
# [6*s0, 8, 4]
permute_11_1 = torch.permute(reshape_10_1, (1, 0, 2))
bmm_12_1 = torch.bmm(mul_9_1, permute_11_1)
return bmm_12 + bmm_12_1

x = torch.randn(6, 2, 48, device=self.device, dtype=torch.float)
y = torch.randn(48, 2, 4, device=self.device, dtype=torch.float)
z = torch.randn(2, 6, 1, device=self.device, dtype=torch.float)
dim0 = Dim("dim0", min=1, max=2048)
dynamic_shapes = {
"x": {1: dim0},
"y": {1: dim0},
"z": {0: dim0},
}

example_inputs = (x, y, z)
m = Model(self.device).to(dtype=torch.float)
self.check_model(m, example_inputs, dynamic_shapes=dynamic_shapes)

def test_fake_tensor_device_validation(self):
if self.device != "cuda":
raise unittest.SkipTest("requires CUDA")
Expand Down Expand Up @@ -2463,7 +2399,6 @@ def fail_non_abi_compatible_cuda(is_skip=False):
"test_repeat_interleave": fail_minimal_arrayref_interface(is_skip=True),
"test_return_constant": fail_minimal_arrayref_interface(is_skip=True),
"test_reuse_kernel": fail_minimal_arrayref_interface(is_skip=True),
"test_reuse_kernel_dynamic": fail_minimal_arrayref_interface(is_skip=True),
"test_simple": fail_minimal_arrayref_interface(is_skip=True),
"test_small_constant": fail_minimal_arrayref_interface(is_skip=True),
"test_with_no_triton_profiler": fail_minimal_arrayref_interface(
Expand Down
2 changes: 0 additions & 2 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,8 +1273,6 @@ def materialize(x):
self.cpp_wrapper = True
self.removed_buffers.clear()
self.inplaced_to_remove.clear()
V.graph.sizevars.precomputed_replacements.clear()
V.graph.sizevars.inv_precomputed_replacements.clear()
return self.codegen()
else:
# cpu
Expand Down

0 comments on commit a236fa9

Please sign in to comment.