Skip to content

Commit

Permalink
[inductor] allow clone cse cache during vectorized indirect load
Browse files Browse the repository at this point in the history
ghstack-source-id: 616a56d2341cd20ecca38d70e130b90e8f67bb3b
Pull Request resolved: #124597
  • Loading branch information
zhuhaozhe committed Apr 24, 2024
1 parent 81740fd commit ecc79da
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
38 changes: 38 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3650,6 +3650,44 @@ def forward(self, x):
x = torch.randn(1, 4, 2, 2)
self.common(fn, (x,))

@requires_vectorization
def test_vec_indirect_load_cse_cache(self):
# https://github.com/pytorch/pytorch/issues/123502.
from math import inf

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, arg0_1):
full_default = torch.ops.aten.full.default([209985], 1)
select = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1 = torch.ops.aten.select.int(arg0_1, 0, 1)
view = torch.ops.aten.reshape.default(select_1, [-1])
expand = torch.ops.aten.expand.default(view, [209985])
full_default_1 = torch.ops.aten.full.default([10000], 0)
scatter_add = torch.ops.aten.scatter_add.default(
full_default_1, 0, expand, full_default
)
pow_1 = torch.ops.aten.pow.Tensor_Scalar(scatter_add, -0.5)
eq = torch.ops.aten.eq.Scalar(pow_1, inf)
full_default_2 = torch.ops.aten.full.default([], 0.0)
where = torch.ops.aten.where.self(eq, full_default_2, pow_1)
index = torch.ops.aten.index.Tensor(where, [select])
index_1 = torch.ops.aten.index.Tensor(where, [select_1])
mul_1 = torch.ops.aten.mul.Tensor(index, index_1)
return (mul_1,)

fn = Model()
x = torch.zeros(2, 209985).to(torch.int64)
_fn_opt = torch.compile()(fn)
_, code = run_and_get_cpp_code(_fn_opt, x)
FileCheck().check_count(
"return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);",
2,
exactly=True,
).run(code)


if __name__ == "__main__":
from torch._inductor.test_case import run_tests
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2325,7 +2325,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert vec_var.is_vec
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with code.indent():
vec_dtype = vec_var.dtype
assert vec_dtype is not None
if vec_dtype == torch.bool:
Expand All @@ -2346,7 +2346,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert opt_ctx is not None
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with code.indent():
result_size = get_result_size(dtype)
result_declare = (
f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {result_size}> tmpbuf;"
Expand Down

0 comments on commit ecc79da

Please sign in to comment.