Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] share more cse cache during swap buffer #124921

Closed
wants to merge 1 commit into from

Conversation

zhuhaozhe
Copy link
Contributor

@zhuhaozhe zhuhaozhe commented Apr 25, 2024

swap_buffer will make the cse_cache cannot be shared inside/outside of the lambda function scope.
For example,

auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}

tmp12 should not be created since it is same with tmp8.

We make the cse_cache as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

Test Plan

python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu

the static_cast<int>(256) will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,

cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8) 
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')

After this PR,

cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8) 
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')

Stack from ghstack (oldest at bottom):

  • (to be filled)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Apr 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124921

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 03dd525 with merge base 59a1f1f (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zhuhaozhe zhuhaozhe marked this pull request as draft April 25, 2024 08:23
zhuhaozhe added a commit that referenced this pull request Apr 25, 2024
ghstack-source-id: 2cd2271ed9ecb25a3d92aaa00ff9537d32202535
Pull Request resolved: #124921
@zhuhaozhe zhuhaozhe requested a review from jansel April 26, 2024 04:45
@zhuhaozhe zhuhaozhe marked this pull request as ready for review April 26, 2024 04:47
@zhuhaozhe zhuhaozhe added the topic: not user facing topic category label Apr 28, 2024
@zhuhaozhe
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 28, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

andoorve pushed a commit to andoorve/pytorch that referenced this pull request May 1, 2024
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: pytorch#124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: pytorch#124597
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: #124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #124597
@github-actions github-actions bot deleted the gh/zhuhaozhe/27/head branch June 3, 2024 01:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

5 participants