Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor cpp] vectorize support for truediv (#112234)
Ops like group_norm has `ops.truediv` that doesn't have vectorization support yet. This PR adds the support. `test_group_norm_vec` Before: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))]; auto tmp1 = out_ptr0[static_cast<long>(x1 + (32L*x0))]; auto tmp3 = out_ptr1[static_cast<long>(x1 + (32L*x0))]; auto tmp10 = in_ptr1[static_cast<long>(x1)]; auto tmp12 = in_ptr2[static_cast<long>(x1)]; auto tmp2 = tmp0 - tmp1; auto tmp4 = c10::convert<float>(1024.0); auto tmp5 = tmp3 / tmp4; auto tmp6 = c10::convert<float>(1e-05); auto tmp7 = tmp5 + tmp6; auto tmp8 = 1 / std::sqrt(tmp7); auto tmp9 = decltype(tmp2)(tmp2 * tmp8); auto tmp11 = decltype(tmp9)(tmp9 * tmp10); auto tmp13 = tmp11 + tmp12; out_ptr2[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))] = tmp13; } } } } } } ``` After: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(out_ptr0[static_cast<long>(x1 + (32L*x0))])); auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(out_ptr1[static_cast<long>(x1 + (32L*x0))])); auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(in_ptr1[static_cast<long>(x1)])); auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(in_ptr2[static_cast<long>(x1)])); auto tmp2 = tmp0 - tmp1; auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(1024.0)); auto tmp5 = tmp3 / tmp4; auto tmp6 = at::vec::Vectorized<float>(static_cast<float>(1e-05)); auto tmp7 = tmp5 + tmp6; auto tmp8 = tmp7.rsqrt(); auto tmp9 = tmp2 * tmp8; auto tmp11 = tmp9 * tmp10; auto tmp13 = tmp11 + tmp12; tmp13.store(out_ptr2 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); } } } } } } ``` Pull Request resolved: #112234 Approved by: https://github.com/lezcano, https://github.com/jansel
- Loading branch information