-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] share more cse cache during swap buffer #124921
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124921
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 03dd525 with merge base 59a1f1f (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 2cd2271ed9ecb25a3d92aaa00ff9537d32202535 Pull Request resolved: #124921
[ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope. For example, ``` auto tmp8 = -std::numeric_limits<float>::infinity(); auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ``` `tmp12` should not be created since it is same with `tmp8`. We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.) **Test Plan** ``` python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu ``` the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope. Before this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = c10::convert<int>(x3); auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1); auto tmp21 = static_cast<int>(256); auto tmp22 = at::vec::Vectorized<int>(tmp21); auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22); auto tmp25 = at::vec::VecMask<float,1>::from(tmp18); auto tmp26 = tmp23 & tmp25; auto tmp24 = [&] { auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp27; } ; auto tmp28 = [&] { if (tmp26.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>()); } } () ; auto tmp29 = static_cast<float>(0.0); auto tmp30 = at::vec::Vectorized<float>(tmp29); auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>()); return tmp31; } ; auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp33 = static_cast<float>(0.0); auto tmp34 = at::vec::VecMask<float,1>::from(tmp16); auto tmp35 = at::vec::Vectorized<float>(tmp33); auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>()); auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>()); return tmp37; } ; auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp40 = static_cast<int>(3); auto tmp41 = tmp39 < tmp40; auto tmp42 = [&] { auto tmp43 = c10::convert<int>(x3); auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1); auto tmp45 = static_cast<int>(256); auto tmp46 = at::vec::Vectorized<int>(tmp45); auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46); auto tmp49 = at::vec::VecMask<float,1>::from(tmp41); auto tmp50 = tmp47 & tmp49; auto tmp48 = [&] { auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp51; } ; auto tmp52 = [&] { if (tmp50.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>()); } } () ; auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::Vectorized<float>(tmp53); auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>()); return tmp55; } ; auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp57 = static_cast<float>(0.0); auto tmp58 = at::vec::VecMask<float,1>::from(tmp41); auto tmp59 = at::vec::Vectorized<float>(tmp57); auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>()); auto tmp61 = at::vec::VecMask<float,1>::from(tmp2); auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>()); tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = c10::convert<int64_t>(x3); auto tmp15 = static_cast<int64_t>(256); auto tmp16 = tmp14 >= tmp15; auto tmp17 = [&] { auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp18; } ; auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = static_cast<float>(0.0); auto tmp21 = tmp16 ? tmp19 : tmp20; return tmp21; } ; auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp23 = static_cast<float>(0.0); auto tmp24 = tmp12 ? tmp22 : tmp23; auto tmp25 = tmp6 ? tmp9 : tmp24; return tmp25; } ; auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp28 = static_cast<int64_t>(3); auto tmp29 = tmp27 < tmp28; auto tmp30 = [&] { auto tmp31 = c10::convert<int64_t>(x3); auto tmp32 = static_cast<int64_t>(256); auto tmp33 = tmp31 >= tmp32; auto tmp34 = [&] { auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp35; } ; auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp33 ? tmp36 : tmp37; return tmp38; } ; auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0); auto tmp40 = static_cast<float>(0.0); auto tmp41 = tmp29 ? tmp39 : tmp40; auto tmp42 = tmp2 ? tmp26 : tmp41; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42; } } } } } } } ''') ``` After this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = at::vec::Vectorized<int>(tmp1); auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19); auto tmp22 = at::vec::VecMask<float,1>::from(tmp18); auto tmp23 = tmp20 & tmp22; auto tmp21 = [&] { auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp24; } ; auto tmp25 = [&] { if (tmp23.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>()); } } () ; auto tmp26 = static_cast<float>(0.0); auto tmp27 = at::vec::Vectorized<float>(tmp26); auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>()); return tmp28; } ; auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp30 = static_cast<float>(0.0); auto tmp31 = at::vec::VecMask<float,1>::from(tmp16); auto tmp32 = at::vec::Vectorized<float>(tmp30); auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>()); auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>()); return tmp34; } ; auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp37 = static_cast<int>(3); auto tmp38 = tmp36 < tmp37; auto tmp39 = [&] { auto tmp40 = c10::convert<int>(x3); auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1); auto tmp42 = at::vec::Vectorized<int>(tmp1); auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42); auto tmp45 = at::vec::VecMask<float,1>::from(tmp38); auto tmp46 = tmp43 & tmp45; auto tmp44 = [&] { auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp47; } ; auto tmp48 = [&] { if (tmp46.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>()); } } () ; auto tmp49 = static_cast<float>(0.0); auto tmp50 = at::vec::Vectorized<float>(tmp49); auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>()); return tmp51; } ; auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::VecMask<float,1>::from(tmp38); auto tmp55 = at::vec::Vectorized<float>(tmp53); auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>()); auto tmp57 = at::vec::VecMask<float,1>::from(tmp2); auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>()); tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = tmp4 >= tmp1; auto tmp15 = [&] { auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp16; } ; auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0); auto tmp18 = static_cast<float>(0.0); auto tmp19 = tmp14 ? tmp17 : tmp18; return tmp19; } ; auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp21 = static_cast<float>(0.0); auto tmp22 = tmp12 ? tmp20 : tmp21; auto tmp23 = tmp6 ? tmp9 : tmp22; return tmp23; } ; auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp26 = static_cast<int64_t>(3); auto tmp27 = tmp25 < tmp26; auto tmp28 = [&] { auto tmp29 = c10::convert<int64_t>(x3); auto tmp30 = tmp29 >= tmp1; auto tmp31 = [&] { auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp32; } ; auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0); auto tmp34 = static_cast<float>(0.0); auto tmp35 = tmp30 ? tmp33 : tmp34; return tmp35; } ; auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp27 ? tmp36 : tmp37; auto tmp39 = tmp2 ? tmp24 : tmp38; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39; } } } } } } } ''') ``` Pull Request resolved: pytorch#124921 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: pytorch#124597
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope. For example, ``` auto tmp8 = -std::numeric_limits<float>::infinity(); auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ``` `tmp12` should not be created since it is same with `tmp8`. We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.) **Test Plan** ``` python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu ``` the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope. Before this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = c10::convert<int>(x3); auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1); auto tmp21 = static_cast<int>(256); auto tmp22 = at::vec::Vectorized<int>(tmp21); auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22); auto tmp25 = at::vec::VecMask<float,1>::from(tmp18); auto tmp26 = tmp23 & tmp25; auto tmp24 = [&] { auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp27; } ; auto tmp28 = [&] { if (tmp26.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>()); } } () ; auto tmp29 = static_cast<float>(0.0); auto tmp30 = at::vec::Vectorized<float>(tmp29); auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>()); return tmp31; } ; auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp33 = static_cast<float>(0.0); auto tmp34 = at::vec::VecMask<float,1>::from(tmp16); auto tmp35 = at::vec::Vectorized<float>(tmp33); auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>()); auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>()); return tmp37; } ; auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp40 = static_cast<int>(3); auto tmp41 = tmp39 < tmp40; auto tmp42 = [&] { auto tmp43 = c10::convert<int>(x3); auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1); auto tmp45 = static_cast<int>(256); auto tmp46 = at::vec::Vectorized<int>(tmp45); auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46); auto tmp49 = at::vec::VecMask<float,1>::from(tmp41); auto tmp50 = tmp47 & tmp49; auto tmp48 = [&] { auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp51; } ; auto tmp52 = [&] { if (tmp50.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>()); } } () ; auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::Vectorized<float>(tmp53); auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>()); return tmp55; } ; auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp57 = static_cast<float>(0.0); auto tmp58 = at::vec::VecMask<float,1>::from(tmp41); auto tmp59 = at::vec::Vectorized<float>(tmp57); auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>()); auto tmp61 = at::vec::VecMask<float,1>::from(tmp2); auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>()); tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = c10::convert<int64_t>(x3); auto tmp15 = static_cast<int64_t>(256); auto tmp16 = tmp14 >= tmp15; auto tmp17 = [&] { auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp18; } ; auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0); auto tmp20 = static_cast<float>(0.0); auto tmp21 = tmp16 ? tmp19 : tmp20; return tmp21; } ; auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp23 = static_cast<float>(0.0); auto tmp24 = tmp12 ? tmp22 : tmp23; auto tmp25 = tmp6 ? tmp9 : tmp24; return tmp25; } ; auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp28 = static_cast<int64_t>(3); auto tmp29 = tmp27 < tmp28; auto tmp30 = [&] { auto tmp31 = c10::convert<int64_t>(x3); auto tmp32 = static_cast<int64_t>(256); auto tmp33 = tmp31 >= tmp32; auto tmp34 = [&] { auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp35; } ; auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp33 ? tmp36 : tmp37; return tmp38; } ; auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0); auto tmp40 = static_cast<float>(0.0); auto tmp41 = tmp29 ? tmp39 : tmp40; auto tmp42 = tmp2 ? tmp26 : tmp41; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42; } } } } } } } ''') ``` After this PR, ``` cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], ''' #include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h" extern "C" void kernel(const float* in_ptr0, float* out_ptr1) { #pragma omp parallel num_threads(128) { int tid = omp_get_thread_num(); { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L)) { for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L)) { auto tmp0 = c10::convert<int>(x1); auto tmp1 = static_cast<int>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int>(x3); auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1); auto tmp6 = static_cast<int>(257); auto tmp7 = at::vec::Vectorized<int>(tmp6); auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7); auto tmp10 = at::vec::VecMask<float,1>::from(tmp2); auto tmp11 = tmp8 & tmp10; auto tmp9 = [&] { auto tmp12 = -std::numeric_limits<float>::infinity(); return tmp12; } ; auto tmp13 = [&] { if (tmp11.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>()); } } () ; auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp15 = static_cast<int>(3); auto tmp16 = tmp14 < tmp15; auto tmp18 = tmp16 & tmp2; auto tmp17 = [&] { auto tmp19 = at::vec::Vectorized<int>(tmp1); auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19); auto tmp22 = at::vec::VecMask<float,1>::from(tmp18); auto tmp23 = tmp20 & tmp22; auto tmp21 = [&] { auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp24; } ; auto tmp25 = [&] { if (tmp23.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>()); } } () ; auto tmp26 = static_cast<float>(0.0); auto tmp27 = at::vec::Vectorized<float>(tmp26); auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>()); return tmp28; } ; auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp30 = static_cast<float>(0.0); auto tmp31 = at::vec::VecMask<float,1>::from(tmp16); auto tmp32 = at::vec::Vectorized<float>(tmp30); auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>()); auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>()); return tmp34; } ; auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L)); auto tmp37 = static_cast<int>(3); auto tmp38 = tmp36 < tmp37; auto tmp39 = [&] { auto tmp40 = c10::convert<int>(x3); auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1); auto tmp42 = at::vec::Vectorized<int>(tmp1); auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42); auto tmp45 = at::vec::VecMask<float,1>::from(tmp38); auto tmp46 = tmp43 & tmp45; auto tmp44 = [&] { auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))); return tmp47; } ; auto tmp48 = [&] { if (tmp46.all_zero()) { return at::vec::Vectorized<float>(static_cast<float>(0.0)); } else { return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>()); } } () ; auto tmp49 = static_cast<float>(0.0); auto tmp50 = at::vec::Vectorized<float>(tmp49); auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>()); return tmp51; } ; auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0)); auto tmp53 = static_cast<float>(0.0); auto tmp54 = at::vec::VecMask<float,1>::from(tmp38); auto tmp55 = at::vec::Vectorized<float>(tmp53); auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>()); auto tmp57 = at::vec::VecMask<float,1>::from(tmp2); auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>()); tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))); } #pragma omp simd simdlen(8) for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L)) { auto tmp0 = c10::convert<int64_t>(x1); auto tmp1 = static_cast<int64_t>(256); auto tmp2 = tmp0 < tmp1; auto tmp3 = [&] { auto tmp4 = c10::convert<int64_t>(x3); auto tmp5 = static_cast<int64_t>(257); auto tmp6 = tmp4 < tmp5; auto tmp7 = [&] { auto tmp8 = -std::numeric_limits<float>::infinity(); return tmp8; } ; auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0); auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp11 = static_cast<int64_t>(3); auto tmp12 = tmp10 < tmp11; auto tmp13 = [&] { auto tmp14 = tmp4 >= tmp1; auto tmp15 = [&] { auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp16; } ; auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0); auto tmp18 = static_cast<float>(0.0); auto tmp19 = tmp14 ? tmp17 : tmp18; return tmp19; } ; auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0); auto tmp21 = static_cast<float>(0.0); auto tmp22 = tmp12 ? tmp20 : tmp21; auto tmp23 = tmp6 ? tmp9 : tmp22; return tmp23; } ; auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0); auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L)); auto tmp26 = static_cast<int64_t>(3); auto tmp27 = tmp25 < tmp26; auto tmp28 = [&] { auto tmp29 = c10::convert<int64_t>(x3); auto tmp30 = tmp29 >= tmp1; auto tmp31 = [&] { auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))]; return tmp32; } ; auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0); auto tmp34 = static_cast<float>(0.0); auto tmp35 = tmp30 ? tmp33 : tmp34; return tmp35; } ; auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0); auto tmp37 = static_cast<float>(0.0); auto tmp38 = tmp27 ? tmp36 : tmp37; auto tmp39 = tmp2 ? tmp24 : tmp38; out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39; } } } } } } } ''') ``` Pull Request resolved: #124921 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #124597
swap_buffer
will make thecse_cache
cannot be shared inside/outside of the lambda function scope.For example,
tmp12
should not be created since it is same withtmp8
.We make the
cse_cache
as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)Test Plan
the
static_cast<int>(256)
will only occur once after this PR since the inside scope can share the cse buffer outside the scope.Before this PR,
After this PR,
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang