Skip to content

Conversation

pianpwk
Copy link
Contributor

@pianpwk pianpwk commented Jul 15, 2024

Sets prefer_deferred_runtime_asserts_over_guards=True for export, so any guards emitted from SymNode.expect_true (for example, guards that are implicitly required to be true for an op to succeed) won't lead to constraint violations. Instead these should appear in the graph as runtime asserts, or potentially as replacement expressions for placeholder shapes.

For example, this reshape op should emit s0 * s1 = s2, deferred as a runtime assert.

x = torch.randn(4, 8)  # [s0, s1]
y = torch.randn(32)  # [s2]
out = x.reshape(-1) + y
# this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph.

However, other complex guards can still cause export to fail, for instance guards emitted from SymNode.guard_bool/guard_size_oblivious (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations. These can be deferred with allow_complex_guards_as_runtime_asserts=True. We don't yet make this default, because while this makes export more likely to succeed, it results in non-trivial asserts being emitted that often represent specialization to a variant of the op, or checks related to 0/1 specialization.

We also remove forced specializations for export and kill the _disable_forced_specializations flag - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring.

Follow up:
Currently, ShapeEnv._set_replacement() is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores s0*s1 in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid replacement and/or runtime assert on equality.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Jul 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130775

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Unrelated Failure

As of commit 899e3f2 with merge base ef05112 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778573

@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

pianpwk added a commit that referenced this pull request Jul 16, 2024
Summary:
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

Pull Request resolved: #130775

Differential Revision: D59778573

Pulled By: pianpwk
@pianpwk pianpwk force-pushed the export-D59778573 branch from 06bbc8e to 281790a Compare July 16, 2024 01:08
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778573

@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pianpwk pianpwk changed the title [export] default ras [export] turn on hybrid symints by default Jul 16, 2024
Copy link
Contributor Author

@pianpwk pianpwk Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer do any specializations, guards either turn into runtime asserts (if expect_true), or are blocking, and require rewriting or allow_complex_guards_as_runtime_asserts=True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Might be good to name the exported program and then do some test to ensure what runtime assert was added?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I should make a note here, this test case does require the complex guards flag to be set, it encounters a guard Ne(Mod(s0, s0-1), 0) (special case of reshape) when tracing.

For the cases I've removed, some require this flag, some export fine with just hybrid symints. For those that are fine, the asserts aren't perfect because replacement expressions are set for placeholders, but not checked at runtime. I was planning to leave that as the follow up task (undo replacement, set as runtime assert)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for SymBool expressions like Eq(u0, 1), where u0 isn't bound to any previous node, and directly calling create_symboolnode() errors because we haven't tracked range info on u0, so we explicitly deserialize each symbol beforehand.

See test_sym_bool in test/export/test_serialize.py for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really necessary, and after #128599 isn't that much more efficient, but if we want to reason more about what guards/runtime asserts to emit, I guess this makes it a bit cleaner to do from the export side. Also I guess there's no point in running this twice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think comments like this are better added to the code itself.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778573

@pianpwk pianpwk force-pushed the export-D59778573 branch from ec77e87 to 0c08e1e Compare July 16, 2024 05:00
Copy link
Contributor Author

@pianpwk pianpwk Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We run the export solver before the runtime asserts pass. Right now this doesn't mean much - the export solver is only there for suggested fixes - we won't even get to constraint solving if that's needed. But if in future we want to control what runtime asserts are emitted for export, or rely on produce_guards() for some simplification on runtime asserts, I think this makes sense.

Also matches how dynamo currently does it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, maybe move such comments to code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Storing these flags in ShapeEnvSettings becomes a problem with AOTInductor, where the ShapeEnv is reused from exporting with Hybrid Symints = True, but later is frozen for runtime asserts for Inductor. We can't overwrite the settings, and this fails when we end up deferring another guard we want to keep.

@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778573

pianpwk added a commit that referenced this pull request Jul 16, 2024
Summary:
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export. Now any guards emitted from `SymNode.expect_true` (e.g. guards implicitly true for ops) won't cause constraint violations, instead appearing in the graph as runtime asserts, or potentially as replacement expressions in the case of equality guards. One example is "complex guards" that must be true for an op to succeed, e.g.
```
x = torch.randn(4, 8)  # [s0, s1]
y = torch.randn(32)  # [s2]
out = x.reshape(-1) + y

# this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph,
# though we could also keep [s2] and reify the runtime assert.
```

However, doesn't yet handle all complex guards that might cause export to crash - guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations, and those can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet turn this on by default because it results in way more asserts in the graph that aren't obvious from the ops at a high-level, and often represent specialization (a variant of the op implementation, or checks for 0/1 specialization) - similar to turning on real tensor propagation.

We also remove forced specializations for export - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring.

Follow up:
Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid/undo replacement and runtime assert on the expression. A similar issue exists for derived dims: if a dim like `3*d0` is a placeholder shape, the value is never checked for divisibility.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

Pull Request resolved: #130775

Differential Revision: D59778573

Pulled By: pianpwk
@pianpwk pianpwk force-pushed the export-D59778573 branch from 02a7e62 to 2b63008 Compare July 16, 2024 06:04
@pianpwk pianpwk requested a review from angelayi July 16, 2024 16:31
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778573

pianpwk added a commit that referenced this pull request Jul 16, 2024
Summary:
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export. Now any guards emitted from `SymNode.expect_true` (e.g. guards implicitly true for ops) won't cause constraint violations, instead appearing in the graph as runtime asserts, or potentially as replacement expressions in the case of equality guards. One example is "complex guards" that must be true for an op to succeed, e.g.
```
x = torch.randn(4, 8)  # [s0, s1]
y = torch.randn(32)  # [s2]
out = x.reshape(-1) + y

# this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph,
# though we could also keep [s2] and reify the runtime assert.
```

However, doesn't yet handle all complex guards that might cause export to crash - guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations, and those can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet turn this on by default because it results in way more asserts in the graph that aren't obvious from the ops at a high-level, and often represent specialization (a variant of the op implementation, or checks for 0/1 specialization) - similar to turning on real tensor propagation.

We also remove forced specializations for export - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring.

Follow up:
Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid/undo replacement and runtime assert on the expression. A similar issue exists for derived dims: if a dim like `3*d0` is a placeholder shape, the value is never checked for divisibility.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

Pull Request resolved: #130775

Reviewed By: avikchaudhuri

Differential Revision: D59778573

Pulled By: pianpwk
@pianpwk pianpwk force-pushed the export-D59778573 branch from 86f16ba to 495de52 Compare July 16, 2024 23:22
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778573

@facebook-github-bot
Copy link
Contributor

@pianpwk has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Summary:
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export, so any guards emitted from `SymNode.expect_true` (for example, guards that are implicitly required to be true for an op to succeed) won't lead to constraint violations. Instead these should appear in the graph as runtime asserts, or potentially as replacement expressions for placeholder shapes.

For example, this reshape op should emit s0 * s1 = s2, deferred as a runtime assert.
```
x = torch.randn(4, 8)  # [s0, s1]
y = torch.randn(32)  # [s2]
out = x.reshape(-1) + y
# this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph.
```

However, other complex guards can still cause export to fail, for instance guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations. These can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet make this default, because while this makes export more likely to succeed, it results in non-trivial asserts being emitted that often represent specialization to a variant of the op, or checks related to 0/1 specialization.

We also remove forced specializations for export and kill the `_disable_forced_specializations` flag - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring.

Follow up:
Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid replacement and/or runtime assert on equality.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

Pull Request resolved: #130775

Reviewed By: avikchaudhuri

Differential Revision: D59778573

Pulled By: pianpwk
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59778573

@pianpwk pianpwk force-pushed the export-D59778573 branch from f5f6f67 to 899e3f2 Compare July 17, 2024 18:23
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 18, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@clee2000
Copy link
Contributor

@pytorchbot merge -f "merged internally, job failures seem unrelated, bc breakage is intentional"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

DiweiSun pushed a commit to DiweiSun/pytorch that referenced this pull request Jul 22, 2024
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export, so any guards emitted from `SymNode.expect_true` (for example, guards that are implicitly required to be true for an op to succeed) won't lead to constraint violations. Instead these should appear in the graph as runtime asserts, or potentially as replacement expressions for placeholder shapes.

For example, this reshape op should emit s0 * s1 = s2, deferred as a runtime assert.
```
x = torch.randn(4, 8)  # [s0, s1]
y = torch.randn(32)  # [s2]
out = x.reshape(-1) + y
# this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph.
```

However, other complex guards can still cause export to fail, for instance guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations. These can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet make this default, because while this makes export more likely to succeed, it results in non-trivial asserts being emitted that often represent specialization to a variant of the op, or checks related to 0/1 specialization.

We also remove forced specializations for export and kill the `_disable_forced_specializations` flag - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring.

Follow up:
Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid replacement and/or runtime assert on equality.

Pull Request resolved: pytorch#130775
Approved by: https://github.com/avikchaudhuri
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Jul 25, 2024
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export, so any guards emitted from `SymNode.expect_true` (for example, guards that are implicitly required to be true for an op to succeed) won't lead to constraint violations. Instead these should appear in the graph as runtime asserts, or potentially as replacement expressions for placeholder shapes.

For example, this reshape op should emit s0 * s1 = s2, deferred as a runtime assert.
```
x = torch.randn(4, 8)  # [s0, s1]
y = torch.randn(32)  # [s2]
out = x.reshape(-1) + y
# this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph.
```

However, other complex guards can still cause export to fail, for instance guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations. These can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet make this default, because while this makes export more likely to succeed, it results in non-trivial asserts being emitted that often represent specialization to a variant of the op, or checks related to 0/1 specialization.

We also remove forced specializations for export and kill the `_disable_forced_specializations` flag - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring.

Follow up:
Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid replacement and/or runtime assert on equality.

Pull Request resolved: pytorch#130775
Approved by: https://github.com/avikchaudhuri
pianpwk added a commit that referenced this pull request Aug 5, 2024
Summary:
#130775 recently stopped killed forced specializations for export on complex guards, so the only way we now get a specialized value is if we're able to solve for it. For example, if we have guards `s0 * 2 = s1`, `s0 + 6 = s1`, we specialize `s0 = 6; s1 = 12`.

That might look like this:
```
class Foo(torch.nn.Module):
    def forward(self, x, y):
        return x.reshape([-1]) + y

dy = Dim("dy", min=6)
x, y = torch.randn(6, 2), torch.randn(12)
dynamic_shapes = {
    "x": (dy - 6, 2),
    "y": (dy,),
}
```

Our current error message is:
`{symbol} must be specialized to {value} because the guards generated for it are too complex`
This is now misleading, so we change it to:
`solving the guards generated for {symbol} resulted in a specialized value of {value}`

Test Plan: test_export

Reviewed By: angelayi

Differential Revision: D60787430
pytorchmergebot pushed a commit that referenced this pull request Aug 6, 2024
#130775 recently killed forced specializations for export on complex guards, so the only way we now get a specialized value is if we're able to solve for it. For example, if we have guards `s0 * 2 = s1`, `s0 + 6 = s1`, we specialize `s0 = 6; s1 = 12`.

That might look like this:
```
class Foo(torch.nn.Module):
    def forward(self, x, y):
        return x.reshape([-1]) + y

dy = Dim("dy", min=6)
x, y = torch.randn(6, 2), torch.randn(12)
dynamic_shapes = {
    "x": (dy - 6, 2),
    "y": (dy,),
}
```

Our current error message is:
`{symbol} must be specialized to {value} because the guards generated for it are too complex`
This is now misleading, so we change it to:
`solving the guards generated for {symbol} resulted in a specialized value of {value}`
Pull Request resolved: #132698
Approved by: https://github.com/avikchaudhuri
@github-actions github-actions bot deleted the export-D59778573 branch August 20, 2024 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants