-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fix error with int+SymBool #114828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix error with int+SymBool #114828
Conversation
Fixes #104797 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114828
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 2f4f346 with merge base f128616 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
def promote(x): | ||
"""Implements True+True=2, which works in python but not sympy""" | ||
if isinstance(x, SymBool): | ||
return SymInt(x.node.wrap_int(int(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling int(x)
guards on the value of x
. It would be better to implement a casting method similar to sym_float
that casts a bool to an int, similar to the method sym_float
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- SymPy doesn't seem to support this
- All the usage of this pattern I found are immediately fed into conditionals, so we need to guard on it anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding 1, we can work around it by using Where
pytorch/torch/utils/_sympy/functions.py
Lines 151 to 163 in e6b3a8c
class Where(sympy.Function): | |
""" | |
Good ol' ternary operator | |
""" | |
nargs = (3,) | |
@classmethod | |
def eval(cls, c, p, q): | |
if c == sympy.true: | |
return p | |
elif c == sympy.false: | |
return q |
Regarding 2, fair enough, although it's not an extremely unusual pattern to use True
as 1 when doing things like modular computations. May be good to leave a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we guard on 0/1, wouldn't something like sym_float
end up guarding anyway ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, good point :D
def promote(x): | ||
"""Implements True+True=2, which works in python but not sympy""" | ||
if isinstance(x, SymBool): | ||
return SymInt(x.node.wrap_int(int(x))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding 1, we can work around it by using Where
pytorch/torch/utils/_sympy/functions.py
Lines 151 to 163 in e6b3a8c
class Where(sympy.Function): | |
""" | |
Good ol' ternary operator | |
""" | |
nargs = (3,) | |
@classmethod | |
def eval(cls, c, p, q): | |
if c == sympy.true: | |
return p | |
elif c == sympy.false: | |
return q |
Regarding 2, fair enough, although it's not an extremely unusual pattern to use True
as 1 when doing things like modular computations. May be good to leave a comment.
Inductor's codegen will likely choke on sympy.Where, so it might be extra work to support that. Since there isn't a real usecase needing that, I'll leave that unimplemented for now. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
I added Where and I already implemented the relevant codegen. This is used for wrapping negative indices when we have indirect indexing. But sure, this is good as-is for now. |
This reverts commit 4631206.
This is failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/7b3429d97ce8cf4573977ab8cc904f22a83c8a66, updating the value after chatting with @jansel Pull Request resolved: #114918 Approved by: https://github.com/jansel
Fixes pytorch#104797 ``` File "/home/jansel/pytorch/torch/_dynamo/utils.py", line 1486, in <lambda> lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "/home/jansel/pytorch/torch/_dynamo/utils.py", line 1591, in run_node raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e File "/home/jansel/pytorch/torch/_dynamo/utils.py", line 1570, in run_node return node.target(*args, **kwargs) File "/home/jansel/conda/envs/pytorch/lib/python3.10/site-packages/einops/packing.py", line 153, in unpack n_unknown_composed_axes = sum(x == -1 for x in lengths_of_composed_axes) torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function unpack at 0x7f644b962710>(*(FakeTensor(..., device='cuda:0', size=(1, s0*s1, 128)), [(s0, s1)], 'b * c'), **{}): unsupported operand type(s) for +: 'int' and 'SymBool' ``` Pull Request resolved: pytorch#114828 Approved by: https://github.com/lezcano
This is failing in trunk https://hud.pytorch.org/pytorch/pytorch/commit/7b3429d97ce8cf4573977ab8cc904f22a83c8a66, updating the value after chatting with @jansel Pull Request resolved: pytorch#114918 Approved by: https://github.com/jansel
Stack from ghstack (oldest at bottom):
Fixes #104797
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng