Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add names= argument to torch.tensor ctor #25424

Closed
wants to merge 5 commits into from

Conversation

zou3519
Copy link
Contributor

@zou3519 zou3519 commented Aug 29, 2019

Stack from ghstack:

Test Plan

  • new tests [namedtensor ci]

Differential Revision: D17120399

Test Plan
- new tests [namedtensor ci]
@pytorchbot pytorchbot added the module: internals Related to internal abstractions in c10 and ATen label Aug 29, 2019
Test Plan
- new tests [namedtensor ci]

Differential Revision: [D17120399](https://our.internmc.facebook.com/intern/diff/D17120399)
zou3519 added a commit that referenced this pull request Aug 29, 2019
Test Plan
- new tests [namedtensor ci]

ghstack-source-id: 9ec59c5fd3cd8ed355e21e1cdae4992a3f78bc85
Pull Request resolved: #25424
@@ -268,6 +281,9 @@ Tensor internal_new_from_data(
auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(type_id));
AutoNoGIL no_gil;
maybe_initialize_cuda(device);
#ifdef BUILD_NAMEDTENSOR
at::namedinference::propagate_names(tensor, names);
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this code should be refactored a bit.

E.g. as far as I can tell the body of the two USE_NUMPY conditions are nearly identical, so if you refactor you'd need fewer #ifdef BUILD_NAMEDTENSOR blocks, which would make the code more readable. Ideally you'd only need one. I think with better code re-use this code would be more readable regardless of your change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I refactored this into a function, that function would require a guard on one of its arguments based on #ifdef BUILD_NAMEDTENSOR because it would take in a names argument.

We'd still have the same number of #ifdef BUILD_NAMEDTENSOR blocks (actually we'd have ~2 more because of the function declaration and definition), so this would not clean up the code.

I agree that this code could use a refactor, however. I can file a task and take care of it when I delete the build_namedtensor flag (that is slated to happen sometime hopefully next week).

@nairbv Did you have any other concerns on the code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't (necessarily) mean separate functions.

e.g. if we just wrote if (PyObject_HasAttrString(data, "__cuda_array_interface__") || PyArray_Check(data)) above instead of having two separate conditions, the only thing that differs inside the if block is the call to get the tensor. The repeated code doesn't necessarily need to be moved to a helper function to be combined.

I guess this can be refactored some other time though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. if we just wrote if (PyObject_HasAttrString(data, "cuda_array_interface") || PyArray_Check(data)) above instead of having two separate conditions, the only thing that differs inside the if block is the call to get the tensor. The repeated code doesn't necessarily need to be moved to a helper function to be combined.

I agree that it would be marginally cleaner, but not to the extent that we should block this PR on. A full refactor of the code that would make the logic in this function a lot clearer, in my opinion, would split up each of the codepaths into distinct cases (in the form of helper functions): build_tensor_from_cuda_array, build_tensor_from_numpy, build_tensor_from_pylist, build_tensor_from_tensor. build_tensor_from_cuda_array would share a lot of common code with build_tensor_from_numpy, and all four of these functions would share a lot of common code in the place where we coerce the input to a device/dtype/names. However, if I were to do this right now, each function would be littered with #ifdef BUILD_NAMEDTENSOR blocks.

@@ -241,6 +246,11 @@ Tensor internal_new_from_data(
if (copy_variables) {
var = var.detach();
}
#ifdef BUILD_NAMEDTENSOR
if (names) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it intentional that you test names for this propagate call but not for the others?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. This is case where we're passing a tensor to the torch.tensor constructor. It should behave differently if we pass in names=None (those are ignored) as opposed to names=<something>.

For example,

x: Tensor[N, C]
(1) torch.tensor(x): Tensor[N, C]. names=None implicitly.
(2) torch.tensor(x, names=None): Tensor[N, C]
torch.tensor(x, names=[N, D]): Error!

This is because we can't distinguish between cases (1) and (2).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can imagine that if we have a x=tensor([1,2,3], names=['foo']) and we try to call y=torch.tensor(x), we would want to make sure names is preserved in y. Is that what you're trying to handle here?

What's the case that should error? I assume we should allow a user to change the names of a tensor in copy construction?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see from the examples in tests that we don't allow changing names in construction. Should we / why don't we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don’t allow it because there’s no use case for it and it is better to be conservative. If there is a use case for it, then there is a workaround and it is easy to add once someone asks for it.

@vadimkantorov
Copy link
Contributor

A naming question: if a general tensor.name attribute is to be introduced at a later point (e.g. for deubgging or graph inspection purposes), it would be confusing to have both name and names. A suggestion: names -> dim_names = 'NCHW'

@zou3519
Copy link
Contributor Author

zou3519 commented Aug 30, 2019

A naming question: if a general tensor.name attribute is to be introduced at a later point (e.g. for deubgging or graph inspection purposes), it would be confusing to have both name and names.

I agree.

A suggestion: names -> dim_names = 'NCHW'

Thank you for your suggestion, we'll keep it in mind. I've thought about names, dims, dim_names as potential attributes. For now, because there isn't a widely used tensor.name (it actually does exist and it is used for ONNX), I'd like to keep the named tensor API nice and simple with names.

Test Plan
- new tests [namedtensor ci]

Differential Revision: [D17120399](https://our.internmc.facebook.com/intern/diff/D17120399)
zou3519 added a commit that referenced this pull request Sep 3, 2019
Test Plan
- new tests [namedtensor ci]

ghstack-source-id: ddad281ada20bf944f672f3e934710749124d17a
Pull Request resolved: #25424
Test Plan
- new tests [namedtensor ci]

Differential Revision: [D17120399](https://our.internmc.facebook.com/intern/diff/D17120399)
zou3519 added a commit that referenced this pull request Sep 9, 2019
Test Plan
- new tests [namedtensor ci]

ghstack-source-id: 054b84daebfb4fe7c130a9a6d0ccb327a66a71eb
Pull Request resolved: #25424
Test Plan
- new tests [namedtensor ci]

Differential Revision: [D17120399](https://our.internmc.facebook.com/intern/diff/D17120399)
zou3519 added a commit that referenced this pull request Sep 10, 2019
Test Plan
- new tests [namedtensor ci]

ghstack-source-id: ae774e1771fbf8c83e61b7424d4e6a927a80a5df
Pull Request resolved: #25424
zdevito pushed a commit to zdevito/ATen that referenced this pull request Sep 11, 2019
Summary:
Pull Request resolved: pytorch/pytorch#25424

Test Plan
- new tests [namedtensor ci]

Test Plan: Imported from OSS

Differential Revision: D17120399

Pulled By: zou3519

fbshipit-source-id: 93d7944f2ec4c5a7256f505323b879af706131df
@facebook-github-bot
Copy link
Contributor

@zou3519 merged this pull request in 4231287.

@facebook-github-bot facebook-github-bot deleted the gh/zou3519/127/head branch October 28, 2019 22:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: internals Related to internal abstractions in c10 and ATen
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants