Skip to content

Commit

Permalink
[inductor] allow clone cse cache during vectorized indirect load
Browse files Browse the repository at this point in the history
ghstack-source-id: 3d5eb12fdc823729df0a6b67f8bf042c5d6c9b73
Pull Request resolved: #124597
  • Loading branch information
zhuhaozhe committed Apr 22, 2024
1 parent 3af1244 commit 62fa869
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
37 changes: 37 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -3648,6 +3648,43 @@ def forward(self, x):
x = torch.randn(1, 4, 2, 2)
self.common(fn, (x,))

def test_vec_indirect_load_cse_cache(self):
# https://github.com/pytorch/pytorch/issues/123502.
from math import inf

class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, arg0_1):
full_default = torch.ops.aten.full.default([209985], 1)
select = torch.ops.aten.select.int(arg0_1, 0, 0)
select_1 = torch.ops.aten.select.int(arg0_1, 0, 1)
view = torch.ops.aten.reshape.default(select_1, [-1])
expand = torch.ops.aten.expand.default(view, [209985])
full_default_1 = torch.ops.aten.full.default([10000], 0)
scatter_add = torch.ops.aten.scatter_add.default(
full_default_1, 0, expand, full_default
)
pow_1 = torch.ops.aten.pow.Tensor_Scalar(scatter_add, -0.5)
eq = torch.ops.aten.eq.Scalar(pow_1, inf)
full_default_2 = torch.ops.aten.full.default([], 0.0)
where = torch.ops.aten.where.self(eq, full_default_2, pow_1)
index = torch.ops.aten.index.Tensor(where, [select])
index_1 = torch.ops.aten.index.Tensor(where, [select_1])
mul_1 = torch.ops.aten.mul.Tensor(index, index_1)
return (mul_1,)

fn = Model()
x = torch.zeros(2, 209985).to(torch.int64)
_fn_opt = torch.compile()(fn)
_, code = run_and_get_cpp_code(_fn_opt, x)
FileCheck().check_count(
"return at::vec::VectorizedN<int64_t,2>::loadu(tmpbuf.data(), 16);",
2,
exactly=True,
).run(code)


if __name__ == "__main__":
from torch._inductor.test_case import run_tests
Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,9 +1348,11 @@ def set_current_node(self, node):
self.current_node = prior

@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
def swap_buffers(self, lb, cb=None, sb=None, cse_cache=None):
if cb is None:
cb = lb
if cse_cache is None:
cse_cache = {}
loads = self.loads
compute = self.compute
stores = self.stores
Expand All @@ -1359,6 +1361,7 @@ def swap_buffers(self, lb, cb=None, sb=None):
self.compute = cb
self.stores = sb
self.cse = cse.clone()
self.cse.cache = cse_cache
try:
yield
finally:
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert vec_var.is_vec
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with self.swap_buffers(code, cse_cache=self.cse.cache), code.indent():
vec_dtype = vec_var.dtype
assert vec_dtype is not None
if vec_dtype == torch.bool:
Expand All @@ -2339,7 +2339,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert opt_ctx is not None
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with self.swap_buffers(code, cse_cache=self.cse.cache), code.indent():
result_size = get_result_size(dtype)
result_declare = (
f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {result_size}> tmpbuf;"
Expand Down

0 comments on commit 62fa869

Please sign in to comment.