Skip to content

Conversation

avikchaudhuri
Copy link
Contributor

@avikchaudhuri avikchaudhuri commented Jan 31, 2024

With the current Dim-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the same Dim. However, non-trivial relationships between such input shapes cannot be expressed.

Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even).

This PR introduces the concept of a "derived" Dim, i.e., a linear arithmetic expression over a Dim. By using a combination of Dims and derived Dims to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might be dim and dim + 1, or dim and 2*dim, or even 2*dim and dim + 1.

We extend the current infrastructure that translates Dims to deprecated dynamic_dim-based constraints to work with derived Dims. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec.

Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a disable_constraint_solver flag to some internal APIs, we may not need that flag any more.

Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols.

Differential Revision: D53254587

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Jan 31, 2024
Copy link

pytorch-bot bot commented Jan 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118729

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 7dca2b3 with merge base 7881b95 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53254587

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Jan 31, 2024
Summary: Pull Request resolved: pytorch#118729

Differential Revision: D53254587
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53254587

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Jan 31, 2024
Summary: Pull Request resolved: pytorch#118729

Differential Revision: D53254587
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53254587

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Feb 1, 2024
Summary: Pull Request resolved: pytorch#118729

Differential Revision: D53254587
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53254587

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53254587

@avikchaudhuri avikchaudhuri added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Feb 1, 2024
avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Feb 2, 2024
Summary: Pull Request resolved: pytorch#118729

Differential Revision: D53254587
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53254587

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Feb 2, 2024
Summary: Pull Request resolved: pytorch#118729

Differential Revision: D53254587
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D53254587

avikchaudhuri added a commit to avikchaudhuri/pytorch that referenced this pull request Feb 27, 2024
Summary:
With the current `Dim`-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the same `Dim`. However, non-trivial relationships between such input shapes cannot be expressed.

Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even).

This PR introduces the concept of a "derived" `Dim`, i.e., a linear arithmetic expression over a `Dim`. By using a combination of `Dim`s and derived `Dim`s to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might be `dim` and `dim + 1`, or `dim` and `2*dim`, or even `2*dim` and `dim + 1`.

We extend the current infrastructure that translates `Dim`s to deprecated `dynamic_dim`-based constraints to work with derived `Dim`s. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec.

Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a `disable_constraint_solver` flag to some internal APIs, we may not need that flag any more.

Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

Pull Request resolved: pytorch#118729

Differential Revision: D53254587

Pulled By: avikchaudhuri
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just left a couple comments. As we discussed, all this logic could do with a big clean-up sooner than later.

In particular, we should seriously streamline the class system here. One example: _PhantomRoot and _ConstraintTarget should be merged into one class, where is_phantom is given by t_id is None, for example.

Comment on lines 58 to 59
def __rsub__(cls, other):
raise NotImplementedError("non-monotonic operation not supported")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x -> -x is monotonic. Do you mean in all these "strictly increasing"? In particular, it looks like you want something that a < b -> f(a) < f(b), right? It would be good to put this equation in the docs of _DerivedDim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duh yes, I'll replace "monotonic" throughout which "strictly non-decreasing."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think the technical term is "monotonically increasing." x -> -x is monotonically decreasing. We also have that a = b -> f(a) = f(b). I'll add the equation and change the proper terminology.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if a = b you always have f(a) = f(b) :p
If you want a <= b -> f(a) <= f(b), that's simply "increasing".
Monotonically increasing is redundant. Monotonically = increasing or decreasing.
If you want a <= b -> f(a) <= f(b) then you need to add zero as an allowed value in __mul__.

At any rate, yes, please add the formula as it'll make things much clearer.


def __mul__(cls, other):
# e.g., dim * 2
assert type(other) is int, f"Expected int, got {type(other)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert type(other) is int, f"Expected int, got {type(other)}"
raise NotImplementedError("non-linear operation not supported")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll merge the two messages into one NotImplementedError.

def root_value():
# given tensor.shape[i] is the value of dim = fn(root),
# find the value of root
symbol = sympy.Symbol(dim.root.__name__, integer=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we set it to be nonnegative?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll clamp the bounds. We haven't yet moved the bounds to >= 0 by the way, that's a TODO (somebody else is working on a PR for that).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, I'm not going to use nonnegative (but instead enforce it via bounds) because we have some work TODO on serde to carry assumptions on all symbols in shape env, and I don't want to add assumptions that cannot be maintained through serde. (The only option is to also blanket-use nonnegative there right now, which is not correct for unbacked symints. I have a comment on that in this PR near serde logic.)

f"got {dynamic_shapes} instead",
)

from collections import defaultdict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push to top

@albanD albanD self-requested a review February 27, 2024 18:40
@albanD albanD dismissed their stale review February 27, 2024 18:41

Allowlist is not modified anymore

@avikchaudhuri avikchaudhuri force-pushed the export-D53254587 branch 2 times, most recently from 52a1331 to 0382090 Compare February 28, 2024 03:05
Summary:
With the current `Dim`-based dynamic shapes API for export, one can express that shapes of different input shapes must be equal by reusing the same `Dim`. However, non-trivial relationships between such input shapes cannot be expressed.

Recently we are seeing more and more examples of code that require this additional expressibility, e.g., where a pair of shapes might differ by one, or a shape might be double another (or simply even).

This PR introduces the concept of a "derived" `Dim`, i.e., a linear arithmetic expression over a `Dim`. By using a combination of `Dim`s and derived `Dim`s to specify input shapes, the desired relationships can be expressed naturally. E.g., a pair of shapes might be `dim` and `dim + 1`, or `dim` and `2*dim`, or even `2*dim` and `dim + 1`.

We extend the current infrastructure that translates `Dim`s to deprecated `dynamic_dim`-based constraints to work with derived `Dim`s. As usual, we raise constraint violation errors when shape guards cannot be verified given a dynamic shapes spec; suggest fixes; and raise runtime errors when future inputs violate the spec.

Importantly, some guards that used to cause forced specializations in the constraint solver because they were deemed "too complex" now do not do so, because they can now be specified as constraints. Since this was what motivated the introduction of a `disable_constraint_solver` flag to some internal APIs, we may not need that flag any more.

Note that shapes of placeholders in exported programs can now contain symbolic expressions and not just symbols.

bypass-github-pytorch-ci-checks

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

Pull Request resolved: pytorch#118729

Reviewed By: angelayi

Differential Revision: D53254587

Pulled By: avikchaudhuri
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged module: dynamo module: inductor release notes: fx release notes category suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants