Skip to content

Commit

Permalink
[inductor] allow clone cse cache during vectorized indirect load
Browse files Browse the repository at this point in the history
ghstack-source-id: 2f026a9472e118acc783edc57aefbc9db4e4d582
Pull Request resolved: #124597
  • Loading branch information
zhuhaozhe committed Apr 25, 2024
1 parent 81740fd commit 8edb728
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
35 changes: 35 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,7 @@ def fn(x):
x = torch.randn(1, 32, 16, 68)
opt_fn = torch._dynamo.optimize("inductor")(fn)
_, code = run_and_get_cpp_code(opt_fn, x)
print(code)
self.assertTrue(same(fn(x), opt_fn(x)))
# def and use
FileCheck().check_count("cpp_fused", 2, exactly=True).run(code)
Expand Down Expand Up @@ -3650,6 +3651,40 @@ def forward(self, x):
x = torch.randn(1, 4, 2, 2)
self.common(fn, (x,))

@requires_vectorization
def test_vec_indirect_load_cse_cache(self):
# https://github.com/pytorch/pytorch/issues/123502
from math import inf

def fn(arg0_1):
full_default = torch.ops.aten.full.default([209985], 1)
select = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1 = torch.ops.aten.select.int(arg0_1, 0, 1)
view = torch.ops.aten.reshape.default(select_1, [-1])
expand = torch.ops.aten.expand.default(view, [209985])
full_default_1 = torch.ops.aten.full.default([10000], 0)
scatter_add = torch.ops.aten.scatter_add.default(
full_default_1, 0, expand, full_default
)
pow_1 = torch.ops.aten.pow.Tensor_Scalar(scatter_add, -0.5)
eq = torch.ops.aten.eq.Scalar(pow_1, inf)
full_default_2 = torch.ops.aten.full.default([], 0.0)
where = torch.ops.aten.where.self(eq, full_default_2, pow_1)
index = torch.ops.aten.index.Tensor(where, [select])
index_1 = torch.ops.aten.index.Tensor(where, [select_1])
mul_1 = torch.ops.aten.mul.Tensor(index, index_1)
return (mul_1,)

x = torch.zeros(2, 209985).to(torch.int64)
opt_fn = torch._dynamo.optimize("inductor")(fn)
_, code = run_and_get_cpp_code(opt_fn, x)
print(code)
FileCheck().check_count(
"return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);",
2,
exactly=True,
).run(code)


if __name__ == "__main__":
from torch._inductor.test_case import run_tests
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,7 +2325,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert vec_var.is_vec
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with code.indent():
vec_dtype = vec_var.dtype
assert vec_dtype is not None
if vec_dtype == torch.bool:
Expand All @@ -2346,7 +2346,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert opt_ctx is not None
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with code.indent():
result_size = get_result_size(dtype)
result_declare = (
f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {result_size}> tmpbuf;"
Expand Down

0 comments on commit 8edb728

Please sign in to comment.