Skip to content

Commit

Permalink
also clone cse cache
Browse files Browse the repository at this point in the history
ghstack-source-id: 08e123950741038a088517781a434629f122654a
Pull Request resolved: #124597
  • Loading branch information
zhuhaozhe committed Apr 22, 2024
1 parent 3af1244 commit 96be07f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,9 +1348,11 @@ def set_current_node(self, node):
self.current_node = prior

@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
def swap_buffers(self, lb, cb=None, sb=None, cse_cache=None):
if cb is None:
cb = lb
if cse_cache is None:
cse_cache = {}
loads = self.loads
compute = self.compute
stores = self.stores
Expand All @@ -1359,6 +1361,7 @@ def swap_buffers(self, lb, cb=None, sb=None):
self.compute = cb
self.stores = sb
self.cse = cse.clone()
self.cse.cache = cse_cache
try:
yield
finally:
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert vec_var.is_vec
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with self.swap_buffers(code, cse_cache=self.cse.cache), code.indent():
vec_dtype = vec_var.dtype
assert vec_dtype is not None
if vec_dtype == torch.bool:
Expand All @@ -2339,7 +2339,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert opt_ctx is not None
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with self.swap_buffers(code, cse_cache=self.cse.cache), code.indent():
result_size = get_result_size(dtype)
result_declare = (
f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {result_size}> tmpbuf;"
Expand Down

0 comments on commit 96be07f

Please sign in to comment.