-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Preserve unbacked SymInt on SymNode #120816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/120816
Note: Links to docs will display an error until the docs builds have been completed. ❌ 41 New Failures, 5 Unrelated FailuresAs of commit 49f1a4f with merge base 7cd7a7a ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with this approach is that it's very error prone. You want to iterate over self.replacements
... and forget about unbacked symints. You do if a in self.replacements
and forget about unbacked symints...
A better approach would be to abstract these two into a class with methods that allow you to add a replacement (no need to figure out whether the symbol is unbacked or not) and perform a replacement on a expression.
Not needed for this PR, and no need to implement it at all. Just an idea.
@@ -529,14 +529,14 @@ def test_unbacked_substitution(self): | |||
_constrain_range_for_size(i0) | |||
_constrain_range_for_size(i1) | |||
self.assertTrue(expect_true(i0 == i1 * 4)) | |||
self.assertExpectedInline(str(i0), """4*u1""") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should tag these two lines with a wikipedia's [[disputed]]
given the amount of times we've changed from u0
to 4*u1
and back lol
if compute_hint: | ||
e = canonicalize_bool_expr(ra.expr.xreplace(self.var_to_val)) | ||
e = canonicalize_bool_expr(self.simplify(ra.expr).xreplace(self.var_to_val)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Factor self.simplify(ra.expr)
from the if plz
cur_replace = {s: self._find(s) for s in res.free_symbols} | ||
self._set_replacement(a, self.replacements[a].xreplace(cur_replace), "find") | ||
return self.replacements[a] | ||
def doit(replacements): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
factoring out the whole function rather than factoring out the arg. Fancy
I think we can do this more simply. We'll continue to chuck everything in replacements, but when we do |
Previously, when we applied a replacement, a SymInt that was previously an unbacked SymInt would then transmute into whatever we replaced it into (e.g., a constant). This has a major downside: we often look at SymInts associated with FX nodes (e.g., the meta of x.item() return) to find out where the unbacked SymInt was allocated. If we replace it, we no longer can find out where, e.g., u1 was allocated! But we need to know this so we can generate deferred runtime asserts like u1 == s0. To solve this problem, I separate out unbacked replacements into a separate dict unbacked_replacements and avoid applying them when you query for the sympy.Expr associated with a SymNode. Everything else still respects these; importantly, if you want to check if u1 statically equals s0, we will still end up doing the replacement. I'm open to other ways to solve this problem, but this one didn't seem too hard to implement. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 89dbad7 Pull Request resolved: #120816
Drastically simplified! |
@pytorchbot merge |
Previously, when we applied a replacement, a SymInt that was previously an unbacked SymInt would then transmute into whatever we replaced it into (e.g., a constant). This has a major downside: we often look at SymInts associated with FX nodes (e.g., the meta of x.item() return) to find out where the unbacked SymInt was allocated. If we replace it, we no longer can find out where, e.g., u1 was allocated! But we need to know this so we can generate deferred runtime asserts like u1 == s0. To solve this problem, I separate out unbacked replacements into a separate dict unbacked_replacements and avoid applying them when you query for the sympy.Expr associated with a SymNode. Everything else still respects these; importantly, if you want to check if u1 statically equals s0, we will still end up doing the replacement. I'm open to other ways to solve this problem, but this one didn't seem too hard to implement. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: 80e339b Pull Request resolved: #120816
Previously, when we applied a replacement, a SymInt that was previously an unbacked SymInt would then transmute into whatever we replaced it into (e.g., a constant). This has a major downside: we often look at SymInts associated with FX nodes (e.g., the meta of x.item() return) to find out where the unbacked SymInt was allocated. If we replace it, we no longer can find out where, e.g., u1 was allocated! But we need to know this so we can generate deferred runtime asserts like u1 == s0. To solve this problem, I separate out unbacked replacements into a separate dict unbacked_replacements and avoid applying them when you query for the sympy.Expr associated with a SymNode. Everything else still respects these; importantly, if you want to check if u1 statically equals s0, we will still end up doing the replacement. I'm open to other ways to solve this problem, but this one didn't seem too hard to implement. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: b7caed8 Pull Request resolved: #120816
Previously, when we applied a replacement, a SymInt that was previously an unbacked SymInt would then transmute into whatever we replaced it into (e.g., a constant). This has a major downside: we often look at SymInts associated with FX nodes (e.g., the meta of x.item() return) to find out where the unbacked SymInt was allocated. If we replace it, we no longer can find out where, e.g., u1 was allocated! But we need to know this so we can generate deferred runtime asserts like u1 == s0. To solve this problem, I separate out unbacked replacements into a separate dict unbacked_replacements and avoid applying them when you query for the sympy.Expr associated with a SymNode. Everything else still respects these; importantly, if you want to check if u1 statically equals s0, we will still end up doing the replacement. I'm open to other ways to solve this problem, but this one didn't seem too hard to implement. Signed-off-by: Edward Z. Yang <ezyang@meta.com> ghstack-source-id: cd25845 Pull Request resolved: #120816
This is going to be obsoleted in favor of #124394 but I haven't quite fixed the runtime assert generation yet, so keeping this open till then. |
…h#122988) This reverts commit 476585b. I did a bisect and this seems to be the cause of compile time regression in cudagraphs_dynamic test suite between 03/23 and 03/24:  image Particularly BERT_pytorch and hf_T5 seem to have ~50% compile time regression. Pull Request resolved: pytorch#122988 Approved by: https://github.com/eellison
Previously, when we applied a replacement, a SymInt that was previously an unbacked SymInt would then transmute into whatever we replaced it into (e.g., a constant). This has a major downside: we often look at SymInts associated with FX nodes (e.g., the meta of x.item() return) to find out where the unbacked SymInt was allocated. If we replace it, we no longer can find out where, e.g., u1 was allocated! But we need to know this so we can generate deferred runtime asserts like u1 == s0. To solve this problem, I have a special mode for replace, resolve_unbacked=False, which lets you disable substitutions on unbacked SymInts. When reporting node.expr, we preferentially avoid applying unbacked SymInt substitutions. To understand if we might accidentally reapply the substitution later, before we have reached the deferred runtime assert, we must study the calls to simplify() in ShapeEnv. My audit turns up these sites: * `produce_guards`: this is fine, deferred runtime asserts never show up here, we must NOT have unbacked SymInts show up here. Similarly `get_nontrivial_guards`. * `_maybe_evaluate_static`: this is fine, we are using this to determine if it is necessary to produce a guard/runtime assert. We don't want to reissue a runtime assert if we've already asserted on it, and replacements can help us understand if this has occurred. * `_simplify_floor_div`: this is a legitimate bug, it needs to be `resolve_unbacked=False` * `_refine_ranges`: this is fine, a refined range doesn't affect what runtime asserts we issue * `_update_divisible`: this updates the `self.divisible` set, which specifies when we can simplify away divisibility constraints. Since this affects replacements only, it won't cause us to oversimplify a user provided expression. There are some situations where we DO want to always apply the substitution, specifically when we have the duplicate symbol problem (we retrace an item call and get u0 and u1 which refer to the same thing.) I don't want two symbols in this case, so a special `rename_unbacked_to` is provided which sets up the unconditional renaming. Along the way, I make a refinement to `_update_var_to_range`: if you update a var range for a size-like unbacked SymInt, you are now no longer allowed to set its lower bound below 2. This is because if you could, then our size oblivious tests for it would be inconsistent. Actually, I think there is still some inconsistency, because if you assert `u0 == 0` we will still end up with this in deferred runtime asserts, and we will then use this to simplify these statements to be True everywhere else. Maybe we should forbid this kind of refinement; not done in this PR. Fixes #119689 Fixes #118385 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #120816 Approved by: https://github.com/lezcano
This is now fully subsumed, including tests, by #124874 |
This completely subsumes #120816 This makes use of the unbacked binding machinery to teach Inductor how to generate deferred runtime asserts directly. There is some back story about why I did it this way, let me explain. Previously, our strategy for generating runtime asserts was that Dynamo would insert them into the FX graph after finishing tracing, and we would attempt to code generate them based on the FX graph. This is a good strategy for export, where we immediately export the graph. However, this strategy was afflicted by problems in eager, where we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. Oops! So, with this PR, we take the attitude that as long as the ShapeEnv sticks around, the ShapeEnv's list of deferred runtime asserts is the source of truth, and we don't put anything in the graph. So we just need to decide when to actually generate asserts, and the place I picked was Inductor lowering, since we already have an AssertScalar buffer concept, and so I just need to insert them at this point. AssertScalar also uses raw sympy.Expr rather than SymInt/Bool, so it is easier to prevent unrestricted simplification at this point. There are a few things jumbled together in this PR. I can split them if you want, but some of the changes are before I changed my strategy, but they're useful changes anyway. **torch/_dynamo/output_graph.py** and **torch/_inductor/lowering.py** - Here, we stop putting deferred runtime asserts in the graph. I also have to make sure we don't DCE unused symbol arguments; we're going to get some goofy graph arguments this way, will be good to restore that optimization eventually. We also just disable codegen for `_assert_scalar` entirely; we assume that ShapeEnv will be good enough to capture all of these. **torch/_inductor/codegen/wrapper.py** and **torch/_inductor/ir.py** - Add a way to codegen sizevars without forcing simplification **torch/_inductor/graph.py** - The main logic. Our strategy is to interpose in the same place we are testing that unbacked SymInts are properly showing up in lowered code. The logic is directly analogous to the logic in the existing insert deferred runtime asserts FX pass, but it's simpler because sympy expressions can be directly stored on inductor IR nodes. **torch/fx/experimental/symbolic_shapes.py** - For extra safety, we have a way of freezing runtime asserts, so that if you try to add more we error. This prevents us from adding runtime asserts after we've done lowering. There's a funny interaction with backwards which there's a comment for in graph.py **torch/fx/passes/runtime_assert.py** - This is not really needed in this PR, but I rewrote the runtime assert logic to use unbacked_bindings rather than inferring it by looking for unbacked SymInts. Now, keypaths are translated into FX node acessors. Unfortunately, I couldn't delete the old inference code, because you still need it to find backed SymInts from arguments (as this pass may be used on graphs which don't explicitly bind all their shape variables as argments). There are some new tests exercising this. TODO: I think we need to generate asserts for replacements too. This is a preexisting problem that the old FX pass had too. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: #124874 Approved by: https://github.com/jansel ghstack dependencies: #124864
This completely subsumes pytorch#120816 This makes use of the unbacked binding machinery to teach Inductor how to generate deferred runtime asserts directly. There is some back story about why I did it this way, let me explain. Previously, our strategy for generating runtime asserts was that Dynamo would insert them into the FX graph after finishing tracing, and we would attempt to code generate them based on the FX graph. This is a good strategy for export, where we immediately export the graph. However, this strategy was afflicted by problems in eager, where we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. Oops! So, with this PR, we take the attitude that as long as the ShapeEnv sticks around, the ShapeEnv's list of deferred runtime asserts is the source of truth, and we don't put anything in the graph. So we just need to decide when to actually generate asserts, and the place I picked was Inductor lowering, since we already have an AssertScalar buffer concept, and so I just need to insert them at this point. AssertScalar also uses raw sympy.Expr rather than SymInt/Bool, so it is easier to prevent unrestricted simplification at this point. There are a few things jumbled together in this PR. I can split them if you want, but some of the changes are before I changed my strategy, but they're useful changes anyway. **torch/_dynamo/output_graph.py** and **torch/_inductor/lowering.py** - Here, we stop putting deferred runtime asserts in the graph. I also have to make sure we don't DCE unused symbol arguments; we're going to get some goofy graph arguments this way, will be good to restore that optimization eventually. We also just disable codegen for `_assert_scalar` entirely; we assume that ShapeEnv will be good enough to capture all of these. **torch/_inductor/codegen/wrapper.py** and **torch/_inductor/ir.py** - Add a way to codegen sizevars without forcing simplification **torch/_inductor/graph.py** - The main logic. Our strategy is to interpose in the same place we are testing that unbacked SymInts are properly showing up in lowered code. The logic is directly analogous to the logic in the existing insert deferred runtime asserts FX pass, but it's simpler because sympy expressions can be directly stored on inductor IR nodes. **torch/fx/experimental/symbolic_shapes.py** - For extra safety, we have a way of freezing runtime asserts, so that if you try to add more we error. This prevents us from adding runtime asserts after we've done lowering. There's a funny interaction with backwards which there's a comment for in graph.py **torch/fx/passes/runtime_assert.py** - This is not really needed in this PR, but I rewrote the runtime assert logic to use unbacked_bindings rather than inferring it by looking for unbacked SymInts. Now, keypaths are translated into FX node acessors. Unfortunately, I couldn't delete the old inference code, because you still need it to find backed SymInts from arguments (as this pass may be used on graphs which don't explicitly bind all their shape variables as argments). There are some new tests exercising this. TODO: I think we need to generate asserts for replacements too. This is a preexisting problem that the old FX pass had too. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: pytorch#124874 Approved by: https://github.com/jansel ghstack dependencies: pytorch#124864
Stack from ghstack (oldest at bottom):
Previously, when we applied a replacement, a SymInt that was
previously an unbacked SymInt would then transmute into whatever
we replaced it into (e.g., a constant).
This has a major downside: we often look at SymInts associated with
FX nodes (e.g., the meta of x.item() return) to find out where the
unbacked SymInt was allocated. If we replace it, we no longer can
find out where, e.g., u1 was allocated! But we need to know this
so we can generate deferred runtime asserts like u1 == s0.
To solve this problem, I have a special mode for replace, resolve_unbacked=False, which lets you disable substitutions on unbacked SymInts. When reporting node.expr, we preferentially avoid applying unbacked SymInt substitutions. To understand if we might accidentally reapply the substitution later, before we have reached the deferred runtime assert, we must study the calls to simplify() in ShapeEnv. My audit turns up these sites:
produce_guards
: this is fine, deferred runtime asserts never show up here, we must NOT have unbacked SymInts show up here. Similarlyget_nontrivial_guards
._maybe_evaluate_static
: this is fine, we are using this to determine if it is necessary to produce a guard/runtime assert. We don't want to reissue a runtime assert if we've already asserted on it, and replacements can help us understand if this has occurred._simplify_floor_div
: this is a legitimate bug, it needs to beresolve_unbacked=False
_refine_ranges
: this is fine, a refined range doesn't affect what runtime asserts we issue_update_divisible
: this updates theself.divisible
set, which specifies when we can simplify away divisibility constraints. Since this affects replacements only, it won't cause us to oversimplify a user provided expression.There are some situations where we DO want to always apply the substitution, specifically when we have the duplicate symbol problem (we retrace an item call and get u0 and u1 which refer to the same thing.) I don't want two symbols in this case, so a special
rename_unbacked_to
is provided which sets up the unconditional renaming.Along the way, I make a refinement to
_update_var_to_range
: if you update a var range for a size-like unbacked SymInt, you are now no longer allowed to set its lower bound below 2. This is because if you could, then our size oblivious tests for it would be inconsistent. Actually, I think there is still some inconsistency, because if you assertu0 == 0
we will still end up with this in deferred runtime asserts, and we will then use this to simplify these statements to be True everywhere else. Maybe we should forbid this kind of refinement; not done in this PR.Fixes #119689
Fixes #118385
Signed-off-by: Edward Z. Yang ezyang@meta.com
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang