Skip to content

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Feb 28, 2024

Stack from ghstack (oldest at bottom):

Previously, when we applied a replacement, a SymInt that was
previously an unbacked SymInt would then transmute into whatever
we replaced it into (e.g., a constant).

This has a major downside: we often look at SymInts associated with
FX nodes (e.g., the meta of x.item() return) to find out where the
unbacked SymInt was allocated. If we replace it, we no longer can
find out where, e.g., u1 was allocated! But we need to know this
so we can generate deferred runtime asserts like u1 == s0.

To solve this problem, I have a special mode for replace, resolve_unbacked=False, which lets you disable substitutions on unbacked SymInts. When reporting node.expr, we preferentially avoid applying unbacked SymInt substitutions. To understand if we might accidentally reapply the substitution later, before we have reached the deferred runtime assert, we must study the calls to simplify() in ShapeEnv. My audit turns up these sites:

  • produce_guards: this is fine, deferred runtime asserts never show up here, we must NOT have unbacked SymInts show up here. Similarly get_nontrivial_guards.
  • _maybe_evaluate_static: this is fine, we are using this to determine if it is necessary to produce a guard/runtime assert. We don't want to reissue a runtime assert if we've already asserted on it, and replacements can help us understand if this has occurred.
  • _simplify_floor_div: this is a legitimate bug, it needs to be resolve_unbacked=False
  • _refine_ranges: this is fine, a refined range doesn't affect what runtime asserts we issue
  • _update_divisible: this updates the self.divisible set, which specifies when we can simplify away divisibility constraints. Since this affects replacements only, it won't cause us to oversimplify a user provided expression.

There are some situations where we DO want to always apply the substitution, specifically when we have the duplicate symbol problem (we retrace an item call and get u0 and u1 which refer to the same thing.) I don't want two symbols in this case, so a special rename_unbacked_to is provided which sets up the unconditional renaming.

Along the way, I make a refinement to _update_var_to_range: if you update a var range for a size-like unbacked SymInt, you are now no longer allowed to set its lower bound below 2. This is because if you could, then our size oblivious tests for it would be inconsistent. Actually, I think there is still some inconsistency, because if you assert u0 == 0 we will still end up with this in deferred runtime asserts, and we will then use this to simplify these statements to be True everywhere else. Maybe we should forbid this kind of refinement; not done in this PR.

Fixes #119689

Fixes #118385

Signed-off-by: Edward Z. Yang ezyang@meta.com

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Feb 28, 2024
Copy link

pytorch-bot bot commented Feb 28, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/120816

Note: Links to docs will display an error until the docs builds have been completed.

❌ 41 New Failures, 5 Unrelated Failures

As of commit 49f1a4f with merge base 7cd7a7a (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@ezyang ezyang added ciflow/trunk Trigger trunk jobs on your pull request release notes: composability release notes category topic: bug fixes topic category labels Feb 29, 2024
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with this approach is that it's very error prone. You want to iterate over self.replacements... and forget about unbacked symints. You do if a in self.replacements and forget about unbacked symints...

A better approach would be to abstract these two into a class with methods that allow you to add a replacement (no need to figure out whether the symbol is unbacked or not) and perform a replacement on a expression.

Not needed for this PR, and no need to implement it at all. Just an idea.

@@ -529,14 +529,14 @@ def test_unbacked_substitution(self):
_constrain_range_for_size(i0)
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == i1 * 4))
self.assertExpectedInline(str(i0), """4*u1""")
Copy link
Collaborator

@lezcano lezcano Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should tag these two lines with a wikipedia's [[disputed]] given the amount of times we've changed from u0 to 4*u1 and back lol

if compute_hint:
e = canonicalize_bool_expr(ra.expr.xreplace(self.var_to_val))
e = canonicalize_bool_expr(self.simplify(ra.expr).xreplace(self.var_to_val))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Factor self.simplify(ra.expr) from the if plz

cur_replace = {s: self._find(s) for s in res.free_symbols}
self._set_replacement(a, self.replacements[a].xreplace(cur_replace), "find")
return self.replacements[a]
def doit(replacements):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

factoring out the whole function rather than factoring out the arg. Fancy

