Skip to content

Commit

Permalink
[inductor cpp] vectorize support for truediv (#112234)
Browse files Browse the repository at this point in the history
Ops like group_norm has `ops.truediv` that doesn't have vectorization support yet. This PR adds the support.

`test_group_norm_vec`
Before:
```c++
extern "C" void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(64)
    {
        {
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L))
            {
                {
                    #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()})
                    #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()})
                    Welford<float> tmp_acc0 = Welford<float>();
                    Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                    for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                        tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0);
                    }
                    tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                    out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean);
                    out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2);
                }
            }
        }
        {
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    #pragma GCC ivdep
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L))
                    {
                        auto tmp0 = in_ptr0[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))];
                        auto tmp1 = out_ptr0[static_cast<long>(x1 + (32L*x0))];
                        auto tmp3 = out_ptr1[static_cast<long>(x1 + (32L*x0))];
                        auto tmp10 = in_ptr1[static_cast<long>(x1)];
                        auto tmp12 = in_ptr2[static_cast<long>(x1)];
                        auto tmp2 = tmp0 - tmp1;
                        auto tmp4 = c10::convert<float>(1024.0);
                        auto tmp5 = tmp3 / tmp4;
                        auto tmp6 = c10::convert<float>(1e-05);
                        auto tmp7 = tmp5 + tmp6;
                        auto tmp8 = 1 / std::sqrt(tmp7);
                        auto tmp9 = decltype(tmp2)(tmp2 * tmp8);
                        auto tmp11 = decltype(tmp9)(tmp9 * tmp10);
                        auto tmp13 = tmp11 + tmp12;
                        out_ptr2[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))] = tmp13;
                    }
                }
            }
        }
    }
}
```

After:
```c++
extern "C" void kernel(const float* in_ptr0,
                       const float* in_ptr1,
                       const float* in_ptr2,
                       float* out_ptr0,
                       float* out_ptr1,
                       float* out_ptr2)
{
    #pragma omp parallel num_threads(64)
    {
        {
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L))
            {
                {
                    #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()})
                    #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()})
                    Welford<float> tmp_acc0 = Welford<float>();
                    Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>();
                    for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0)));
                        tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0);
                    }
                    tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec));
                    out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean);
                    out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2);
                }
            }
        }
        {
            #pragma omp for  collapse(2)
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L))
                {
                    for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0)));
                        auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(out_ptr0[static_cast<long>(x1 + (32L*x0))]));
                        auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(out_ptr1[static_cast<long>(x1 + (32L*x0))]));
                        auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(in_ptr1[static_cast<long>(x1)]));
                        auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(in_ptr2[static_cast<long>(x1)]));
                        auto tmp2 = tmp0 - tmp1;
                        auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(1024.0));
                        auto tmp5 = tmp3 / tmp4;
                        auto tmp6 = at::vec::Vectorized<float>(static_cast<float>(1e-05));
                        auto tmp7 = tmp5 + tmp6;
                        auto tmp8 = tmp7.rsqrt();
                        auto tmp9 = tmp2 * tmp8;
                        auto tmp11 = tmp9 * tmp10;
                        auto tmp13 = tmp11 + tmp12;
                        tmp13.store(out_ptr2 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0)));
                    }
                }
            }
        }
    }
}
```

Pull Request resolved: #112234
Approved by: https://github.com/lezcano, https://github.com/jansel
  • Loading branch information
jgong5 authored and pytorchmergebot committed Oct 31, 2023
1 parent b91fcdf commit a1c56df
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,6 +2534,36 @@ def forward(self, x):
)
assert metrics.generated_kernel_count == 0

def test_group_norm_vec(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.group_norm = torch.nn.GroupNorm(32, 32)

def forward(self, x):
return self.group_norm(x)

metrics.reset()
mod = M().eval()
x = torch.randn(2, 32, 32, 32)
with torch.no_grad():
self.common(mod, (x,))
# 2 generated kernels (one for var_mean, the other for result)
assert metrics.generated_cpp_vec_kernel_count == 2

def test_int_div_vec(self):
def fn(x, y, mode):
return torch.div(x, y, rounding_mode=mode)

x = torch.randint(1, 100, (32, 32))
y = torch.randint(1, 100, (32, 32))
for mode in [None, "trunc", "floor"]:
with torch.no_grad():
metrics.reset()
self.common(fn, (x, y, mode))
# TODO: support vectorization for int div
assert metrics.generated_cpp_vec_kernel_count == 0


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ def mul(a, b):
def div(a, b):
return f"{a} / {b}"

@staticmethod
def truediv(a, b):
return f"{a} / {b}"

@staticmethod
def abs(x):
return f"{x}.abs()"
Expand Down

0 comments on commit a1c56df

Please sign in to comment.