-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor.view(dtype) #47951
Add tensor.view(dtype) #47951
Conversation
Oh, sorry, I forget the doc |
💊 CI failures summary and remediationsAs of commit a3295fc (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_windows_vs2019_py36_cuda10.1_test2 (1/1)Step: "Test" (full log | diagnosis details | 🔁 rerun)
|
@@ -104,7 +104,7 @@ std::ostream& operator<<( | |||
static void printAttribute(std::ostream& out, const at::Tensor& tensor) { | |||
// 1-elem tensors are usually boxed scalars, so print them like it | |||
if (tensor.numel() == 1) { | |||
auto scalar_tensor = tensor.view({}).item(); | |||
auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to stop {}
from being resolved to ScalarType
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need these changes if dtype overload is blocklisted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is blocklisted at runtime for schema matching, but not at compile time as in here.
original related issue: #29013 |
Despite historic NumPy-originated name "view(...)", maybe a more clear alias name "reinterpret(...)" would be nice as well |
// Note (@zasdfgbnm): | ||
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964 | ||
// Currently JIT does not distinguish ScalarType vs int, so there is really | ||
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode | ||
// the aten::view.dtype here to block this overload. This blocklist should be | ||
// removed when JIT fully suports ScalarType as its own type. | ||
bool isBlockListedSchema(const FunctionSchema& schema) { | ||
if (schema.name() == "aten::view" && schema.overload_name() == "dtype") { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at the codegen, and some codes in dispatcher and JIT, and I feel that hard coding the operator name here is the best solution to workaround #47964.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to allow it work with the unambiguous torch.view(dtype=torch.int8)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you suggesting that if schema.name() == "aten::view"
then I should check if kwargs
is "dtype"
and special case it to make this case work? I think this would require changing more places in tryMatchSchema
. As a temporary workaround, I prefer the workaround to be simple, and can be easily reverted when the real fix (which is to add ScalarType support to JIT) is landed.
Codecov Report
@@ Coverage Diff @@
## master #47951 +/- ##
==========================================
+ Coverage 80.49% 80.68% +0.19%
==========================================
Files 1900 1900
Lines 206305 206318 +13
==========================================
+ Hits 166056 166470 +414
+ Misses 40249 39848 -401 |
cc @albanD for autograd. |
cc @bwasti |
// Note (@zasdfgbnm): | ||
// This is a workaround for https://github.com/pytorch/pytorch/issues/47964 | ||
// Currently JIT does not distinguish ScalarType vs int, so there is really | ||
// no way to distinguish x.view(1) vs x.view(torch.int8). So we have to hardcode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which overload is x.view(torch.int8) getting matched to without this logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this logic, both x.view(1)
and x.view(torch.int8)
get matched to aten::view.dtype
. Even x.view(-1)
get matched to aten::view.dtype
although -1
is not a valid dtype.
Actuacally I think aten::view.dtype
gets priority over aten::view
because aten::view.dtype
can be matched at tryMatchSchema(..., allow_conversions=False)
, but aten::view
is usually matched by tryMatchSchema(..., allow_conversions=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
of elements, but may have a different dtype. For a tensor to be viewed, the new | ||
dtype must have the same number of bytes with its original dtype. | ||
|
||
.. warning:: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we blacklist this overload in torchscript to avoid this kind of confusing error?
@@ -1129,6 +1129,9 @@ | |||
- name: view(Tensor(a) self, int[] size) -> Tensor(a) | |||
self: grad.reshape(self.sizes()) | |||
|
|||
- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) | |||
output_differentiability: [False] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will work for now. You can ping me if you want to change that to True in the future :) (we already have similar things for complex dtypes so it shouldn't be too hard to add).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the grad of this overload well defined mathematically? For the reinterpreting of two different types?
The int<-->float reinterpret will not work because int tensors do not support gradients. The only reinterpret in question I think is the double<-->complex64
. But I don't think this makes sense either. For example, if we change the real part of the complex64 from 0 to 1.1111111111*2^-80, then the exponent bits of the double tensor will be changed, and the limit
lim (f(x+dx) - f(x)) / dx
dx->0
don't seem to be converging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the forward is really defined mathematically either right? haha
But I do agree that we most likely want to keep it non differentiable for now.
@mruberry I think this PR is ready? |
This PR looks good to me, and I think you're right, @zasdfgbnm, but there may be a docs formatting issue. Take a look at the docs artifact here: https://9950312-65600975-gh.circle-artifacts.com/0/docs/tensors.html?highlight=view#torch.Tensor.view. In particular, this line:
Needs to be rendered like the multiple schema definitions for other operations (see, for example, sum's documentation: https://pytorch.org/docs/master/generated/torch.sum.html?highlight=sum#torch.sum). Just ping me when that's fixed. |
@mruberry This is fixed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
"Viewing a tensor as a new dtype with a different number of bytes per element is not supported."); | ||
Storage storage = self.storage(); | ||
auto new_tensor = detail::make_tensor<TensorImpl>( | ||
std::move(storage), self.key_set(), type_meta); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zasdfgbnm do you know what would happen if original tensor required grad, and you are viewing it as integer? Would key_set still have autograd key? The operation is non-differentiable, so it's not particularly important, unless it crashes or produces confusing error message. Can you add a test for what would happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will just return a tensor with requires_grad=False. I have added the test.
@@ -104,7 +104,7 @@ std::ostream& operator<<( | |||
static void printAttribute(std::ostream& out, const at::Tensor& tensor) { | |||
// 1-elem tensors are usually boxed scalars, so print them like it | |||
if (tensor.numel() == 1) { | |||
auto scalar_tensor = tensor.view({}).item(); | |||
auto scalar_tensor = tensor.view(std::vector<int64_t>{}).item(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need these changes if dtype overload is blocklisted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes pytorch#42571 Note that this functionality is a subset of [`numpy.ndarray.view`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.view.html): - this only supports viewing a tensor as a dtype with the same number of bytes - this does not support viewing a tensor as a subclass of `torch.Tensor` Pull Request resolved: pytorch#47951 Reviewed By: ngimel Differential Revision: D25062301 Pulled By: mruberry fbshipit-source-id: 9fefaaef77f15d5b863ccd12d836932983794475
Fixes #42571
Note that this functionality is a subset of
numpy.ndarray.view
:torch.Tensor