[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented Feb 29, 2024

I think we can do this more simply. We'll continue to chuck everything in replacements, but when we do _find, when we compute cur_replace we will delete unbacked symbols from it unless we resolve unbacked.

[ghstack-poisoned]
ezyang added a commit that referenced this pull request Feb 29, 2024
Previously, when we applied a replacement, a SymInt that was
previously an unbacked SymInt would then transmute into whatever
we replaced it into (e.g., a constant).

This has a major downside: we often look at SymInts associated with
FX nodes (e.g., the meta of x.item() return) to find out where the
unbacked SymInt was allocated.  If we replace it, we no longer can
find out where, e.g., u1 was allocated!  But we need to know this
so we can generate deferred runtime asserts like u1 == s0.

To solve this problem, I separate out unbacked replacements into a
separate dict unbacked_replacements and avoid applying them when you
query for the sympy.Expr associated with a SymNode.  Everything else
still respects these; importantly, if you want to check if u1 statically
equals s0, we will still end up doing the replacement.

I'm open to other ways to solve this problem, but this one didn't seem
too hard to implement.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 89dbad7
Pull Request resolved: #120816
@ezyang
Copy link
Contributor Author

ezyang commented Feb 29, 2024

Drastically simplified!

@ezyang
Copy link
Contributor Author

ezyang commented Feb 29, 2024

@pytorchbot merge

@ezyang
Copy link
Contributor Author

ezyang commented Apr 11, 2024

@ezyang
Copy link
Contributor Author

ezyang commented Apr 11, 2024

With fix

image

Neutral now

ezyang added 2 commits April 15, 2024 06:12
[ghstack-poisoned]
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 15, 2024
Previously, when we applied a replacement, a SymInt that was
previously an unbacked SymInt would then transmute into whatever
we replaced it into (e.g., a constant).

This has a major downside: we often look at SymInts associated with
FX nodes (e.g., the meta of x.item() return) to find out where the
unbacked SymInt was allocated.  If we replace it, we no longer can
find out where, e.g., u1 was allocated!  But we need to know this
so we can generate deferred runtime asserts like u1 == s0.

To solve this problem, I separate out unbacked replacements into a
separate dict unbacked_replacements and avoid applying them when you
query for the sympy.Expr associated with a SymNode.  Everything else
still respects these; importantly, if you want to check if u1 statically
equals s0, we will still end up doing the replacement.

I'm open to other ways to solve this problem, but this one didn't seem
too hard to implement.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: 80e339b
Pull Request resolved: #120816
[ghstack-poisoned]
[ghstack-poisoned]
@ezyang ezyang mentioned this pull request Apr 16, 2024
ezyang added a commit that referenced this pull request Apr 16, 2024
Previously, when we applied a replacement, a SymInt that was
previously an unbacked SymInt would then transmute into whatever
we replaced it into (e.g., a constant).

This has a major downside: we often look at SymInts associated with
FX nodes (e.g., the meta of x.item() return) to find out where the
unbacked SymInt was allocated.  If we replace it, we no longer can
find out where, e.g., u1 was allocated!  But we need to know this
so we can generate deferred runtime asserts like u1 == s0.

To solve this problem, I separate out unbacked replacements into a
separate dict unbacked_replacements and avoid applying them when you
query for the sympy.Expr associated with a SymNode.  Everything else
still respects these; importantly, if you want to check if u1 statically
equals s0, we will still end up doing the replacement.

I'm open to other ways to solve this problem, but this one didn't seem
too hard to implement.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: b7caed8
Pull Request resolved: #120816
[ghstack-poisoned]
[ghstack-poisoned]
ezyang added a commit that referenced this pull request Apr 16, 2024
Previously, when we applied a replacement, a SymInt that was
previously an unbacked SymInt would then transmute into whatever
we replaced it into (e.g., a constant).

This has a major downside: we often look at SymInts associated with
FX nodes (e.g., the meta of x.item() return) to find out where the
unbacked SymInt was allocated.  If we replace it, we no longer can
find out where, e.g., u1 was allocated!  But we need to know this
so we can generate deferred runtime asserts like u1 == s0.

To solve this problem, I separate out unbacked replacements into a
separate dict unbacked_replacements and avoid applying them when you
query for the sympy.Expr associated with a SymNode.  Everything else
still respects these; importantly, if you want to check if u1 statically
equals s0, we will still end up doing the replacement.

I'm open to other ways to solve this problem, but this one didn't seem
too hard to implement.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

ghstack-source-id: cd25845
Pull Request resolved: #120816
@ezyang
Copy link
Contributor Author

ezyang commented Apr 20, 2024

This is going to be obsoleted in favor of #124394 but I haven't quite fixed the runtime assert generation yet, so keeping this open till then.

sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
…h#122988)

This reverts commit 476585b.

I did a bisect and this seems to be the cause of compile time regression in cudagraphs_dynamic test suite between 03/23 and 03/24:
![image](https://github.com/pytorch/pytorch/assets/4063635/21394e06-4906-4690-b5a2-7d16cc475843)
image Particularly BERT_pytorch and hf_T5 seem to have ~50% compile time regression.

Pull Request resolved: pytorch#122988
Approved by: https://github.com/eellison
pytorch-bot bot pushed a commit that referenced this pull request Apr 22, 2024
Previously, when we applied a replacement, a SymInt that was
previously an unbacked SymInt would then transmute into whatever
we replaced it into (e.g., a constant).

This has a major downside: we often look at SymInts associated with
FX nodes (e.g., the meta of x.item() return) to find out where the
unbacked SymInt was allocated.  If we replace it, we no longer can
find out where, e.g., u1 was allocated!  But we need to know this
so we can generate deferred runtime asserts like u1 == s0.

To solve this problem, I have a special mode for replace, resolve_unbacked=False, which lets you disable substitutions on unbacked SymInts. When reporting node.expr, we preferentially avoid applying unbacked SymInt substitutions. To understand if we might accidentally reapply the substitution later, before we have reached the deferred runtime assert, we must study the calls to simplify() in ShapeEnv. My audit turns up these sites:

* `produce_guards`: this is fine, deferred runtime asserts never show up here, we must NOT have unbacked SymInts show up here. Similarly `get_nontrivial_guards`.
* `_maybe_evaluate_static`: this is fine, we are using this to determine if it is necessary to produce a guard/runtime assert. We don't want to reissue a runtime assert if we've already asserted on it, and replacements can help us understand if this has occurred.
* `_simplify_floor_div`: this is a legitimate bug, it needs to be `resolve_unbacked=False`
* `_refine_ranges`: this is fine, a refined range doesn't affect what runtime asserts we issue
* `_update_divisible`: this updates the `self.divisible` set, which specifies when we can simplify away divisibility constraints. Since this affects replacements only, it won't cause us to oversimplify a user provided expression.

There are some situations where we DO want to always apply the substitution, specifically when we have the duplicate symbol problem (we retrace an item call and get u0 and u1 which refer to the same thing.) I don't want two symbols in this case, so a special `rename_unbacked_to` is provided which sets up the unconditional renaming.

Along the way, I make a refinement to `_update_var_to_range`: if you update a var range for a size-like unbacked SymInt, you are now no longer allowed to set its lower bound below 2. This is because if you could, then our size oblivious tests for it would be inconsistent. Actually, I think there is still some inconsistency, because if you assert `u0 == 0` we will still end up with this in deferred runtime asserts, and we will then use this to simplify these statements to be True everywhere else. Maybe we should forbid this kind of refinement; not done in this PR.

Fixes #119689

Fixes #118385

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #120816
Approved by: https://github.com/lezcano
@ezyang
Copy link
Contributor Author

ezyang commented Apr 24, 2024

This is now fully subsumed, including tests, by #124874

@ezyang ezyang closed this Apr 24, 2024
pytorchmergebot pushed a commit that referenced this pull request Apr 29, 2024
This completely subsumes #120816

This makes use of the unbacked binding machinery to teach Inductor how to generate deferred runtime asserts directly. There is some back story about why I did it this way, let me explain.

Previously, our strategy for generating runtime asserts was that Dynamo would insert them into the FX graph after finishing tracing, and we would attempt to code generate them based on the FX graph. This is a good strategy for export, where we immediately export the graph. However, this strategy was afflicted by problems in eager, where we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. Oops!

So, with this PR, we take the attitude that as long as the ShapeEnv sticks around, the ShapeEnv's list of deferred runtime asserts is the source of truth, and we don't put anything in the graph. So we just need to decide when to actually generate asserts, and the place I picked was Inductor lowering, since we already have an AssertScalar buffer concept, and so I just need to insert them at this point. AssertScalar also uses raw sympy.Expr rather than SymInt/Bool, so it is easier to prevent unrestricted simplification at this point.

There are a few things jumbled together in this PR. I can split them if you want, but some of the changes are before I changed my strategy, but they're useful changes anyway.

**torch/_dynamo/output_graph.py** and **torch/_inductor/lowering.py** - Here, we stop putting deferred runtime asserts in the graph. I also have to make sure we don't DCE unused symbol arguments; we're going to get some goofy graph arguments this way, will be good to restore that optimization eventually. We also just disable codegen for `_assert_scalar`  entirely; we assume that ShapeEnv will be good enough to capture all of these.

**torch/_inductor/codegen/wrapper.py** and **torch/_inductor/ir.py** - Add a way to codegen sizevars without forcing simplification

**torch/_inductor/graph.py** - The main logic. Our strategy is to interpose in the same place we are testing that unbacked SymInts are properly showing up in lowered code. The logic is directly analogous to the logic in the existing insert deferred runtime asserts FX pass, but it's simpler because sympy expressions can be directly stored on inductor IR nodes.

**torch/fx/experimental/symbolic_shapes.py** - For extra safety, we have a way of freezing runtime asserts, so that if you try to add more we error. This prevents us from adding runtime asserts after we've done lowering. There's a funny interaction with backwards which there's a comment for in graph.py

**torch/fx/passes/runtime_assert.py** - This is not really needed in this PR, but I rewrote the runtime assert logic to use unbacked_bindings rather than inferring it by looking for unbacked SymInts. Now, keypaths are translated into FX node acessors. Unfortunately, I couldn't delete the old inference code, because you still need it to find backed SymInts from arguments (as this pass may be used on graphs which don't explicitly bind all their shape variables as argments). There are some new tests exercising this.

TODO: I think we need to generate asserts for replacements too. This is a preexisting problem that the old FX pass had too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #124874
Approved by: https://github.com/jansel
ghstack dependencies: #124864
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This completely subsumes pytorch#120816

This makes use of the unbacked binding machinery to teach Inductor how to generate deferred runtime asserts directly. There is some back story about why I did it this way, let me explain.

Previously, our strategy for generating runtime asserts was that Dynamo would insert them into the FX graph after finishing tracing, and we would attempt to code generate them based on the FX graph. This is a good strategy for export, where we immediately export the graph. However, this strategy was afflicted by problems in eager, where we reuse the same ShapeEnv as before. In particular, on subsequent graph passes, we would immediately turn all of these assertions into noops, because when we evaluated their expressions, we would see that because we had a deferred runtime assert in the ShapeEnv, we know "oh, of course this expression is True" already. Oops!

So, with this PR, we take the attitude that as long as the ShapeEnv sticks around, the ShapeEnv's list of deferred runtime asserts is the source of truth, and we don't put anything in the graph. So we just need to decide when to actually generate asserts, and the place I picked was Inductor lowering, since we already have an AssertScalar buffer concept, and so I just need to insert them at this point. AssertScalar also uses raw sympy.Expr rather than SymInt/Bool, so it is easier to prevent unrestricted simplification at this point.

There are a few things jumbled together in this PR. I can split them if you want, but some of the changes are before I changed my strategy, but they're useful changes anyway.

**torch/_dynamo/output_graph.py** and **torch/_inductor/lowering.py** - Here, we stop putting deferred runtime asserts in the graph. I also have to make sure we don't DCE unused symbol arguments; we're going to get some goofy graph arguments this way, will be good to restore that optimization eventually. We also just disable codegen for `_assert_scalar`  entirely; we assume that ShapeEnv will be good enough to capture all of these.

**torch/_inductor/codegen/wrapper.py** and **torch/_inductor/ir.py** - Add a way to codegen sizevars without forcing simplification

**torch/_inductor/graph.py** - The main logic. Our strategy is to interpose in the same place we are testing that unbacked SymInts are properly showing up in lowered code. The logic is directly analogous to the logic in the existing insert deferred runtime asserts FX pass, but it's simpler because sympy expressions can be directly stored on inductor IR nodes.

**torch/fx/experimental/symbolic_shapes.py** - For extra safety, we have a way of freezing runtime asserts, so that if you try to add more we error. This prevents us from adding runtime asserts after we've done lowering. There's a funny interaction with backwards which there's a comment for in graph.py

**torch/fx/passes/runtime_assert.py** - This is not really needed in this PR, but I rewrote the runtime assert logic to use unbacked_bindings rather than inferring it by looking for unbacked SymInts. Now, keypaths are translated into FX node acessors. Unfortunately, I couldn't delete the old inference code, because you still need it to find backed SymInts from arguments (as this pass may be used on graphs which don't explicitly bind all their shape variables as argments). There are some new tests exercising this.

TODO: I think we need to generate asserts for replacements too. This is a preexisting problem that the old FX pass had too.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#124874
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#124864
@github-actions github-actions bot deleted the gh/ezyang/2592/head branch June 1, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo module: inductor release notes: composability release notes category release notes: fx release notes category topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants