-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[inductor cpp] vectorize support for truediv #112234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112234
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9d430e5 with merge base 29844ad ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -455,6 +455,10 @@ def mul(a, b): | |||
def div(a, b): | |||
return f"{a} / {b}" | |||
|
|||
@staticmethod | |||
def truediv(a, b): | |||
return f"{a} / {b}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this be correct for integer types?
In C++ int/int is floordiv, while float/float is truediv.
Add a test case for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Integer types don't hit it. I added a test case for int div with a TODO to vectorize them.
Ops like group_norm has `ops.truediv` that doesn't have vectorization support yet. This PR adds the support. `test_group_norm_vec` Before: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))]; auto tmp1 = out_ptr0[static_cast<long>(x1 + (32L*x0))]; auto tmp3 = out_ptr1[static_cast<long>(x1 + (32L*x0))]; auto tmp10 = in_ptr1[static_cast<long>(x1)]; auto tmp12 = in_ptr2[static_cast<long>(x1)]; auto tmp2 = tmp0 - tmp1; auto tmp4 = c10::convert<float>(1024.0); auto tmp5 = tmp3 / tmp4; auto tmp6 = c10::convert<float>(1e-05); auto tmp7 = tmp5 + tmp6; auto tmp8 = 1 / std::sqrt(tmp7); auto tmp9 = decltype(tmp2)(tmp2 * tmp8); auto tmp11 = decltype(tmp9)(tmp9 * tmp10); auto tmp13 = tmp11 + tmp12; out_ptr2[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))] = tmp13; } } } } } } ``` After: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(out_ptr0[static_cast<long>(x1 + (32L*x0))])); auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(out_ptr1[static_cast<long>(x1 + (32L*x0))])); auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(in_ptr1[static_cast<long>(x1)])); auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(in_ptr2[static_cast<long>(x1)])); auto tmp2 = tmp0 - tmp1; auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(1024.0)); auto tmp5 = tmp3 / tmp4; auto tmp6 = at::vec::Vectorized<float>(static_cast<float>(1e-05)); auto tmp7 = tmp5 + tmp6; auto tmp8 = tmp7.rsqrt(); auto tmp9 = tmp2 * tmp8; auto tmp11 = tmp9 * tmp10; auto tmp13 = tmp11 + tmp12; tmp13.store(out_ptr2 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); } } } } } } ``` cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
Ops like group_norm has `ops.truediv` that doesn't have vectorization support yet. This PR adds the support. `test_group_norm_vec` Before: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))]; auto tmp1 = out_ptr0[static_cast<long>(x1 + (32L*x0))]; auto tmp3 = out_ptr1[static_cast<long>(x1 + (32L*x0))]; auto tmp10 = in_ptr1[static_cast<long>(x1)]; auto tmp12 = in_ptr2[static_cast<long>(x1)]; auto tmp2 = tmp0 - tmp1; auto tmp4 = c10::convert<float>(1024.0); auto tmp5 = tmp3 / tmp4; auto tmp6 = c10::convert<float>(1e-05); auto tmp7 = tmp5 + tmp6; auto tmp8 = 1 / std::sqrt(tmp7); auto tmp9 = decltype(tmp2)(tmp2 * tmp8); auto tmp11 = decltype(tmp9)(tmp9 * tmp10); auto tmp13 = tmp11 + tmp12; out_ptr2[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))] = tmp13; } } } } } } ``` After: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(out_ptr0[static_cast<long>(x1 + (32L*x0))])); auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(out_ptr1[static_cast<long>(x1 + (32L*x0))])); auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(in_ptr1[static_cast<long>(x1)])); auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(in_ptr2[static_cast<long>(x1)])); auto tmp2 = tmp0 - tmp1; auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(1024.0)); auto tmp5 = tmp3 / tmp4; auto tmp6 = at::vec::Vectorized<float>(static_cast<float>(1e-05)); auto tmp7 = tmp5 + tmp6; auto tmp8 = tmp7.rsqrt(); auto tmp9 = tmp2 * tmp8; auto tmp11 = tmp9 * tmp10; auto tmp13 = tmp11 + tmp12; tmp13.store(out_ptr2 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); } } } } } } ``` cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #112243 Approved by: https://github.com/lezcano ghstack dependencies: #112234
Ops like group_norm has `ops.truediv` that doesn't have vectorization support yet. This PR adds the support. `test_group_norm_vec` Before: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))]; auto tmp1 = out_ptr0[static_cast<long>(x1 + (32L*x0))]; auto tmp3 = out_ptr1[static_cast<long>(x1 + (32L*x0))]; auto tmp10 = in_ptr1[static_cast<long>(x1)]; auto tmp12 = in_ptr2[static_cast<long>(x1)]; auto tmp2 = tmp0 - tmp1; auto tmp4 = c10::convert<float>(1024.0); auto tmp5 = tmp3 / tmp4; auto tmp6 = c10::convert<float>(1e-05); auto tmp7 = tmp5 + tmp6; auto tmp8 = 1 / std::sqrt(tmp7); auto tmp9 = decltype(tmp2)(tmp2 * tmp8); auto tmp11 = decltype(tmp9)(tmp9 * tmp10); auto tmp13 = tmp11 + tmp12; out_ptr2[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))] = tmp13; } } } } } } ``` After: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(out_ptr0[static_cast<long>(x1 + (32L*x0))])); auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(out_ptr1[static_cast<long>(x1 + (32L*x0))])); auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(in_ptr1[static_cast<long>(x1)])); auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(in_ptr2[static_cast<long>(x1)])); auto tmp2 = tmp0 - tmp1; auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(1024.0)); auto tmp5 = tmp3 / tmp4; auto tmp6 = at::vec::Vectorized<float>(static_cast<float>(1e-05)); auto tmp7 = tmp5 + tmp6; auto tmp8 = tmp7.rsqrt(); auto tmp9 = tmp2 * tmp8; auto tmp11 = tmp9 * tmp10; auto tmp13 = tmp11 + tmp12; tmp13.store(out_ptr2 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); } } } } } } ``` Pull Request resolved: pytorch#112234 Approved by: https://github.com/lezcano, https://github.com/jansel
Pull Request resolved: pytorch#112243 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#112234
Ops like group_norm has `ops.truediv` that doesn't have vectorization support yet. This PR adds the support. `test_group_norm_vec` Before: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { #pragma GCC ivdep for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr0[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))]; auto tmp1 = out_ptr0[static_cast<long>(x1 + (32L*x0))]; auto tmp3 = out_ptr1[static_cast<long>(x1 + (32L*x0))]; auto tmp10 = in_ptr1[static_cast<long>(x1)]; auto tmp12 = in_ptr2[static_cast<long>(x1)]; auto tmp2 = tmp0 - tmp1; auto tmp4 = c10::convert<float>(1024.0); auto tmp5 = tmp3 / tmp4; auto tmp6 = c10::convert<float>(1e-05); auto tmp7 = tmp5 + tmp6; auto tmp8 = 1 / std::sqrt(tmp7); auto tmp9 = decltype(tmp2)(tmp2 * tmp8); auto tmp11 = decltype(tmp9)(tmp9 * tmp10); auto tmp13 = tmp11 + tmp12; out_ptr2[static_cast<long>(x2 + (1024L*x1) + (32768L*x0))] = tmp13; } } } } } } ``` After: ```c++ extern "C" void kernel(const float* in_ptr0, const float* in_ptr1, const float* in_ptr2, float* out_ptr0, float* out_ptr1, float* out_ptr2) { #pragma omp parallel num_threads(64) { { #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(64L); x0+=static_cast<long>(1L)) { { #pragma omp declare reduction(welford:Welford<float>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<float>()}) #pragma omp declare reduction(welford:Welford<at::vec::Vectorized<float>>:omp_out = welford_combine(omp_out, omp_in)) initializer(omp_priv={Welford<at::vec::Vectorized<float>>()}) Welford<float> tmp_acc0 = Welford<float>(); Welford<at::vec::Vectorized<float>> tmp_acc0_vec = Welford<at::vec::Vectorized<float>>(); for(long x1=static_cast<long>(0L); x1<static_cast<long>(1024L); x1+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x1 + (1024L*x0))); tmp_acc0_vec = welford_combine(tmp_acc0_vec, tmp0); } tmp_acc0 = welford_combine(tmp_acc0, welford_vec_reduce_all(tmp_acc0_vec)); out_ptr0[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.mean); out_ptr1[static_cast<long>(x0)] = static_cast<float>(tmp_acc0.m2); } } } { #pragma omp for collapse(2) for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(32L); x1+=static_cast<long>(1L)) { for(long x2=static_cast<long>(0L); x2<static_cast<long>(1024L); x2+=static_cast<long>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); auto tmp1 = at::vec::Vectorized<float>(static_cast<float>(out_ptr0[static_cast<long>(x1 + (32L*x0))])); auto tmp3 = at::vec::Vectorized<float>(static_cast<float>(out_ptr1[static_cast<long>(x1 + (32L*x0))])); auto tmp10 = at::vec::Vectorized<float>(static_cast<float>(in_ptr1[static_cast<long>(x1)])); auto tmp12 = at::vec::Vectorized<float>(static_cast<float>(in_ptr2[static_cast<long>(x1)])); auto tmp2 = tmp0 - tmp1; auto tmp4 = at::vec::Vectorized<float>(static_cast<float>(1024.0)); auto tmp5 = tmp3 / tmp4; auto tmp6 = at::vec::Vectorized<float>(static_cast<float>(1e-05)); auto tmp7 = tmp5 + tmp6; auto tmp8 = tmp7.rsqrt(); auto tmp9 = tmp2 * tmp8; auto tmp11 = tmp9 * tmp10; auto tmp13 = tmp11 + tmp12; tmp13.store(out_ptr2 + static_cast<long>(x2 + (1024L*x1) + (32768L*x0))); } } } } } } ``` Pull Request resolved: pytorch#112234 Approved by: https://github.com/lezcano, https://github.com/jansel
Pull Request resolved: pytorch#112243 Approved by: https://github.com/lezcano ghstack dependencies: pytorch#112234
Stack from ghstack (oldest at bottom):
Ops like group_norm has
ops.truediv
that doesn't have vectorization support yet. This PR adds the support.test_group_norm_vec
Before:
After:
cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler