-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[export] handle new roots & root swapping in derived dims suggested fixes #125543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125543
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit aff51b8 with merge base cbb79a2 ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
39fefbf to
4d03382
Compare
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
Summary:
The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name.
This PR only adds a suggested fix for the root dim, and removes all other derived suggestions.
For example, with the export test case test_derived_dim_out_of_order_simplified:
```
_dimz = torch.export.Dim("_dimz", min=6, max=8)
dimy = _dimz - 1
dimx = dimy - 1
dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz
class Foo(torch.nn.Module):
def forward(self, x, y, z):
return x + y[1:] + z[2:]
foo = Foo()
u, v, w = torch.randn(5), torch.randn(6), torch.randn(7)
export(
foo,
(u, v, w),
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
)
```
Before:
```
Suggested fixes:
_dimz = Dim('_dimz', min=3, max=9223372036854775807) # 2 <= _dimz - 1 <= 9223372036854775806
_dimz - 2 = Dim('_dimz - 2', min=4, max=6)
_dimz = Dim('_dimz', min=2, max=9223372036854775806) # 2 <= _dimz <= 9223372036854775806
_dimz - 1 = _dimz - 1
dimz = _dimz
```
New suggested fixes:
```
Suggested fixes:
dimz = _dimz
```
Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example:
```
def forward(self, x, y):
return x.reshape([-1]) + y # guard: s0 * 2 = s1
dx = Dim("dx")
export(
model,
(torch.randn(6, 2), torch.randn(12)),
dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )}
)
```
This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints.
Test Plan: export test cases
Differential Revision: D56974543
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
| int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) | ||
| return symbool.node.shape_env.create_symintnode(int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None) | ||
|
|
||
| # use class name to avoid circular imports |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eh, couldn't you do a local import to avoid circular imports?
| def _is_dim(dim): | ||
| return dim.__class__.__name__ == "_Dim" | ||
|
|
||
| def _solve_for_root(expr, val): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you use try_solve instead?
| if k in name_to_dim and _is_derived_dim(name_to_dim[k]) | ||
| } | ||
|
|
||
| # evaluate each root for either concrete val or min/max range |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The algo here is hard to follow, please add a running example. It's intricate logic, having the transformation explicitly laid out in comments, and how it's achieved with every major step of the transformation, would be good.
Also maybe add comments on why just suggesting fixes to the roots is sufficient.
|
|
||
| # retrieve dynamic shapes | ||
| name_to_dim = {} | ||
| for dim in pytree.tree_flatten( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there a different way to collect all dims? Like, have you checked the contents of source_name_to_debug_name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main issue here is collecting root dims which aren't attached to a size
…ixes (#125543) Summary: This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping. 1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results that look like {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}}, and this prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both. 2) The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name. This PR only adds a suggested fix for the root dim, and removes all other derived suggestions. For example, with the export test case test_derived_dim_out_of_order_simplified: ``` _dimz = torch.export.Dim("_dimz", min=6, max=8) dimy = _dimz - 1 dimx = dimy - 1 dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz class Foo(torch.nn.Module): def forward(self, x, y, z): return x + y[1:] + z[2:] foo = Foo() u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) export( foo, (u, v, w), dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) ``` Before: ``` Suggested fixes: _dimz = Dim('_dimz', min=3, max=9223372036854775807) # 2 <= _dimz - 1 <= 9223372036854775806 _dimz - 2 = Dim('_dimz - 2', min=4, max=6) _dimz = Dim('_dimz', min=2, max=9223372036854775806) # 2 <= _dimz <= 9223372036854775806 _dimz - 1 = _dimz - 1 dimz = _dimz ``` New suggested fixes: ``` Suggested fixes: dimz = _dimz ``` Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example: ``` def forward(self, x, y): return x.reshape([-1]) + y # guard: s0 * 2 = s1 dx = Dim("dx") export( model, (torch.randn(6, 2), torch.randn(12)), dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} ) ``` This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints. Test Plan: export test cases Differential Revision: D56974543
35b2f34 to
6d82012
Compare
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
| root = list(c["eq"].free_symbols)[0] | ||
| if not str(root) in name_to_dim: | ||
| introduced_roots[str(root)] = k | ||
| # calculate necessary min & max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved this code from prettify_results, existed for newly introduced roots
| if isinstance(other, int): | ||
| others.append(f"{k} = {other}") | ||
| elif self._is_supported_equivalence(other): | ||
| s = next(iter(other.free_symbols)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this handles newly introduced roots, move to new function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trusting the tests and the summary, the changes are great! Mostly skimmed over the code.
Thank you! |
6d82012 to
2c3f918
Compare
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
2c3f918 to
1825eda
Compare
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
…ixes (#125543) Summary: This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping. 1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions. For example: ``` "dx": {"eq": 3*_dx-1, "max": 35} "dy": {"eq": dx+1} This should lead to suggested fixes _dx = Dim('_dx', max=12) dx = 3 * _dx - 1 dy = 3 * _dx ``` This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both. 2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name. This PR only adds a suggested fix for the root dim, and removes all other derived suggestions. For example, with the export test case test_derived_dim_out_of_order_simplified: ``` _dimz = torch.export.Dim("_dimz", min=6, max=8) dimy = _dimz - 1 dimx = dimy - 1 dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz class Foo(torch.nn.Module): def forward(self, x, y, z): return x + y[1:] + z[2:] foo = Foo() u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) export( foo, (u, v, w), dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) ``` Before: ``` Suggested fixes: _dimz = Dim('_dimz', min=3, max=9223372036854775807) # 2 <= _dimz - 1 <= 9223372036854775806 _dimz - 2 = Dim('_dimz - 2', min=4, max=6) _dimz = Dim('_dimz', min=2, max=9223372036854775806) # 2 <= _dimz <= 9223372036854775806 _dimz - 1 = _dimz - 1 dimz = _dimz ``` New suggested fixes: ``` Suggested fixes: dimz = _dimz ``` Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example: ``` def forward(self, x, y): return x.reshape([-1]) + y # guard: s0 * 2 = s1 dx = Dim("dx") export( model, (torch.randn(6, 2), torch.randn(12)), dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} ) ``` This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints. Test Plan: export test cases Reviewed By: avikchaudhuri Differential Revision: D56974543
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
aaa0105 to
1c29db7
Compare
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
0b250f1 to
b24262e
Compare
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
…ixes (#125543) Summary: This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping. 1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions. For example: ``` "dx": {"eq": 3*_dx-1, "max": 35} "dy": {"eq": dx+1} This should lead to suggested fixes _dx = Dim('_dx', max=12) dx = 3 * _dx - 1 dy = 3 * _dx ``` This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both. 2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name. This PR only adds a suggested fix for the root dim, and removes all other derived suggestions. For example, with the export test case test_derived_dim_out_of_order_simplified: ``` _dimz = torch.export.Dim("_dimz", min=6, max=8) dimy = _dimz - 1 dimx = dimy - 1 dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz class Foo(torch.nn.Module): def forward(self, x, y, z): return x + y[1:] + z[2:] foo = Foo() u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) export( foo, (u, v, w), dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) ``` Before: ``` Suggested fixes: _dimz = Dim('_dimz', min=3, max=9223372036854775807) # 2 <= _dimz - 1 <= 9223372036854775806 _dimz - 2 = Dim('_dimz - 2', min=4, max=6) _dimz = Dim('_dimz', min=2, max=9223372036854775806) # 2 <= _dimz <= 9223372036854775806 _dimz - 1 = _dimz - 1 dimz = _dimz ``` New suggested fixes: ``` Suggested fixes: dimz = _dimz ``` Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example: ``` def forward(self, x, y): return x.reshape([-1]) + y # guard: s0 * 2 = s1 dx = Dim("dx") export( model, (torch.randn(6, 2), torch.randn(12)), dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} ) ``` This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints. Test Plan: export test cases Reviewed By: avikchaudhuri Differential Revision: D56974543
|
This pull request was exported from Phabricator. Differential Revision: D56974543 |
|
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ixes (pytorch#125543) Pull Request resolved: pytorch#125543 This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping. 1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions. For example: ``` "dx": {"eq": 3*_dx-1, "max": 36} "dy": {"eq": dx+1} This should lead to suggested fixes _dx = Dim('_dx', max=12) dx = 3 * _dx - 1 dy = 3 * _dx ``` This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both. 2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g. `dy - 1, dy` -> `dx, dx + 1`. This leads to problematic suggested fixes that look like `dy - 1 = Dim("dy - 1")` since we don't have access to the original variable name. This PR only adds a suggested fix for the root dim, and removes all other derived suggestions. For example, with the export test case test_derived_dim_out_of_order_simplified: ``` _dimz = torch.export.Dim("_dimz", min=6, max=8) dimy = _dimz - 1 dimx = dimy - 1 dimz = torch.export.Dim("dimz", min=6, max=8) # doesn't work, should be = _dimz class Foo(torch.nn.Module): def forward(self, x, y, z): return x + y[1:] + z[2:] foo = Foo() u, v, w = torch.randn(5), torch.randn(6), torch.randn(7) export( foo, (u, v, w), dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}), ) ``` Before: ``` Suggested fixes: _dimz = Dim('_dimz', min=3, max=9223372036854775807) # 2 <= _dimz - 1 <= 9223372036854775806 _dimz - 2 = Dim('_dimz - 2', min=4, max=6) _dimz = Dim('_dimz', min=2, max=9223372036854775806) # 2 <= _dimz <= 9223372036854775806 _dimz - 1 = _dimz - 1 dimz = _dimz ``` New suggested fixes: ``` Suggested fixes: dimz = _dimz ``` Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example: ``` def forward(self, x, y): return x.reshape([-1]) + y # guard: s0 * 2 = s1 dx = Dim("dx") export( model, (torch.randn(6, 2), torch.randn(12)), dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} ) ``` This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints. Pull Request resolved: pytorch#125543 Approved by: https://github.com/avikchaudhuri
Pull Request resolved: #125543
This PR address 2 issues with derived dim suggested fixes, 1) newly introduced roots, and 2) root swapping.
1 | Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final suggested fixes handle this correctly, but we can get intermediate results where related derived dims don't rely on a unified root, and are a mixture of min/max range and derived suggestions.
For example:
This PR prettifies the suggested fixes routine by unifying to a single root, and making each intermediate suggestion either a derived dim or min/max range, not both.
2 | The current suggested fixes for derived dims can lead to root dims/derived dims being swapped, e.g.
dy - 1, dy->dx, dx + 1. This leads to problematic suggested fixes that look likedy - 1 = Dim("dy - 1")since we don't have access to the original variable name.This PR only adds a suggested fix for the root dim, and removes all other derived suggestions.
For example, with the export test case test_derived_dim_out_of_order_simplified:
Before:
New suggested fixes:
Note: This assumes the specified derived relations between dims are correct. This should be valid because: 1) if the relation is plain wrong (e.g. (dx, dx - 1) provided with inputs (6, 4)), this gets caught in beforehand in produce_guards. 2) if the relation is correct but does not match the emitted guard, for example:
This produces two linear equations, leading to specialization since a) produce_guards is able to solve for a concrete value, and b) the export constraint solver will anyways force specializations due to range constraints.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang