Skip to content

Conversation

@jiayisunx
Copy link
Collaborator

@jiayisunx jiayisunx commented Sep 17, 2025

Stack from ghstack (oldest at bottom):

Fix #151400.
Summary:
Optimize the heuristics of sum reduction, reduce the chunk size of cascade sum to improve numerical stability.
I ran the Inductor benchmark with this PR on CPU, and no performance regression is seen.

Example:
Take #151400 as an example:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        vec = x.flatten()
        vec_one = torch.ones_like(vec)
        x = torch.outer(vec, vec_one)
        return torch.mean(x, dim=1)


model = Model()

x = torch.randn(3, 8, 64, 64)  # error will be amplified as the input tensor gets larger

inputs = [x]


def run_test(model, inputs, backend):
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    torch.manual_seed(0)
    output = model(*inputs)
    return output


output = run_test(model, inputs, 'eager')
c_output = run_test(model, inputs, 'inductor')
fp64 = run_test(model.to(dtype=torch.float64), [inputs[0].to(dtype=torch.float64)], 'eager')

print(torch.allclose(output, c_output, rtol=1e-3, atol=1e-3))
print(torch.max(torch.abs(c_output - output)))
print(torch._dynamo.utils.same(output, c_output, fp64))

logs:

  • Before
False
tensor(0.0052)
False
  • After
True
tensor(0.0004)
True

Generated code:

  • Before
cpp_fused_mean_mul_ones_like_view_0 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    #pragma omp parallel num_threads(240)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
            {
                {
                    float tmp_acc0 = 0;
                    at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(98304L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                                auto tmp1 = static_cast<float>(1.0);
                                auto tmp2 = at::vec::Vectorized<float>(tmp1);
                                auto tmp3 = tmp0 * tmp2;
                                tmp_acc0_vec = tmp_acc0_vec + tmp3;
                            }
                        }
                    }
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        tmp_acc0_vec.store(out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        auto tmp1 = static_cast<float>(98304.0);
                        auto tmp2 = at::vec::Vectorized<float>(tmp1);
                        auto tmp3 = tmp0 / tmp2;
                        tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (3, 8, 64, 64), (32768, 4096, 64, 1))
        buf0 = empty_strided_cpu((98304, ), (1, ), torch.float32)
        buf1 = buf0; del buf0  # reuse
        # [Provenance debug handles] cpp_fused_mean_mul_ones_like_view_0:1
        cpp_fused_mean_mul_ones_like_view_0(buf1, arg0_1)
        del arg0_1
        return (buf1, )
  • After
cpp_fused_mean_mul_ones_like_view_0 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    #pragma omp parallel num_threads(240)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
            {
                {
                    float tmp_acc0 = 0;
                    at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
                    at::vec::Vectorized<float> masked_tmp_acc0_vec = at::vec::Vectorized<float>(0);
                    CascadeSumHelper<float, 4096> scalar_cascade_helper0(static_cast<int64_t>(98304L));
                    CascadeSumHelper<at::vec::Vectorized<float>, 4096> cascade_helper0(static_cast<int64_t>(98304L));
                    CascadeSumHelper<at::vec::Vectorized<float>, 4096> masked_cascade_helper0(static_cast<int64_t>(0L));
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(98304L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                                auto tmp1 = static_cast<float>(1.0);
                                auto tmp2 = at::vec::Vectorized<float>(tmp1);
                                auto tmp3 = tmp0 * tmp2;
                                tmp_acc0_vec = cascade_sum_combine(tmp3, &cascade_helper0);
                            }
                        }
                    }
                    tmp_acc0 = cascade_sum_final(&scalar_cascade_helper0);
                    tmp_acc0_vec = cascade_sum_final(&cascade_helper0);
                    masked_tmp_acc0_vec = cascade_sum_final(&masked_cascade_helper0);
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        tmp_acc0_vec = tmp_acc0_vec + masked_tmp_acc0_vec;
                        tmp_acc0_vec.store(out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        auto tmp1 = static_cast<float>(98304.0);
                        auto tmp2 = at::vec::Vectorized<float>(tmp1);
                        auto tmp3 = tmp0 / tmp2;
                        tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
            }
        }
    }
}
''')


async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (3, 8, 64, 64), (32768, 4096, 64, 1))
        buf0 = empty_strided_cpu((98304, ), (1, ), torch.float32)
        buf1 = buf0; del buf0  # reuse
        # [Provenance debug handles] cpp_fused_mean_mul_ones_like_view_0:1
        cpp_fused_mean_mul_ones_like_view_0(buf1, arg0_1)
        del arg0_1
        return (buf1, )

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163144

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 638c751 with merge base fee1ac9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jiayisunx added a commit that referenced this pull request Sep 17, 2025
ghstack-source-id: eb4a794
Pull Request resolved: #163144
@jiayisunx jiayisunx marked this pull request as draft September 17, 2025 05:37
[ghstack-poisoned]
@jiayisunx jiayisunx added release notes: inductor ciflow/trunk Trigger trunk jobs on your pull request labels Sep 17, 2025
jiayisunx added a commit that referenced this pull request Sep 17, 2025
ghstack-source-id: e88e7e6
Pull Request resolved: #163144
[ghstack-poisoned]
jiayisunx added a commit that referenced this pull request Sep 18, 2025
ghstack-source-id: 80a959d
Pull Request resolved: #163144
[ghstack-poisoned]
@jiayisunx jiayisunx requested review from CaoE and mingfeima September 19, 2025 01:51
jiayisunx added a commit that referenced this pull request Oct 11, 2025
ghstack-source-id: e917625
Pull Request resolved: #163144
[ghstack-poisoned]
jiayisunx added a commit that referenced this pull request Nov 4, 2025
ghstack-source-id: 6244e6d
Pull Request resolved: #163144
[ghstack-poisoned]
@jiayisunx jiayisunx requested a review from CaoE November 7, 2025 07:53
@jiayisunx jiayisunx requested a review from jansel November 10, 2025 07:11
@jiayisunx jiayisunx marked this pull request as ready for review November 10, 2025 08:12
@jiayisunx
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants