Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Handle aten.full's dtype in the decomposition #108443

Closed
wants to merge 2 commits into from

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Sep 1, 2023

Stack from ghstack (oldest at bottom):

In the lowering we don't have SymFloat and SymInt, we just have sympy.Expr
so it is impossible to accurately determine the expected dtype of a full call.
For example, sym_float(int_expr) has is_integer=True but should be treated
as a float. In the decomposition though, we can get this right.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

In the lowering we don't have `SymFloat` and `SymInt`, we just have `sympy.Expr`
so it is impossible to accurately determine the expected dtype of a `full` call.
For example, `sym_float(int_expr)` has `is_integer=True` but should be treated
as a float. In the decomposition though, we can get this right.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 1, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108443

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7c5a6b2 with merge base 29f17e1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

peterbell10 added a commit that referenced this pull request Sep 1, 2023
In the lowering we don't have `SymFloat` and `SymInt`, we just have `sympy.Expr`
so it is impossible to accurately determine the expected dtype of a `full` call.
For example, `sym_float(int_expr)` has `is_integer=True` but should be treated
as a float. In the decomposition though, we can get this right.

ghstack-source-id: a5b8e323cb166621f796109d6bad4cfcc903a8de
Pull Request resolved: #108443
@peterbell10 peterbell10 added the topic: not user facing topic category label Sep 1, 2023
peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Sep 1, 2023
In the lowering we don't have `SymFloat` and `SymInt`, we just have `sympy.Expr`
so it is impossible to accurately determine the expected dtype of a `full` call.
For example, `sym_float(int_expr)` has `is_integer=True` but should be treated
as a float. In the decomposition though, we can get this right.

ghstack-source-id: a5b8e323cb166621f796109d6bad4cfcc903a8de
Pull Request resolved: pytorch#108443
…ion"

In the lowering we don't have `SymFloat` and `SymInt`, we just have `sympy.Expr`
so it is impossible to accurately determine the expected dtype of a `full` call.
For example, `sym_float(int_expr)` has `is_integer=True` but should be treated
as a float. In the decomposition though, we can get this right.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov

[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Sep 1, 2023
In the lowering we don't have `SymFloat` and `SymInt`, we just have `sympy.Expr`
so it is impossible to accurately determine the expected dtype of a `full` call.
For example, `sym_float(int_expr)` has `is_integer=True` but should be treated
as a float. In the decomposition though, we can get this right.

ghstack-source-id: 828e98b71dacf5c85b022c273b2e75fae466189d
Pull Request resolved: #108443
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we mark the symfloats as not integer so that Yukio's solution was also correct? This is a bit odd...
At any rate, having a solution at a higher level of abstraction is always better, so approving.

@peterbell10
Copy link
Collaborator Author

Don't we mark the symfloats as not integer so that Yukio's solution was also correct?

No, the test case I've included fails with Yukio's PR because the sympy expression for sym_float(s0) is still just s0 and s0 is an integer. Maybe we could add some extra typing metadata but that would basically be re-inventing SymInt and SymFloat.

@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 1, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/peterbell10/608/head branch September 5, 2023 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants