-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] share cse cache during vectorized indirect load #124597
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124597
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit dda79ae with merge base 59a1f1f (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: a11a76fdbe9b6eab89a881188ee66b405f4bc461 Pull Request resolved: #124597
[ghstack-poisoned]
ghstack-source-id: 08e123950741038a088517781a434629f122654a Pull Request resolved: #124597
Fix #123502 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 3d5eb12fdc823729df0a6b67f8bf042c5d6c9b73 Pull Request resolved: #124597
…t load" Fix #123502 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
torch/_inductor/codegen/common.py
Outdated
@@ -1348,9 +1348,11 @@ def set_current_node(self, node): | |||
self.current_node = prior | |||
|
|||
@contextlib.contextmanager | |||
def swap_buffers(self, lb, cb=None, sb=None): | |||
def swap_buffers(self, lb, cb=None, sb=None, cse_cache=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "cse" object has cache
, reduction_cache
and store_cache
. A single cse_cache
cannot represent all of them. Maybe it is better to pass cse
directly as to override the new "cse" after the buffers are swapped?
torch/_inductor/codegen/cpp.py
Outdated
@@ -2318,7 +2318,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable: | |||
assert vec_var.is_vec | |||
code = BracesBuffer() | |||
code.writeline("[&]") | |||
with self.swap_buffers(code), code.indent(): | |||
with self.swap_buffers(code, cse_cache=self.cse.cache), code.indent(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When new variables are defined in the scope, they cannot be used outside the scope afterwards. So, it is not safe to reuse the existing self.cse.cache
here. I guess we want to create a wrapper around the cache abstraction that supports scoping, allowing variables and expressions defined within a scope to be valid only within that scope while still leveraging the existing cache from the outer scope before it. How about we create something like below?
def scope_cse(cse):
new_cse = cse.clone()
new_cse.cache = ScopedDict(cse.cache)
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
new_cse.store_cache = ScopedDict(cse.store_cache)
return new_cse
Here the ScopedDict
allows read-only access to the original dict while uses a separate dict to hold newly-added items:
class ScopedDict:
def __init__(self, original_dict):
self.original_dict = original_dict
self.new_items = {}
def __getitem__(self, key):
if key in self.new_items:
return self.new_items[key]
return self.original_dict[key]
def __setitem__(self, key, value):
self.new_items[key] = value
def __contains__(self, key):
return key in self.new_items or key in self.original_dict
def get(self, key, default=None):
if key in self.new_items:
return self.new_items[key]
return self.original_dict.get(key, default)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When new variables are defined in the scope, they cannot be used outside the scope afterwards. So, it is not safe to reuse the existing
self.cse.cache
here. I guess we want to create a wrapper around the cache abstraction that supports scoping, allowing variables and expressions defined within a scope to be valid only within that scope while still leveraging the existing cache from the outer scope before it. How about we create something like below?def scope_cse(cse): new_cse = cse.clone() new_cse.cache = ScopedDict(cse.cache) new_cse.reduction_cache = ScopedDict(cse.reduction_cache) new_cse.store_cache = ScopedDict(cse.store_cache) return new_cseHere the
ScopedDict
allows read-only access to the original dict while uses a separate dict to hold newly-added items:class ScopedDict: def __init__(self, original_dict): self.original_dict = original_dict self.new_items = {} def __getitem__(self, key): if key in self.new_items: return self.new_items[key] return self.original_dict[key] def __setitem__(self, key, value): self.new_items[key] = value def __contains__(self, key): return key in self.new_items or key in self.original_dict def get(self, key, default=None): if key in self.new_items: return self.new_items[key] return self.original_dict.get(key, default)
Sounds better, I will try this way.
ghstack-source-id: 616a56d2341cd20ecca38d70e130b90e8f67bb3b Pull Request resolved: #124597
…t load" Fix #123502 `swap_buffer` do not clone the `cse.cache` which will bring redundant computation. We may able to clone the `cse.cache` if there is no cse value in the `expr` ``` auto tmp8 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; // // other codes // // also store tmp7 here (redundant tmp16) auto tmp16 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: f0bfb65da321fa7b02468e5980682e9e3f9d2160 Pull Request resolved: #124597
…t load" Fix #123502 `swap_buffer` do not clone the `cse.cache` which will bring redundant computation. We may able to clone the `cse.cache` if there is no cse value in the `expr` ``` auto tmp8 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; // // other codes // // also store tmp7 here (redundant tmp16) auto tmp16 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 2f026a9472e118acc783edc57aefbc9db4e4d582 Pull Request resolved: #124597
…t load" Fix #123502 `swap_buffer` do not clone the `cse.cache` which will bring redundant computation. We may able to clone the `cse.cache` if there is no cse value in the `expr` ``` auto tmp8 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; // // other codes // // also store tmp7 here (redundant tmp16) auto tmp16 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
ghstack-source-id: 0bd0a3852d71d986f37e0acd337e64751c3fa678 Pull Request resolved: #124597
…t load" Fix #123502 `swap_buffer` do not clone the `cse.cache` which will bring redundant computation. We may able to clone the `cse.cache` if there is no cse value in the `expr` ``` auto tmp8 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; // // other codes // // also store tmp7 here (redundant tmp16) auto tmp16 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…t load" Fix #123502 `swap_buffer` do not clone the `cse.cache` which will bring redundant computation. We may able to clone the `cse.cache` if there is no cse value in the `expr` ``` auto tmp8 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; // // other codes // // also store tmp7 here (redundant tmp16) auto tmp16 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Hi, @jgong5. For this issue #123502, your recommendation for
the For my current understanding we may not need to always put a |
As for the |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope. For example, ``` auto tmp8 = -std::numeric_limits<float>::infinity(); auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ``` `tmp12` should not be created since it is same with `tmp8`. We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.) **Test Plan** ``` python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu ``` the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope. Before this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = c10::convert<int>(x3); auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1); auto tmp21 = static_cast<int>(256); auto tmp22 = at::vec::Vectorized<int>(tmp21); auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22); auto tmp25 = at::vec::VecMask<float,1>::from(tmp18); auto tmp26 = tmp23 & tmp25; auto tmp24 = [&] { auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp27; } ; auto tmp28 = [&] { if (tmp26.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>()); } } () ; auto tmp29 = static_cast<float>(0.0); auto tmp30 = at::vec::Vectorized<float>(tmp29); auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>()); return tmp31; } ; auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp33 = static_cast<float>(0.0); auto tmp34 = at::vec::VecMask<float,1>::from(tmp16); auto tmp35 = at::vec::Vectorized<float>(tmp33); auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>()); auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>()); return tmp37; } ; auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp40 = static_cast<int>(3); auto tmp41 = tmp39 < tmp40; auto tmp42 = [&] { auto tmp43 = c10::convert<int>(x3); auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1); auto tmp45 = static_cast<int>(256); auto tmp46 = at::vec::Vectorized<int>(tmp45); auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46); auto tmp49 = at::vec::VecMask<float,1>::from(tmp41); auto tmp50 = tmp47 & tmp49; auto tmp48 = [&] { auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp51; } ; auto tmp52 = [&] { if (tmp50.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>()); } } () ; auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::Vectorized<float>(tmp53); auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>()); return tmp55; } ; auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp57 = static_cast<float>(0.0); auto tmp58 = at::vec::VecMask<float,1>::from(tmp41); auto tmp59 = at::vec::Vectorized<float>(tmp57); auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>()); auto tmp61 = at::vec::VecMask<float,1>::from(tmp2); auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>()); tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = c10::convert<int64_t>(x3); auto tmp15 = static_cast<int64_t>(256); auto tmp16 = tmp14 >= tmp15; auto tmp17 = [&] { auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp18; } ; auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = static_cast<float>(0.0); auto tmp21 = tmp16 ? tmp19 : tmp20; return tmp21; } ; auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp23 = static_cast<float>(0.0); auto tmp24 = tmp12 ? tmp22 : tmp23; auto tmp25 = tmp6 ? tmp9 : tmp24; return tmp25; } ; auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp28 = static_cast<int64_t>(3); auto tmp29 = tmp27 < tmp28; auto tmp30 = [&] { auto tmp31 = c10::convert<int64_t>(x3); auto tmp32 = static_cast<int64_t>(256); auto tmp33 = tmp31 >= tmp32; auto tmp34 = [&] { auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp35; } ; auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp33 ? tmp36 : tmp37; return tmp38; } ; auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0); auto tmp40 = static_cast<float>(0.0); auto tmp41 = tmp29 ? tmp39 : tmp40; auto tmp42 = tmp2 ? tmp26 : tmp41; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42; } } } } } } } ''') ``` After this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = at::vec::Vectorized<int>(tmp1); auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19); auto tmp22 = at::vec::VecMask<float,1>::from(tmp18); auto tmp23 = tmp20 & tmp22; auto tmp21 = [&] { auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp24; } ; auto tmp25 = [&] { if (tmp23.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>()); } } () ; auto tmp26 = static_cast<float>(0.0); auto tmp27 = at::vec::Vectorized<float>(tmp26); auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>()); return tmp28; } ; auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp30 = static_cast<float>(0.0); auto tmp31 = at::vec::VecMask<float,1>::from(tmp16); auto tmp32 = at::vec::Vectorized<float>(tmp30); auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>()); auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>()); return tmp34; } ; auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp37 = static_cast<int>(3); auto tmp38 = tmp36 < tmp37; auto tmp39 = [&] { auto tmp40 = c10::convert<int>(x3); auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1); auto tmp42 = at::vec::Vectorized<int>(tmp1); auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42); auto tmp45 = at::vec::VecMask<float,1>::from(tmp38); auto tmp46 = tmp43 & tmp45; auto tmp44 = [&] { auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp47; } ; auto tmp48 = [&] { if (tmp46.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>()); } } () ; auto tmp49 = static_cast<float>(0.0); auto tmp50 = at::vec::Vectorized<float>(tmp49); auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>()); return tmp51; } ; auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::VecMask<float,1>::from(tmp38); auto tmp55 = at::vec::Vectorized<float>(tmp53); auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>()); auto tmp57 = at::vec::VecMask<float,1>::from(tmp2); auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>()); tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = tmp4 >= tmp1; auto tmp15 = [&] { auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp16; } ; auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0); auto tmp18 = static_cast<float>(0.0); auto tmp19 = tmp14 ? tmp17 : tmp18; return tmp19; } ; auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp21 = static_cast<float>(0.0); auto tmp22 = tmp12 ? tmp20 : tmp21; auto tmp23 = tmp6 ? tmp9 : tmp22; return tmp23; } ; auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp26 = static_cast<int64_t>(3); auto tmp27 = tmp25 < tmp26; auto tmp28 = [&] { auto tmp29 = c10::convert<int64_t>(x3); auto tmp30 = tmp29 >= tmp1; auto tmp31 = [&] { auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp32; } ; auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0); auto tmp34 = static_cast<float>(0.0); auto tmp35 = tmp30 ? tmp33 : tmp34; return tmp35; } ; auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp27 ? tmp36 : tmp37; auto tmp39 = tmp2 ? tmp24 : tmp38; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39; } } } } } } } ''') ``` Pull Request resolved: #124921 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #124597
…24597) Fix pytorch#123502 `swap_buffer` in not needed in vectorized indirect load, remove it to share cse buffer. ``` auto tmp8 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; // // other codes // // also store tmp7 here (redundant tmp16) auto tmp16 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; ``` Pull Request resolved: pytorch#124597 Approved by: https://github.com/jgong5, https://github.com/jansel
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope. For example, ``` auto tmp8 = -std::numeric_limits<float>::infinity(); auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ``` `tmp12` should not be created since it is same with `tmp8`. We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.) **Test Plan** ``` python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu ``` the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope. Before this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = c10::convert<int>(x3); auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1); auto tmp21 = static_cast<int>(256); auto tmp22 = at::vec::Vectorized<int>(tmp21); auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22); auto tmp25 = at::vec::VecMask<float,1>::from(tmp18); auto tmp26 = tmp23 & tmp25; auto tmp24 = [&] { auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp27; } ; auto tmp28 = [&] { if (tmp26.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>()); } } () ; auto tmp29 = static_cast<float>(0.0); auto tmp30 = at::vec::Vectorized<float>(tmp29); auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>()); return tmp31; } ; auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp33 = static_cast<float>(0.0); auto tmp34 = at::vec::VecMask<float,1>::from(tmp16); auto tmp35 = at::vec::Vectorized<float>(tmp33); auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>()); auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>()); return tmp37; } ; auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp40 = static_cast<int>(3); auto tmp41 = tmp39 < tmp40; auto tmp42 = [&] { auto tmp43 = c10::convert<int>(x3); auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1); auto tmp45 = static_cast<int>(256); auto tmp46 = at::vec::Vectorized<int>(tmp45); auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46); auto tmp49 = at::vec::VecMask<float,1>::from(tmp41); auto tmp50 = tmp47 & tmp49; auto tmp48 = [&] { auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp51; } ; auto tmp52 = [&] { if (tmp50.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>()); } } () ; auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::Vectorized<float>(tmp53); auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>()); return tmp55; } ; auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp57 = static_cast<float>(0.0); auto tmp58 = at::vec::VecMask<float,1>::from(tmp41); auto tmp59 = at::vec::Vectorized<float>(tmp57); auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>()); auto tmp61 = at::vec::VecMask<float,1>::from(tmp2); auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>()); tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = c10::convert<int64_t>(x3); auto tmp15 = static_cast<int64_t>(256); auto tmp16 = tmp14 >= tmp15; auto tmp17 = [&] { auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp18; } ; auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = static_cast<float>(0.0); auto tmp21 = tmp16 ? tmp19 : tmp20; return tmp21; } ; auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp23 = static_cast<float>(0.0); auto tmp24 = tmp12 ? tmp22 : tmp23; auto tmp25 = tmp6 ? tmp9 : tmp24; return tmp25; } ; auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp28 = static_cast<int64_t>(3); auto tmp29 = tmp27 < tmp28; auto tmp30 = [&] { auto tmp31 = c10::convert<int64_t>(x3); auto tmp32 = static_cast<int64_t>(256); auto tmp33 = tmp31 >= tmp32; auto tmp34 = [&] { auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp35; } ; auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp33 ? tmp36 : tmp37; return tmp38; } ; auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0); auto tmp40 = static_cast<float>(0.0); auto tmp41 = tmp29 ? tmp39 : tmp40; auto tmp42 = tmp2 ? tmp26 : tmp41; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42; } } } } } } } ''') ``` After this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = at::vec::Vectorized<int>(tmp1); auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19); auto tmp22 = at::vec::VecMask<float,1>::from(tmp18); auto tmp23 = tmp20 & tmp22; auto tmp21 = [&] { auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp24; } ; auto tmp25 = [&] { if (tmp23.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>()); } } () ; auto tmp26 = static_cast<float>(0.0); auto tmp27 = at::vec::Vectorized<float>(tmp26); auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>()); return tmp28; } ; auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp30 = static_cast<float>(0.0); auto tmp31 = at::vec::VecMask<float,1>::from(tmp16); auto tmp32 = at::vec::Vectorized<float>(tmp30); auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>()); auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>()); return tmp34; } ; auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp37 = static_cast<int>(3); auto tmp38 = tmp36 < tmp37; auto tmp39 = [&] { auto tmp40 = c10::convert<int>(x3); auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1); auto tmp42 = at::vec::Vectorized<int>(tmp1); auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42); auto tmp45 = at::vec::VecMask<float,1>::from(tmp38); auto tmp46 = tmp43 & tmp45; auto tmp44 = [&] { auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp47; } ; auto tmp48 = [&] { if (tmp46.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>()); } } () ; auto tmp49 = static_cast<float>(0.0); auto tmp50 = at::vec::Vectorized<float>(tmp49); auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>()); return tmp51; } ; auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::VecMask<float,1>::from(tmp38); auto tmp55 = at::vec::Vectorized<float>(tmp53); auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>()); auto tmp57 = at::vec::VecMask<float,1>::from(tmp2); auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>()); tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = tmp4 >= tmp1; auto tmp15 = [&] { auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp16; } ; auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0); auto tmp18 = static_cast<float>(0.0); auto tmp19 = tmp14 ? tmp17 : tmp18; return tmp19; } ; auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp21 = static_cast<float>(0.0); auto tmp22 = tmp12 ? tmp20 : tmp21; auto tmp23 = tmp6 ? tmp9 : tmp22; return tmp23; } ; auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp26 = static_cast<int64_t>(3); auto tmp27 = tmp25 < tmp26; auto tmp28 = [&] { auto tmp29 = c10::convert<int64_t>(x3); auto tmp30 = tmp29 >= tmp1; auto tmp31 = [&] { auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp32; } ; auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0); auto tmp34 = static_cast<float>(0.0); auto tmp35 = tmp30 ? tmp33 : tmp34; return tmp35; } ; auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp27 ? tmp36 : tmp37; auto tmp39 = tmp2 ? tmp24 : tmp38; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39; } } } } } } } ''') ``` Pull Request resolved: pytorch#124921 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: pytorch#124597
…24597) Fix pytorch#123502 `swap_buffer` in not needed in vectorized indirect load, remove it to share cse buffer. ``` auto tmp8 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; // // other codes // // also store tmp7 here (redundant tmp16) auto tmp16 = [&] { __at_align__ std::array<int64_t, 16> tmpbuf; tmp7.store(tmpbuf.data()); return tmpbuf; } () ; ``` Pull Request resolved: pytorch#124597 Approved by: https://github.com/jgong5, https://github.com/jansel
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope. For example, ``` auto tmp8 = -std::numeric_limits<float>::infinity(); auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ``` `tmp12` should not be created since it is same with `tmp8`. We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.) **Test Plan** ``` python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu ``` the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope. Before this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = c10::convert<int>(x3); auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1); auto tmp21 = static_cast<int>(256); auto tmp22 = at::vec::Vectorized<int>(tmp21); auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22); auto tmp25 = at::vec::VecMask<float,1>::from(tmp18); auto tmp26 = tmp23 & tmp25; auto tmp24 = [&] { auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp27; } ; auto tmp28 = [&] { if (tmp26.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>()); } } () ; auto tmp29 = static_cast<float>(0.0); auto tmp30 = at::vec::Vectorized<float>(tmp29); auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>()); return tmp31; } ; auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp33 = static_cast<float>(0.0); auto tmp34 = at::vec::VecMask<float,1>::from(tmp16); auto tmp35 = at::vec::Vectorized<float>(tmp33); auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>()); auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>()); return tmp37; } ; auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp40 = static_cast<int>(3); auto tmp41 = tmp39 < tmp40; auto tmp42 = [&] { auto tmp43 = c10::convert<int>(x3); auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1); auto tmp45 = static_cast<int>(256); auto tmp46 = at::vec::Vectorized<int>(tmp45); auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46); auto tmp49 = at::vec::VecMask<float,1>::from(tmp41); auto tmp50 = tmp47 & tmp49; auto tmp48 = [&] { auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp51; } ; auto tmp52 = [&] { if (tmp50.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>()); } } () ; auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::Vectorized<float>(tmp53); auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>()); return tmp55; } ; auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp57 = static_cast<float>(0.0); auto tmp58 = at::vec::VecMask<float,1>::from(tmp41); auto tmp59 = at::vec::Vectorized<float>(tmp57); auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>()); auto tmp61 = at::vec::VecMask<float,1>::from(tmp2); auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>()); tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = c10::convert<int64_t>(x3); auto tmp15 = static_cast<int64_t>(256); auto tmp16 = tmp14 >= tmp15; auto tmp17 = [&] { auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp18; } ; auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = static_cast<float>(0.0); auto tmp21 = tmp16 ? tmp19 : tmp20; return tmp21; } ; auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp23 = static_cast<float>(0.0); auto tmp24 = tmp12 ? tmp22 : tmp23; auto tmp25 = tmp6 ? tmp9 : tmp24; return tmp25; } ; auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp28 = static_cast<int64_t>(3); auto tmp29 = tmp27 < tmp28; auto tmp30 = [&] { auto tmp31 = c10::convert<int64_t>(x3); auto tmp32 = static_cast<int64_t>(256); auto tmp33 = tmp31 >= tmp32; auto tmp34 = [&] { auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp35; } ; auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp33 ? tmp36 : tmp37; return tmp38; } ; auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0); auto tmp40 = static_cast<float>(0.0); auto tmp41 = tmp29 ? tmp39 : tmp40; auto tmp42 = tmp2 ? tmp26 : tmp41; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42; } } } } } } } ''') ``` After this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = at::vec::Vectorized<int>(tmp1); auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19); auto tmp22 = at::vec::VecMask<float,1>::from(tmp18); auto tmp23 = tmp20 & tmp22; auto tmp21 = [&] { auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp24; } ; auto tmp25 = [&] { if (tmp23.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>()); } } () ; auto tmp26 = static_cast<float>(0.0); auto tmp27 = at::vec::Vectorized<float>(tmp26); auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>()); return tmp28; } ; auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp30 = static_cast<float>(0.0); auto tmp31 = at::vec::VecMask<float,1>::from(tmp16); auto tmp32 = at::vec::Vectorized<float>(tmp30); auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>()); auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>()); return tmp34; } ; auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp37 = static_cast<int>(3); auto tmp38 = tmp36 < tmp37; auto tmp39 = [&] { auto tmp40 = c10::convert<int>(x3); auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1); auto tmp42 = at::vec::Vectorized<int>(tmp1); auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42); auto tmp45 = at::vec::VecMask<float,1>::from(tmp38); auto tmp46 = tmp43 & tmp45; auto tmp44 = [&] { auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp47; } ; auto tmp48 = [&] { if (tmp46.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>()); } } () ; auto tmp49 = static_cast<float>(0.0); auto tmp50 = at::vec::Vectorized<float>(tmp49); auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>()); return tmp51; } ; auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::VecMask<float,1>::from(tmp38); auto tmp55 = at::vec::Vectorized<float>(tmp53); auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>()); auto tmp57 = at::vec::VecMask<float,1>::from(tmp2); auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>()); tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = tmp4 >= tmp1; auto tmp15 = [&] { auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp16; } ; auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0); auto tmp18 = static_cast<float>(0.0); auto tmp19 = tmp14 ? tmp17 : tmp18; return tmp19; } ; auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp21 = static_cast<float>(0.0); auto tmp22 = tmp12 ? tmp20 : tmp21; auto tmp23 = tmp6 ? tmp9 : tmp22; return tmp23; } ; auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp26 = static_cast<int64_t>(3); auto tmp27 = tmp25 < tmp26; auto tmp28 = [&] { auto tmp29 = c10::convert<int64_t>(x3); auto tmp30 = tmp29 >= tmp1; auto tmp31 = [&] { auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp32; } ; auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0); auto tmp34 = static_cast<float>(0.0); auto tmp35 = tmp30 ? tmp33 : tmp34; return tmp35; } ; auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp27 ? tmp36 : tmp37; auto tmp39 = tmp2 ? tmp24 : tmp38; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39; } } } } } } } ''') ``` Pull Request resolved: pytorch#124921 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: pytorch#124597
Fix #123502
swap_buffer
in not needed in vectorized indirect load, remove it to share cse buffer.Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang