Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] share cse cache during vectorized indirect load #124597

Closed
wants to merge 8 commits into from

Conversation

zhuhaozhe
Copy link
Collaborator

@zhuhaozhe zhuhaozhe commented Apr 22, 2024

Fix #123502

swap_buffer in not needed in vectorized indirect load, remove it to share cse buffer.

auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Apr 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124597

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit dda79ae with merge base 59a1f1f (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

zhuhaozhe added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: a11a76fdbe9b6eab89a881188ee66b405f4bc461
Pull Request resolved: #124597
@zhuhaozhe zhuhaozhe marked this pull request as draft April 22, 2024 06:48
@zhuhaozhe zhuhaozhe added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 22, 2024
[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: 08e123950741038a088517781a434629f122654a
Pull Request resolved: #124597
Fix #123502





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@zhuhaozhe zhuhaozhe changed the title also clone cse cache [inductor] allow clone cse cache during vectorized indirect load Apr 22, 2024
zhuhaozhe added a commit that referenced this pull request Apr 22, 2024
ghstack-source-id: 3d5eb12fdc823729df0a6b67f8bf042c5d6c9b73
Pull Request resolved: #124597
@zhuhaozhe zhuhaozhe requested a review from jgong5 April 22, 2024 13:12
…t load"

Fix #123502





cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@@ -1348,9 +1348,11 @@ def set_current_node(self, node):
self.current_node = prior

@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
def swap_buffers(self, lb, cb=None, sb=None, cse_cache=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "cse" object has cache, reduction_cache and store_cache. A single cse_cache cannot represent all of them. Maybe it is better to pass cse directly as to override the new "cse" after the buffers are swapped?

@@ -2318,7 +2318,7 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
assert vec_var.is_vec
code = BracesBuffer()
code.writeline("[&]")
with self.swap_buffers(code), code.indent():
with self.swap_buffers(code, cse_cache=self.cse.cache), code.indent():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When new variables are defined in the scope, they cannot be used outside the scope afterwards. So, it is not safe to reuse the existing self.cse.cache here. I guess we want to create a wrapper around the cache abstraction that supports scoping, allowing variables and expressions defined within a scope to be valid only within that scope while still leveraging the existing cache from the outer scope before it. How about we create something like below?

def scope_cse(cse):
    new_cse = cse.clone()
    new_cse.cache = ScopedDict(cse.cache)
    new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
    new_cse.store_cache = ScopedDict(cse.store_cache)
    return new_cse

Here the ScopedDict allows read-only access to the original dict while uses a separate dict to hold newly-added items:

class ScopedDict:
    def __init__(self, original_dict):
        self.original_dict = original_dict
        self.new_items = {}
    
    def __getitem__(self, key):
        if key in self.new_items:
            return self.new_items[key]
        return self.original_dict[key]
    
    def __setitem__(self, key, value):
        self.new_items[key] = value
    
    def __contains__(self, key):
        return key in self.new_items or key in self.original_dict
    
    def get(self, key, default=None):
        if key in self.new_items:
            return self.new_items[key]
        return self.original_dict.get(key, default)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When new variables are defined in the scope, they cannot be used outside the scope afterwards. So, it is not safe to reuse the existing self.cse.cache here. I guess we want to create a wrapper around the cache abstraction that supports scoping, allowing variables and expressions defined within a scope to be valid only within that scope while still leveraging the existing cache from the outer scope before it. How about we create something like below?

def scope_cse(cse):
    new_cse = cse.clone()
    new_cse.cache = ScopedDict(cse.cache)
    new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
    new_cse.store_cache = ScopedDict(cse.store_cache)
    return new_cse

Here the ScopedDict allows read-only access to the original dict while uses a separate dict to hold newly-added items:

class ScopedDict:
    def __init__(self, original_dict):
        self.original_dict = original_dict
        self.new_items = {}
    
    def __getitem__(self, key):
        if key in self.new_items:
            return self.new_items[key]
        return self.original_dict[key]
    
    def __setitem__(self, key, value):
        self.new_items[key] = value
    
    def __contains__(self, key):
        return key in self.new_items or key in self.original_dict
    
    def get(self, key, default=None):
        if key in self.new_items:
            return self.new_items[key]
        return self.original_dict.get(key, default)

Sounds better, I will try this way.

zhuhaozhe added a commit that referenced this pull request Apr 24, 2024
ghstack-source-id: 616a56d2341cd20ecca38d70e130b90e8f67bb3b
Pull Request resolved: #124597
…t load"

Fix #123502

`swap_buffer` do not clone the `cse.cache` which will bring redundant computation.
We may able to clone the `cse.cache` if there is no cse value in the `expr`
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request Apr 24, 2024
ghstack-source-id: f0bfb65da321fa7b02468e5980682e9e3f9d2160
Pull Request resolved: #124597
…t load"

Fix #123502

`swap_buffer` do not clone the `cse.cache` which will bring redundant computation.
We may able to clone the `cse.cache` if there is no cse value in the `expr`
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request Apr 25, 2024
ghstack-source-id: 2f026a9472e118acc783edc57aefbc9db4e4d582
Pull Request resolved: #124597
…t load"

Fix #123502

`swap_buffer` do not clone the `cse.cache` which will bring redundant computation.
We may able to clone the `cse.cache` if there is no cse value in the `expr`
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
zhuhaozhe added a commit that referenced this pull request Apr 25, 2024
ghstack-source-id: 0bd0a3852d71d986f37e0acd337e64751c3fa678
Pull Request resolved: #124597
…t load"

Fix #123502

`swap_buffer` do not clone the `cse.cache` which will bring redundant computation.
We may able to clone the `cse.cache` if there is no cse value in the `expr`
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
…t load"

Fix #123502

`swap_buffer` do not clone the `cse.cache` which will bring redundant computation.
We may able to clone the `cse.cache` if there is no cse value in the `expr`
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@zhuhaozhe
Copy link
Collaborator Author

zhuhaozhe commented Apr 25, 2024

Hi, @jgong5. For this issue #123502, your recommendation for ScopedDict should not work because

            auto tmp8 =
            [&]
            {
                __at_align__ std::array<int64_t, 16> tmpbuf;
                tmp7.store(tmpbuf.data());
                return tmpbuf;
            }
            ()
            ;
            auto tmp9 =
            [&]
            {
                __at_align__ std::array<float, 16> tmpbuf;
                #pragma GCC unroll 16
                for (long x0_inner = 0; x0_inner < 16; x0_inner++)
                {
                    tmpbuf[x0_inner] = out_ptr0[static_cast<long>(tmp8[x0_inner])];
                }
                return at::vec::Vectorized<float>::loadu(tmpbuf.data());
            }
            ()
            ;

the tmp8 is created inside this scope here. But actually I observed the tmp8 should be able to be used outside because the buffer here in csevar = self.cse.generate(buffer, code) is not inside the lambda function (of tmp9) scope.

For my current understanding we may not need to always put a swap_buffer after a [&], for _load_or_store_non_contiguous ,should it be safe to remove swap buffer here?

@zhuhaozhe
Copy link
Collaborator Author

zhuhaozhe commented Apr 25, 2024

As for the ScopedDict, it indeed help.
I found a case test_AllenaiLongformerBase_repro_cpu and make it as an UT in another PR #124921.
It will only generate tmp_x=static_cast<int>(256) once after ScopedDict impl, and it will be 3 times before it.

@zhuhaozhe zhuhaozhe requested a review from jgong5 April 25, 2024 11:38
@zhuhaozhe zhuhaozhe requested a review from jansel April 26, 2024 04:45
@zhuhaozhe zhuhaozhe marked this pull request as ready for review April 26, 2024 04:48
@zhuhaozhe zhuhaozhe changed the title [inductor] allow clone cse cache during vectorized indirect load [inductor] share cse cache during vectorized indirect load Apr 26, 2024
@zhuhaozhe
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Apr 28, 2024
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: #124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #124597
carmocca pushed a commit to carmocca/pytorch that referenced this pull request Apr 29, 2024
…24597)

Fix pytorch#123502

`swap_buffer` in not needed in vectorized indirect load, remove it to share cse buffer.
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```

Pull Request resolved: pytorch#124597
Approved by: https://github.com/jgong5, https://github.com/jansel
carmocca pushed a commit to carmocca/pytorch that referenced this pull request Apr 29, 2024
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: pytorch#124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: pytorch#124597
andoorve pushed a commit to andoorve/pytorch that referenced this pull request May 1, 2024
…24597)

Fix pytorch#123502

`swap_buffer` in not needed in vectorized indirect load, remove it to share cse buffer.
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```

Pull Request resolved: pytorch#124597
Approved by: https://github.com/jgong5, https://github.com/jansel
andoorve pushed a commit to andoorve/pytorch that referenced this pull request May 1, 2024
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: pytorch#124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: pytorch#124597
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…24597)

Fix pytorch#123502

`swap_buffer` in not needed in vectorized indirect load, remove it to share cse buffer.
```
auto tmp8 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
//
// other codes
//
// also store tmp7 here (redundant tmp16)
auto tmp16 =
[&]
{
    __at_align__ std::array<int64_t, 16> tmpbuf;
    tmp7.store(tmpbuf.data());
    return tmpbuf;
}
()
;
```

Pull Request resolved: pytorch#124597
Approved by: https://github.com/jgong5, https://github.com/jansel
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
`swap_buffer` will make the `cse_cache` cannot be shared inside/outside of the lambda function scope.
For example,

```
auto tmp8 = -std::numeric_limits<float>::infinity();
auto tmp9 = [&]
{
    auto tmp12 = -std::numeric_limits<float>::infinity();
    return tmp12;
}
```
`tmp12` should not be created since it is same with `tmp8`.

We make the `cse_cache` as a read only cache inside the scope (because it is unsafe to expose cache inside the scope,the outside scope cannot use it.)

**Test Plan**
```
python test/inductor/test_torchinductor.py -k test_AllenaiLongformerBase_repro_cpu
```
the `static_cast<int>(256)` will only occur once after this PR since the inside scope can share the cse buffer outside the scope.

Before this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = c10::convert<int>(x3);
                                    auto tmp20 = at::vec::Vectorized<int>::arange(tmp19, 1);
                                    auto tmp21 = static_cast<int>(256);
                                    auto tmp22 = at::vec::Vectorized<int>(tmp21);
                                    auto tmp23 = at::vec::VecMask<int,1>(tmp20 >= tmp22);
                                    auto tmp25 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp26 = tmp23 & tmp25;
                                    auto tmp24 = [&]
                                    {
                                        auto tmp27 = tmp26.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp27;
                                    }
                                    ;
                                    auto tmp28 =
                                    [&]
                                    {
                                        if (tmp26.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp24())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp24(), tmp26.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp29 = static_cast<float>(0.0);
                                    auto tmp30 = at::vec::Vectorized<float>(tmp29);
                                    auto tmp31 = decltype(tmp28)::blendv(tmp30, tmp28, tmp23.template cast<float,1>());
                                    return tmp31;
                                }
                                ;
                                auto tmp32 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp33 = static_cast<float>(0.0);
                                auto tmp34 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp35 = at::vec::Vectorized<float>(tmp33);
                                auto tmp36 = decltype(tmp32)::blendv(tmp35, tmp32, tmp34.template cast<float,1>());
                                auto tmp37 = decltype(tmp13)::blendv(tmp36, tmp13, tmp8.template cast<float,1>());
                                return tmp37;
                            }
                            ;
                            auto tmp38 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp39 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp40 = static_cast<int>(3);
                            auto tmp41 = tmp39 < tmp40;
                            auto tmp42 = [&]
                            {
                                auto tmp43 = c10::convert<int>(x3);
                                auto tmp44 = at::vec::Vectorized<int>::arange(tmp43, 1);
                                auto tmp45 = static_cast<int>(256);
                                auto tmp46 = at::vec::Vectorized<int>(tmp45);
                                auto tmp47 = at::vec::VecMask<int,1>(tmp44 >= tmp46);
                                auto tmp49 = at::vec::VecMask<float,1>::from(tmp41);
                                auto tmp50 = tmp47 & tmp49;
                                auto tmp48 = [&]
                                {
                                    auto tmp51 = tmp50.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp51;
                                }
                                ;
                                auto tmp52 =
                                [&]
                                {
                                    if (tmp50.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp48())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp48(), tmp50.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp53 = static_cast<float>(0.0);
                                auto tmp54 = at::vec::Vectorized<float>(tmp53);
                                auto tmp55 = decltype(tmp52)::blendv(tmp54, tmp52, tmp47.template cast<float,1>());
                                return tmp55;
                            }
                            ;
                            auto tmp56 = tmp41 ? tmp42() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp57 = static_cast<float>(0.0);
                            auto tmp58 = at::vec::VecMask<float,1>::from(tmp41);
                            auto tmp59 = at::vec::Vectorized<float>(tmp57);
                            auto tmp60 = decltype(tmp56)::blendv(tmp59, tmp56, tmp58.template cast<float,1>());
                            auto tmp61 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp62 = decltype(tmp38)::blendv(tmp60, tmp38, tmp61.template cast<float,1>());
                            tmp62.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = c10::convert<int64_t>(x3);
                                    auto tmp15 = static_cast<int64_t>(256);
                                    auto tmp16 = tmp14 >= tmp15;
                                    auto tmp17 = [&]
                                    {
                                        auto tmp18 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp18;
                                    }
                                    ;
                                    auto tmp19 = tmp16 ? tmp17() : static_cast<decltype(tmp17())>(0.0);
                                    auto tmp20 = static_cast<float>(0.0);
                                    auto tmp21 = tmp16 ? tmp19 : tmp20;
                                    return tmp21;
                                }
                                ;
                                auto tmp22 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp23 = static_cast<float>(0.0);
                                auto tmp24 = tmp12 ? tmp22 : tmp23;
                                auto tmp25 = tmp6 ? tmp9 : tmp24;
                                return tmp25;
                            }
                            ;
                            auto tmp26 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp27 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp28 = static_cast<int64_t>(3);
                            auto tmp29 = tmp27 < tmp28;
                            auto tmp30 = [&]
                            {
                                auto tmp31 = c10::convert<int64_t>(x3);
                                auto tmp32 = static_cast<int64_t>(256);
                                auto tmp33 = tmp31 >= tmp32;
                                auto tmp34 = [&]
                                {
                                    auto tmp35 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp35;
                                }
                                ;
                                auto tmp36 = tmp33 ? tmp34() : static_cast<decltype(tmp34())>(0.0);
                                auto tmp37 = static_cast<float>(0.0);
                                auto tmp38 = tmp33 ? tmp36 : tmp37;
                                return tmp38;
                            }
                            ;
                            auto tmp39 = tmp29 ? tmp30() : static_cast<decltype(tmp30())>(0.0);
                            auto tmp40 = static_cast<float>(0.0);
                            auto tmp41 = tmp29 ? tmp39 : tmp40;
                            auto tmp42 = tmp2 ? tmp26 : tmp41;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp42;
                        }
                    }
                }
            }
        }
    }
}
''')
```
After this PR,
```
cpp_fused_copy_full_like_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_root/ub/cub6x5nmhqhp7xapkb3dlgjxef3t2bnkx7y7n4z2f4z5obnecxpy.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr1)
{
    #pragma omp parallel num_threads(128)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(4L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(12L); x2+=static_cast<long>(1L))
                    {
                        for(long x3=static_cast<long>(0L); x3<static_cast<long>(512L); x3+=static_cast<long>(16L))
                        {
                            auto tmp0 = c10::convert<int>(x1);
                            auto tmp1 = static_cast<int>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int>(x3);
                                auto tmp5 = at::vec::Vectorized<int>::arange(tmp4, 1);
                                auto tmp6 = static_cast<int>(257);
                                auto tmp7 = at::vec::Vectorized<int>(tmp6);
                                auto tmp8 = at::vec::VecMask<int,1>(tmp5 < tmp7);
                                auto tmp10 = at::vec::VecMask<float,1>::from(tmp2);
                                auto tmp11 = tmp8 & tmp10;
                                auto tmp9 = [&]
                                {
                                    auto tmp12 = -std::numeric_limits<float>::infinity();
                                    return tmp12;
                                }
                                ;
                                auto tmp13 =
                                [&]
                                {
                                    if (tmp11.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(at::vec::Vectorized<float>(tmp9()))::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), at::vec::Vectorized<float>(tmp9()), tmp11.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp14 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                                auto tmp15 = static_cast<int>(3);
                                auto tmp16 = tmp14 < tmp15;
                                auto tmp18 = tmp16 & tmp2;
                                auto tmp17 = [&]
                                {
                                    auto tmp19 = at::vec::Vectorized<int>(tmp1);
                                    auto tmp20 = at::vec::VecMask<int,1>(tmp5 >= tmp19);
                                    auto tmp22 = at::vec::VecMask<float,1>::from(tmp18);
                                    auto tmp23 = tmp20 & tmp22;
                                    auto tmp21 = [&]
                                    {
                                        auto tmp24 = tmp23.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                        return tmp24;
                                    }
                                    ;
                                    auto tmp25 =
                                    [&]
                                    {
                                        if (tmp23.all_zero())
                                        {
                                            return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                        }
                                        else
                                        {
                                            return decltype(tmp21())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp21(), tmp23.template cast<float,1>());
                                        }
                                    }
                                    ()
                                    ;
                                    auto tmp26 = static_cast<float>(0.0);
                                    auto tmp27 = at::vec::Vectorized<float>(tmp26);
                                    auto tmp28 = decltype(tmp25)::blendv(tmp27, tmp25, tmp20.template cast<float,1>());
                                    return tmp28;
                                }
                                ;
                                auto tmp29 = tmp16 ? tmp17() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                                auto tmp30 = static_cast<float>(0.0);
                                auto tmp31 = at::vec::VecMask<float,1>::from(tmp16);
                                auto tmp32 = at::vec::Vectorized<float>(tmp30);
                                auto tmp33 = decltype(tmp29)::blendv(tmp32, tmp29, tmp31.template cast<float,1>());
                                auto tmp34 = decltype(tmp13)::blendv(tmp33, tmp13, tmp8.template cast<float,1>());
                                return tmp34;
                            }
                            ;
                            auto tmp35 = tmp2 ? tmp3() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp36 = c10::convert<int>(c10::div_floor_integer(x1, 256L));
                            auto tmp37 = static_cast<int>(3);
                            auto tmp38 = tmp36 < tmp37;
                            auto tmp39 = [&]
                            {
                                auto tmp40 = c10::convert<int>(x3);
                                auto tmp41 = at::vec::Vectorized<int>::arange(tmp40, 1);
                                auto tmp42 = at::vec::Vectorized<int>(tmp1);
                                auto tmp43 = at::vec::VecMask<int,1>(tmp41 >= tmp42);
                                auto tmp45 = at::vec::VecMask<float,1>::from(tmp38);
                                auto tmp46 = tmp43 & tmp45;
                                auto tmp44 = [&]
                                {
                                    auto tmp47 = tmp46.template cast<float,1>().template loadu<float,1>(in_ptr0 + static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0)));
                                    return tmp47;
                                }
                                ;
                                auto tmp48 =
                                [&]
                                {
                                    if (tmp46.all_zero())
                                    {
                                        return at::vec::Vectorized<float>(static_cast<float>(0.0));
                                    }
                                    else
                                    {
                                        return decltype(tmp44())::blendv(at::vec::Vectorized<float>(static_cast<float>(0.0)), tmp44(), tmp46.template cast<float,1>());
                                    }
                                }
                                ()
                                ;
                                auto tmp49 = static_cast<float>(0.0);
                                auto tmp50 = at::vec::Vectorized<float>(tmp49);
                                auto tmp51 = decltype(tmp48)::blendv(tmp50, tmp48, tmp43.template cast<float,1>());
                                return tmp51;
                            }
                            ;
                            auto tmp52 = tmp38 ? tmp39() : at::vec::Vectorized<float>(static_cast<float>(0.0));
                            auto tmp53 = static_cast<float>(0.0);
                            auto tmp54 = at::vec::VecMask<float,1>::from(tmp38);
                            auto tmp55 = at::vec::Vectorized<float>(tmp53);
                            auto tmp56 = decltype(tmp52)::blendv(tmp55, tmp52, tmp54.template cast<float,1>());
                            auto tmp57 = at::vec::VecMask<float,1>::from(tmp2);
                            auto tmp58 = decltype(tmp35)::blendv(tmp56, tmp35, tmp57.template cast<float,1>());
                            tmp58.store(out_ptr1 + static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0)));
                        }
                        #pragma omp simd simdlen(8)
                        for(long x3=static_cast<long>(512L); x3<static_cast<long>(513L); x3+=static_cast<long>(1L))
                        {
                            auto tmp0 = c10::convert<int64_t>(x1);
                            auto tmp1 = static_cast<int64_t>(256);
                            auto tmp2 = tmp0 < tmp1;
                            auto tmp3 = [&]
                            {
                                auto tmp4 = c10::convert<int64_t>(x3);
                                auto tmp5 = static_cast<int64_t>(257);
                                auto tmp6 = tmp4 < tmp5;
                                auto tmp7 = [&]
                                {
                                    auto tmp8 = -std::numeric_limits<float>::infinity();
                                    return tmp8;
                                }
                                ;
                                auto tmp9 = tmp6 ? tmp7() : static_cast<decltype(tmp7())>(0.0);
                                auto tmp10 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                                auto tmp11 = static_cast<int64_t>(3);
                                auto tmp12 = tmp10 < tmp11;
                                auto tmp13 = [&]
                                {
                                    auto tmp14 = tmp4 >= tmp1;
                                    auto tmp15 = [&]
                                    {
                                        auto tmp16 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                        return tmp16;
                                    }
                                    ;
                                    auto tmp17 = tmp14 ? tmp15() : static_cast<decltype(tmp15())>(0.0);
                                    auto tmp18 = static_cast<float>(0.0);
                                    auto tmp19 = tmp14 ? tmp17 : tmp18;
                                    return tmp19;
                                }
                                ;
                                auto tmp20 = tmp12 ? tmp13() : static_cast<decltype(tmp13())>(0.0);
                                auto tmp21 = static_cast<float>(0.0);
                                auto tmp22 = tmp12 ? tmp20 : tmp21;
                                auto tmp23 = tmp6 ? tmp9 : tmp22;
                                return tmp23;
                            }
                            ;
                            auto tmp24 = tmp2 ? tmp3() : static_cast<decltype(tmp3())>(0.0);
                            auto tmp25 = c10::convert<int64_t>(c10::div_floor_integer(x1, 256L));
                            auto tmp26 = static_cast<int64_t>(3);
                            auto tmp27 = tmp25 < tmp26;
                            auto tmp28 = [&]
                            {
                                auto tmp29 = c10::convert<int64_t>(x3);
                                auto tmp30 = tmp29 >= tmp1;
                                auto tmp31 = [&]
                                {
                                    auto tmp32 = in_ptr0[static_cast<long>((-256L) + x3 + (513L*(static_cast<long>(x1) % static_cast<long>(256L))) + (262656L*(c10::div_floor_integer(x1, 256L))) + (787968L*x2) + (9455616L*x0))];
                                    return tmp32;
                                }
                                ;
                                auto tmp33 = tmp30 ? tmp31() : static_cast<decltype(tmp31())>(0.0);
                                auto tmp34 = static_cast<float>(0.0);
                                auto tmp35 = tmp30 ? tmp33 : tmp34;
                                return tmp35;
                            }
                            ;
                            auto tmp36 = tmp27 ? tmp28() : static_cast<decltype(tmp28())>(0.0);
                            auto tmp37 = static_cast<float>(0.0);
                            auto tmp38 = tmp27 ? tmp36 : tmp37;
                            auto tmp39 = tmp2 ? tmp24 : tmp38;
                            out_ptr1[static_cast<long>(x3 + (513L*x1) + (525312L*x2) + (6303744L*x0))] = tmp39;
                        }
                    }
                }
            }
        }
    }
}
''')
```

Pull Request resolved: pytorch#124921
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: pytorch#124597
@github-actions github-actions bot deleted the gh/zhuhaozhe/22/head branch June 3, 2024 01:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

None yet

5 participants