New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor cpp] fix non-contiguous reduction store #113261
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113261
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ec155ea with merge base 4f2b288 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 5b4d662c675d16d882c5b81ed721a60eaa2b537f Pull Request resolved: #113261
Fix #113018 The reduction store in this case works on non-contiguous buffer. Previously, we only do scalar fallback for normal stores but not reduction stores. This PR fixes this. Before fix ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(39L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L); x1+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (17L*x2) + (306L*x0))); tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp0); } tmp_acc0_vec.store(out_ptr1 + static_cast<long>(x0 + (39L*x1))); // this is wrong since x0 is not vector dim } } #pragma omp simd simdlen(8) for(long x1=static_cast<long>(16L); x1<static_cast<long>(17L); x1+=static_cast<long>(1L)) { { float tmp_acc0 = -std::numeric_limits<float>::infinity(); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr1[static_cast<long>(x1 + (17L*x2) + (306L*x0))]; tmp_acc0 = max_propagate_nan(tmp_acc0, tmp0); } out_ptr1[static_cast<long>(x0 + (39L*x1))] = tmp_acc0; } } } ``` After fix ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(39L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L); x1+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (17L*x2) + (306L*x0))); tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp0); } { __at_align__ float tmpbuf[16*sizeof(float)/sizeof(float)]; tmp_acc0_vec.store(tmpbuf); for (long x1_inner = 0; x1_inner < 16; x1_inner++) out_ptr1[static_cast<long>(x0 + (39L*x1) + (39L*x1_inner))] = tmpbuf[x1_inner]; } } } #pragma omp simd simdlen(8) for(long x1=static_cast<long>(16L); x1<static_cast<long>(17L); x1+=static_cast<long>(1L)) { { float tmp_acc0 = -std::numeric_limits<float>::infinity(); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr1[static_cast<long>(x1 + (17L*x2) + (306L*x0))]; tmp_acc0 = max_propagate_nan(tmp_acc0, tmp0); } out_ptr1[static_cast<long>(x0 + (39L*x1))] = tmp_acc0; } } } ``` cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
ghstack-source-id: e9fd92c05473f44799e74127348f568c94691c11 Pull Request resolved: #113261
def get_vec_store_line(self, value, var, index, dtype): | ||
""" | ||
Get a store line str that stores `value` into `var` at `index` of `dtype`. | ||
:param value: Vectorized type templaterized on `dtype`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to static assert that value
is indeed of type Vectorized<dtype>
or nah?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will check it in python after introducing CppCSEVariable
.
Fix #113018 The reduction store in this case works on non-contiguous buffer. Previously, we only do scalar fallback for normal stores but not reduction stores. This PR fixes this. Before fix ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(39L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L); x1+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (17L*x2) + (306L*x0))); tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp0); } tmp_acc0_vec.store(out_ptr1 + static_cast<long>(x0 + (39L*x1))); // this is wrong since x0 is not vector dim } } #pragma omp simd simdlen(8) for(long x1=static_cast<long>(16L); x1<static_cast<long>(17L); x1+=static_cast<long>(1L)) { { float tmp_acc0 = -std::numeric_limits<float>::infinity(); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr1[static_cast<long>(x1 + (17L*x2) + (306L*x0))]; tmp_acc0 = max_propagate_nan(tmp_acc0, tmp0); } out_ptr1[static_cast<long>(x0 + (39L*x1))] = tmp_acc0; } } } ``` After fix ```c++ #pragma omp for for(long x0=static_cast<long>(0L); x0<static_cast<long>(39L); x0+=static_cast<long>(1L)) { for(long x1=static_cast<long>(0L); x1<static_cast<long>(16L); x1+=static_cast<long>(16L)) { { #pragma omp declare reduction(max:at::vec::Vectorized<float>:omp_out = at::vec::maximum(omp_out, omp_in)) initializer(omp_priv={at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity())}) float tmp_acc0 = -std::numeric_limits<float>::infinity(); at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity()); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr1 + static_cast<long>(x1 + (17L*x2) + (306L*x0))); tmp_acc0_vec = at::vec::maximum(tmp_acc0_vec, tmp0); } { __at_align__ float tmpbuf[16*sizeof(float)/sizeof(float)]; tmp_acc0_vec.store(tmpbuf); for (long x1_inner = 0; x1_inner < 16; x1_inner++) out_ptr1[static_cast<long>(x0 + (39L*x1) + (39L*x1_inner))] = tmpbuf[x1_inner]; } } } #pragma omp simd simdlen(8) for(long x1=static_cast<long>(16L); x1<static_cast<long>(17L); x1+=static_cast<long>(1L)) { { float tmp_acc0 = -std::numeric_limits<float>::infinity(); for(long x2=static_cast<long>(0L); x2<static_cast<long>(18L); x2+=static_cast<long>(1L)) { auto tmp0 = in_ptr1[static_cast<long>(x1 + (17L*x2) + (306L*x0))]; tmp_acc0 = max_propagate_nan(tmp_acc0, tmp0); } out_ptr1[static_cast<long>(x0 + (39L*x1))] = tmp_acc0; } } } ``` cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: #113407 Approved by: https://github.com/lezcano ghstack dependencies: #113261
Stack from ghstack (oldest at bottom):
Fix #113018
The reduction store in this case works on non-contiguous buffer. Previously, we only do scalar fallback for normal stores but not reduction stores. This PR fixes this.
Before fix
After fix
cc @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler