-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decompose/add reference for view_as_complex
#108005
Changes from all commits
06e20e8
1cc5753
aa7def2
c1eae97
2641069
93855e8
27fb55d
cc3b999
fc899b4
c950d11
b8e1a30
8560532
93ce05d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -289,6 +289,7 @@ | |||
"view_as", | ||||
"vsplit", | ||||
"vstack", | ||||
"view_as_complex", | ||||
"unflatten", | ||||
"unbind", | ||||
"triu", | ||||
|
@@ -949,6 +950,43 @@ def trunc(a): | |||
return prims.trunc(a) | ||||
|
||||
|
||||
# TODO: register this as a real ref/decomposition once TorchInductor supports complex! | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be fine to register the decomposition and inductor won't try to use it by default. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I registered this as a decomposition, tests started failing and it seems Inductor generated code with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh it is being picked up because you have it in the pytorch/torch/_inductor/decomposition.py Line 64 in 78810d7
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function is composite(explicit) and it's already symintified in core, so there's not much point in registering it as a decomposition, not now nor when we support complex. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lezcano There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason why we wouldn't want to trace through those? Just because we are assuming that there may be some in-place ops or smth? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well in the case of Either way I'm guessing this function isn't alone in being registered as |
||||
def view_as_complex(self: TensorLikeType) -> TensorLikeType: | ||||
input_dtype = self.dtype | ||||
torch._check( | ||||
utils.is_float_dtype(input_dtype), | ||||
lambda: f"view_as_complex is only supported for floating point" | ||||
f"tensors, but got a tensor of scalar type: {input_dtype}", | ||||
) | ||||
sizes = self.size() | ||||
torch._check( | ||||
len(sizes) != 0, | ||||
lambda: "Input tensor must have one or more dimensions", | ||||
) | ||||
torch._check( | ||||
sizes[-1] == 2, | ||||
lambda: "Tensor must have a last dimension of size 2", | ||||
) | ||||
|
||||
old_strides = self.stride() | ||||
torch._check( | ||||
old_strides[-1] == 1, | ||||
lambda: "Tensor must have a last dimension with stride 1", | ||||
) | ||||
dims = old_strides[:-1] | ||||
torch._check( | ||||
py_all(stride % 2 == 0 for stride in dims), | ||||
lambda: "Tensor must have a stride divisible by 2 for all but last dimension", | ||||
) | ||||
torch._check( | ||||
self.storage_offset() % 2 == 0, | ||||
lambda: "Tensor must have a storage_offset divisible by 2", | ||||
) | ||||
return prims.view_element_type( | ||||
self, utils.corresponding_complex_dtype(input_dtype) | ||||
).squeeze(-1) | ||||
|
||||
|
||||
def _make_elementwise_binary_reference( | ||||
type_promotion_kind, | ||||
aten_op=infer_aten_op, | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed this as this function doesn't exist.