-
Notifications
You must be signed in to change notification settings - Fork 25.6k
derived dim #118729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
derived dim #118729
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118729
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 7dca2b3 with merge base 7881b95 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D53254587 |
fabd708
to
a67c879
Compare
Summary: Pull Request resolved: pytorch#118729 Differential Revision: D53254587
This pull request was exported from Phabricator. Differential Revision: D53254587 |
a67c879
to
b64d13f
Compare
Summary: Pull Request resolved: pytorch#118729 Differential Revision: D53254587
This pull request was exported from Phabricator. Differential Revision: D53254587 |
b64d13f
to
5290073
Compare
Summary: Pull Request resolved: pytorch#118729 Differential Revision: D53254587
This pull request was exported from Phabricator. Differential Revision: D53254587 |
5290073
to
588e273
Compare
This pull request was exported from Phabricator. Differential Revision: D53254587 |
588e273
to
2b5c41c
Compare
Summary: Pull Request resolved: pytorch#118729 Differential Revision: D53254587
This pull request was exported from Phabricator. Differential Revision: D53254587 |
2b5c41c
to
a77c518
Compare
Summary: Pull Request resolved: pytorch#118729 Differential Revision: D53254587
This pull request was exported from Phabricator. Differential Revision: D53254587 |
a77c518
to
c401721
Compare
8924cbb
to
b119241
Compare
Summary: With the current `Dim`-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the same `Dim`. However, non-trivial relationships between such input shapes cannot be expressed. Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even). This PR introduces the concept of a "derived" `Dim`, i.e., a linear arithmetic expression over a `Dim`. By using a combination of `Dim`s and derived `Dim`s to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might be `dim` and `dim + 1`, or `dim` and `2*dim`, or even `2*dim` and `dim + 1`. We extend the current infrastructure that translates `Dim`s to deprecated `dynamic_dim`-based constraints to work with derived `Dim`s. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec. Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a `disable_constraint_solver` flag to some internal APIs, we may not need that flag any more. Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler Pull Request resolved: pytorch#118729 Differential Revision: D53254587 Pulled By: avikchaudhuri
b119241
to
a2dcad5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just left a couple comments. As we discussed, all this logic could do with a big clean-up sooner than later.
In particular, we should seriously streamline the class system here. One example: _PhantomRoot
and _ConstraintTarget
should be merged into one class, where is_phantom
is given by t_id is None
, for example.
torch/export/dynamic_shapes.py
Outdated
def __rsub__(cls, other): | ||
raise NotImplementedError("non-monotonic operation not supported") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x -> -x is monotonic. Do you mean in all these "strictly increasing"? In particular, it looks like you want something that a < b -> f(a) < f(b)
, right? It would be good to put this equation in the docs of _DerivedDim
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duh yes, I'll replace "monotonic" throughout which "strictly non-decreasing."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think the technical term is "monotonically increasing." x -> -x is monotonically decreasing. We also have that a = b -> f(a) = f(b). I'll add the equation and change the proper terminology.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, if a = b
you always have f(a) = f(b)
:p
If you want a <= b -> f(a) <= f(b)
, that's simply "increasing".
Monotonically increasing is redundant. Monotonically = increasing or decreasing.
If you want a <= b -> f(a) <= f(b)
then you need to add zero as an allowed value in __mul__
.
At any rate, yes, please add the formula as it'll make things much clearer.
torch/export/dynamic_shapes.py
Outdated
|
||
def __mul__(cls, other): | ||
# e.g., dim * 2 | ||
assert type(other) is int, f"Expected int, got {type(other)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert type(other) is int, f"Expected int, got {type(other)}" | |
raise NotImplementedError("non-linear operation not supported") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll merge the two messages into one NotImplementedError
.
def root_value(): | ||
# given tensor.shape[i] is the value of dim = fn(root), | ||
# find the value of root | ||
symbol = sympy.Symbol(dim.root.__name__, integer=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we set it to be nonnegative
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll clamp the bounds. We haven't yet moved the bounds to >= 0
by the way, that's a TODO (somebody else is working on a PR for that).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, I'm not going to use nonnegative
(but instead enforce it via bounds) because we have some work TODO on serde to carry assumptions on all symbols in shape env, and I don't want to add assumptions that cannot be maintained through serde. (The only option is to also blanket-use nonnegative
there right now, which is not correct for unbacked symints. I have a comment on that in this PR near serde logic.)
torch/export/dynamic_shapes.py
Outdated
f"got {dynamic_shapes} instead", | ||
) | ||
|
||
from collections import defaultdict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
push to top
a2dcad5
to
a815f32
Compare
a815f32
to
1e854c2
Compare
52a1331
to
0382090
Compare
0382090
to
8ed1d9a
Compare
8ed1d9a
to
68a6567
Compare
68a6567
to
e3b5a4c
Compare
Summary: With the current `Dim`-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the same `Dim`. However, non-trivial relationships between such input shapes cannot be expressed. Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even). This PR introduces the concept of a "derived" `Dim`, i.e., a linear arithmetic expression over a `Dim`. By using a combination of `Dim`s and derived `Dim`s to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might be `dim` and `dim + 1`, or `dim` and `2*dim`, or even `2*dim` and `dim + 1`. We extend the current infrastructure that translates `Dim`s to deprecated `dynamic_dim`-based constraints to work with derived `Dim`s. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec. Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a `disable_constraint_solver` flag to some internal APIs, we may not need that flag any more. Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols. bypass-github-pytorch-ci-checks cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler Pull Request resolved: pytorch#118729 Reviewed By: angelayi Differential Revision: D53254587 Pulled By: avikchaudhuri
e3b5a4c
to
7dca2b3
Compare
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With the current
Dim
-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the sameDim
. However, non-trivial relationships between such input shapes cannot be expressed.Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even).
This PR introduces the concept of a "derived"
Dim
, i.e., a linear arithmetic expression over aDim
. By using a combination ofDim
s and derivedDim
s to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might bedim
anddim + 1
, ordim
and2*dim
, or even2*dim
anddim + 1
.We extend the current infrastructure that translates
Dim
s to deprecateddynamic_dim
-based constraints to work with derivedDim
s. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec.Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a
disable_constraint_solver
flag to some internal APIs, we may not need that flag any more.Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols.
Differential Revision: D53254587
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames