-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[MPSInductor] Fix remainder implementation for int types #155891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155891
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 13 PendingAs of commit 031425d with merge base 6020440 ( NEW FAILURE - The following job has failed:
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Introduce `c10::metal::remainder` and call it from both inductor and eager implementation This fixes compilation of something like ```python torch.compile def f(x, y): return x[y % 5] ``` which beforehand failed to compile with ``` torch._inductor.exc.InductorError: SyntaxError: failed to compile #include <c10/metal/utils.h> kernel void generated_kernel( device float* out_ptr0, constant long* in_ptr0, constant float* in_ptr1, uint xindex [[thread_position_in_grid]] ) { int x0 = xindex; auto tmp0 = in_ptr0[x0]; auto tmp1 = 12; auto tmp2 = static_cast<float>(tmp0) - static_cast<float>(tmp1) * metal::floor(static_cast<float>(tmp0) / static_cast<float>(tmp1)); auto tmp3 = 1024; auto tmp4 = static_cast<long>(tmp3); auto tmp5 = tmp2 + tmp4; auto tmp6 = tmp2 < 0; auto tmp7 = tmp6 ? tmp5 : tmp2; if ((tmp7 < 0) && (tmp7 > 1024)) return; auto tmp9 = in_ptr1[tmp7]; out_ptr0[x0] = static_cast<float>(tmp9); } with program_source:372:28: error: array subscript is not an integer auto tmp9 = in_ptr1[tmp7]; ^~~~~ ``` ghstack-source-id: f27f2cb Pull Request resolved: #155891
Introduce `c10::metal::remainder` and call it from both inductor and eager implementation This fixes compilation of something like ```python torch.compile def f(x, y): return x[y % 5] ``` which beforehand failed to compile with ``` torch._inductor.exc.InductorError: SyntaxError: failed to compile #include <c10/metal/utils.h> kernel void generated_kernel( device float* out_ptr0, constant long* in_ptr0, constant float* in_ptr1, uint xindex [[thread_position_in_grid]] ) { int x0 = xindex; auto tmp0 = in_ptr0[x0]; auto tmp1 = 12; auto tmp2 = static_cast<float>(tmp0) - static_cast<float>(tmp1) * metal::floor(static_cast<float>(tmp0) / static_cast<float>(tmp1)); auto tmp3 = 1024; auto tmp4 = static_cast<long>(tmp3); auto tmp5 = tmp2 + tmp4; auto tmp6 = tmp2 < 0; auto tmp7 = tmp6 ? tmp5 : tmp2; if ((tmp7 < 0) && (tmp7 > 1024)) return; auto tmp9 = in_ptr1[tmp7]; out_ptr0[x0] = static_cast<float>(tmp9); } with program_source:372:28: error: array subscript is not an integer auto tmp9 = in_ptr1[tmp7]; ^~~~~ ``` ghstack-source-id: cd4c8a4 Pull Request resolved: #155891
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Introduce
c10::metal::remainder
and call it from both inductor and eager implementation, with integer specialization, which should make it much faster than before, while still compliant with Python way of rounding up negative numbers.This allows one to remove complex type detection logic from mps codegen and rely on Metal(C++) type system to figure out input and output types.
This fixes compilation of something like
which beforehand failed to compile with
This fixes fail_to_compile for GPT2ForSequenceClassification Huggingface model using
transformers==4.44.2
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov