-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[WIP] rewrite should_swap #149215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] rewrite should_swap #149215
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149215
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit c27fd1a with merge base b238e36 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -1247,6 +1247,12 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool: | |||
return x | |||
|
|||
|
|||
def _guard_semantics(x: Union[bool, SymBool]) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is basically guard_or_false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, let me change the name later
torch/_prims_common/__init__.py
Outdated
if _guard_semantics(b == 0): | ||
return True | ||
|
||
expr = (a.node.expr if isinstance(a, torch.SymInt) else a) // (b.node.expr if isinstance(b, torch.SymInt) else b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not a//b am i missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found I had to look at the inner sympy expression to see the expression structure (to figure out if there was a remaining //
anyways), so this was easier to work with.
@@ -504,23 +505,34 @@ def compute_elementwise_output_logical_to_physical_perm( | |||
shape = tensors[0].shape | |||
|
|||
def should_swap(idx_a, idx_b): | |||
def gte(a, b): | |||
# semantics for a >= b, assuming a != b, a >= 0, b >= 0 | |||
if _guard_semantics(a == 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comments maybe
# if we know that a is 0 or b is 0 then we know the answer.
torch/_prims_common/__init__.py
Outdated
return True | ||
|
||
expr = (a.node.expr if isinstance(a, torch.SymInt) else a) // (b.node.expr if isinstance(b, torch.SymInt) else b) | ||
if isinstance(expr, (int, sympy.Integer)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see here it sounds like you want
try:
return expr>=1
except DataDepenedentErrpr:
return not isinstance(expr, sympy.floor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe guard_or_else(expr, lambda) where you pass a continuation lambda to be called in case of dde?
I needed that in several occasions
torch/_prims_common/__init__.py
Outdated
expr = (a.node.expr if isinstance(a, torch.SymInt) else a) // (b.node.expr if isinstance(b, torch.SymInt) else b) | ||
if isinstance(expr, (int, sympy.Integer)): | ||
return expr >= 1 | ||
return not isinstance(expr, sympy.floor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if i understand this well
you want to divide a//b
and check if the result is >1 or <1
if i could not tell if the result is >1 or <1
then i want something like
u1u2//u1 to be simplified to u1
and the inverse
u1//u1u2 to be simplified to 1/u1
and you want a way to detect that something is >1 or <1
- i do not get how this floor does it :) can you explain it.
- do you know that u1*u2//u1 is simplified to u2? (do we just assume u1 is not zero and simplify now?)
…k/should_swap_oblivious
Fixes #ISSUE_NUMBER
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv