-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Fix lowerings that create unexpected aliases #105173
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/105173
Note: Links to docs will display an error until the docs builds have been completed. ✅ 1 Unrelated FailureAs of commit 45f8c32: UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This may give the wrong result if the result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` ghstack-source-id: 1b4683fe003cd65529760df3d9df2b35ed82d5fe Pull Request resolved: #105173
This may give the wrong result if the result in some cases, e.g. ```python @torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` [ghstack-poisoned]
This may give the wrong result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
This may give the wrong result if the result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` ghstack-source-id: 18f37b54e605d7cdfa109d4b9b6751123fc60212 Pull Request resolved: #105173
This may give the wrong result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
This may give the wrong result if the result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` ghstack-source-id: 110e8d89ae4e49ba75f6513913e4ce640cdcfdbf Pull Request resolved: #105173
This may give the wrong result if the result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` ghstack-source-id: 110e8d89ae4e49ba75f6513913e4ce640cdcfdbf Pull Request resolved: pytorch#105173
This may give the wrong result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx ipiszy ngimel yf225 chenyang78 kadeng muchulee8 [ghstack-poisoned]
This may give the wrong result if the result in some cases, e.g. ```python torch.compile() def fn(x): tmp = x.ceil() x.add_(10) return tmp a = torch.zeros((), dtype=torch.int64) fn(a) # tensor(10) ``` ghstack-source-id: 7e1aa4a2ac4561711c4464e63a50ea13df475d62 Pull Request resolved: pytorch#105173
if arg.get_dtype() != dtype: | ||
return to_dtype(arg, dtype) | ||
return arg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only functional change here. It fixes an infinite recursion issue created by to_dtype
calling clone
, which is wrapped with type promotion that itself calls to_dtype
. This only happens for no-op dtype conversions, where it was wasteful to clone anyway so better to not call to_dtype
when the dtypes already match here.
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hi @peterbell10, I tried with your PR but it seems the unit test provided in your PR description still gives incorrect result:
In addition, the change in this PR has brought performance regression to several models: As we're approaching the PyTorch 2.1 branch cut, may I know if you could help take a look at this issue? |
I forgot to ask you to add the script in the OP as a test. Could you do so while investigating those issues in afollow up @peterbell10 ? |
Stack from ghstack (oldest at bottom):
This may give the wrong result in some cases, e.g.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov