New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decompose/add reference for view_as_complex
#108005
Decompose/add reference for view_as_complex
#108005
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/108005
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 93ce05d with merge base 9f37aec (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -172,7 +173,6 @@ | |||
"item", | |||
"maximum_value", | |||
"minimum_value", | |||
"to_dtype", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this as this function doesn't exist.
torch/_refs/__init__.py
Outdated
new_storage_offset = self.storage_offset() // 2 | ||
if not utils.is_complex_dtype(input_dtype): | ||
self = prims.view_of_dtype(self, utils.corresponding_complex_dtype(input_dtype)) | ||
|
||
return self.as_strided(new_sizes, new_strides, new_storage_offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This decomposition isn't ideal for inductor since it 1) goes directly to prims 2) calls as_strided
. It would make more sense to decompose into aten.view.dtype
.
new_storage_offset = self.storage_offset() // 2 | |
if not utils.is_complex_dtype(input_dtype): | |
self = prims.view_of_dtype(self, utils.corresponding_complex_dtype(input_dtype)) | |
return self.as_strided(new_sizes, new_strides, new_storage_offset) | |
complex_dtype = utils.corresponding_complex_dtype(input_dtype) | |
return self.view(complex_dtype)[..., 0] |
Which you could in turn implement with prims.view_of_dtype
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another reason to do it this way is to fix the meta registration for view.dtype
which currently isn't a view and doesn't handle different sized dtypes at all!
pytorch/torch/_meta_registrations.py
Lines 2645 to 2647 in 808e088
@register_meta(aten.view.dtype) | |
def view_dtype(self, dtype): | |
return utils.clone_preserve_strides(self).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright. I think I will work on the decomposition for aten.view.dtype
(which means i also need to fix the meta). Will do that in another PR. Thanks for the review!
@@ -918,6 +919,47 @@ def trunc(a): | |||
return prims.trunc(a) | |||
|
|||
|
|||
# TODO: register this as a real ref/decomposition once TorchInductor supports complex! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fine to register the decomposition and inductor won't try to use it by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I registered this as a decomposition, tests started failing and it seems Inductor generated code with complex64
and like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh it is being picked up because you have it in the core_aten_decompositions
list which is used by inductor here:
pytorch/torch/_inductor/decomposition.py
Line 64 in 78810d7
decompositions = {**core_aten_decompositions(), **inductor_decompositions} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is composite(explicit) and it's already symintified in core, so there's not much point in registering it as a decomposition, not now nor when we support complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lezcano CompositeExplicitAutograd
operators are not traced through by FakeTensor
, so you still need a decomposition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why we wouldn't want to trace through those? Just because we are assuming that there may be some in-place ops or smth?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well in the case of view.dtype
it calls raw TensorImpl
methods and wouldn't materialize anything meaningful in the fx graph if we traced through it. Although I would question if that should be legal in composite methods. It certainly wouldn't work for any tensor subclasses so maybe should be registered as CPU, CUDA, Meta
.
Either way I'm guessing this function isn't alone in being registered as CompositeExplicit
but actually doesn't dispatch to anything.
torch/_meta_registrations.py
Outdated
@@ -2606,7 +2606,54 @@ def meta_complex(real, imag): | |||
|
|||
@register_meta(aten.view.dtype) | |||
def view_dtype(self, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aten source
Tensor view_dtype(const Tensor& self, ScalarType dtype) { |
As a low level view-y operation, my personal preference would have been to just symbolify the underlying C++ implementation |
Co-Authored-By: peterbell10 <13238737+peterbell10@users.noreply.github.com>
I see. Just to clarify, you mean so support SymInts in both |
torch/_meta_registrations.py
Outdated
new_strides[-1] = 1 | ||
|
||
new_size = list(self.size()) | ||
new_size[self.dim() - 1] //= size_ratio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. In general, unless you know that you are sure that what you are working with are Python types, avoid in-place ops, as in-place ops and tracing often don't get along. I'm not sure if torch.SymInt
handles in-place ops well, but better safe than sorry.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: New commits were pushed while merging. Please rerun the merge command. Details for Dev Infra teamRaised by workflow job |
Peter, thanks for taking the time to review my PR thoroughly. Really appreciate it! |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Aten source:
pytorch/aten/src/ATen/native/ComplexHelper.h
Line 78 in d4a9963
Documentation reference:
https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
Note: this adds a new primitive
view_of_dtype
, which is trivially implemented, as its meta function is already implemented elsewhere.Finally, this is not registered as a decomposition (yet), because TorchInductor does not yet support complex types. It should be added once we do.
Closes #108020 as well.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @lezcano