Skip to content

Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors#18952

Closed
VitalyFedyunin wants to merge 14 commits intopytorch:masterfrom
VitalyFedyunin:pin_memory_try_3
Closed

Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors#18952
VitalyFedyunin wants to merge 14 commits intopytorch:masterfrom
VitalyFedyunin:pin_memory_try_3

Conversation

@VitalyFedyunin
Copy link
Copy Markdown
Contributor

@VitalyFedyunin VitalyFedyunin commented Apr 5, 2019

Make it possible to construct a pinned memory tensor without creating a storage first and without calling pin_memory() function. It is also faster, as copy operation is unnecessary.

Supported functions:

torch.rand_like(t, pin_memory=True)
torch.randn_like(t, pin_memory=True)
torch.empty_like(t, pin_memory=True)
torch.full_like(t, 4, pin_memory=True)
torch.zeros_like(t, pin_memory=True)
torch.ones_like(t, pin_memory=True)
torch.tensor([10,11], pin_memory=True)
torch.randn(3, 5, pin_memory=True)
torch.rand(3, pin_memory=True)
torch.zeros(3, pin_memory=True)
torch.randperm(3, pin_memory=True)
torch.empty(6, pin_memory=True)
torch.ones(6, pin_memory=True)
torch.eye(6, pin_memory=True)
torch.arange(3, 5, pin_memory=True)

Part of the bigger: Remove Storage plan.

Now compatible with both torch scripts:
_1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"), pin_memory=False)
and
_1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))

Same checked for all similar functions rand_like, empty_like and others

It is fixed version of #18455

…s. (pytorch#18455)

Summary:
Make it possible to construct a pinned memory tensor without creating a storage first and without calling pin_memory() function. It is also faster, as copy operation is unnecessary.

Supported functions:
```python
torch.rand_like(t, pin_memory=True)
torch.randn_like(t, pin_memory=True)
torch.empty_like(t, pin_memory=True)
torch.full_like(t, 4, pin_memory=True)
torch.zeros_like(t, pin_memory=True)
torch.ones_like(t, pin_memory=True)
torch.tensor([10,11], pin_memory=True)
torch.randn(3, 5, pin_memory=True)
torch.rand(3, pin_memory=True)
torch.zeros(3, pin_memory=True)
torch.randperm(3, pin_memory=True)
torch.empty(6, pin_memory=True)
torch.ones(6, pin_memory=True)
torch.eye(6, pin_memory=True)
torch.arange(3, 5, pin_memory=True)
```

Part of the bigger: `Remove Storage` plan.
Pull Request resolved: pytorch#18455

Reviewed By: ezyang

Differential Revision: D14672084

Pulled By: VitalyFedyunin

fbshipit-source-id: 9d0997ec00f59500ee018f8b851934d334012124
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 5, 2019
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin
Copy link
Copy Markdown
Contributor Author

@pytorchbot retest this please

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin VitalyFedyunin changed the title [WIP] Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors Adding pin_memory kwarg to zeros, ones, empty, ... tensor constructors Apr 5, 2019
CUDA: _cudnn_rnn_backward

- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType dtype, Layout layout, Device device) -> Tensor
- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one isn't optional, but the other ones are. What's the difference?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the fact that these are not optional imply that you might break JIT scripts which mention any of the non-optional variants here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish I could do it optional, however native function and jit generators only knows how to wrap all non-optional arguments to TensorOptions or all optional arguments to TensorOptions. There is no easy way to get combination of them (without major rewrite, which makes no sense as we migration off this generation anyway).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my real question is, why weren't all of these optional in the first place?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of the bigger what we do with tensor options discussion.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the fact that these are not optional imply that you might break JIT scripts which mention any of the non-optional variants here?

Answering my question, no, JIT scripts should not be broken, because all occurrences of pin_memory are given default arguments.

I think there should still be a Note talking about the inconsistency of pin_memory in these sites. Put it in native_functions.yaml (or maybe the README in native)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot of discussion about this function when it doesn't even matter (outside of not breaking JIT) -- no one calls this function directly. So presumably we can just do whatever makes our job easier here (giving everything a default, not using TensorOptions, etc. as long as it doesn't break JIT).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's my fault, I stuck the PR comment on the very first occurrence that had this problem. If you look at the rest of native_functions.yaml there are a bunch of other, more public, functions that look this way too. (In the end it's a moot point, I don't believe there is any BC-breakage here, regardless of the status of the function.)

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 11, 2019

Wasn't there going to be a test for TorchScript BC in this case?

Copy link
Copy Markdown
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see comments

@VitalyFedyunin
Copy link
Copy Markdown
Contributor Author

Wasn't there going to be a test for TorchScript BC in this case?

Done manually with binary files.

@ezyang ezyang self-requested a review April 11, 2019 17:50
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 11, 2019

Done manually with binary files.

The general expectation is that if you did manual testing to verify if a change worked, you should describe in the PR how exactly you tested it ;) (That's what "Test Plan" in FB infrastructure is all about). But I think there's a case to be made for actually have a real test in the test suite for this case, so that we can avoid breaking it in the future. What ended up being difficult about constructing an ad hoc JIT IR with the missing kwarg field for this?

@VitalyFedyunin
Copy link
Copy Markdown
Contributor Author

What ended up being difficult about constructing an ad hoc JIT IR with the missing kwarg field for this?

This type of the test will cover only this particular tiny BC (much more easier to test manually). Making proper test will require or storing binary, or inlining all zipped files inside of test code (including attributes.pkl binary)

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 11, 2019

This type of the test will cover only this particular tiny BC (much more easier to test manually)

I mean, sure, it will only cover this particular BC, but the point of adding regression tests when a regression happens is that we have some updated priors that this particular aspect of the system is likely to break. And it doesn't seem unreasonable to me that it will break exactly the same way the next time we need to add another option to TensorOptions.

@VitalyFedyunin
Copy link
Copy Markdown
Contributor Author

This type of the test will cover only this particular tiny BC (much more easier to test manually)

I mean, sure, it will only cover this particular BC, but the point of adding regression tests when a regression happens is that we have some updated priors that this particular aspect of the system is likely to break. And it doesn't seem unreasonable to me that it will break exactly the same way the next time we need to add another option to TensorOptions.

Will land #19174 first (need stamp ;) )

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Apr 11, 2019

You don't have to wait for #19174 to land first before landing this one; your assurance is good enough for me :)

Copy link
Copy Markdown
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't heavily rereview the Python arg parser code, let me know if you want a careful audit.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Apr 15, 2019
…models. (#19174)

Summary:
Helps to test #18952
Pull Request resolved: #19174

Differential Revision: D14899474

Pulled By: VitalyFedyunin

fbshipit-source-id: a4854ad44da28bd0f5115ca316e6078cbfe29d0d
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@VitalyFedyunin merged this pull request in 1c5073f.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Apr 16, 2019
…s (#18952)

Summary:
Make it possible to construct a pinned memory tensor without creating a storage first and without calling pin_memory() function. It is also faster, as copy operation is unnecessary.

Supported functions:
```python
torch.rand_like(t, pin_memory=True)
torch.randn_like(t, pin_memory=True)
torch.empty_like(t, pin_memory=True)
torch.full_like(t, 4, pin_memory=True)
torch.zeros_like(t, pin_memory=True)
torch.ones_like(t, pin_memory=True)
torch.tensor([10,11], pin_memory=True)
torch.randn(3, 5, pin_memory=True)
torch.rand(3, pin_memory=True)
torch.zeros(3, pin_memory=True)
torch.randperm(3, pin_memory=True)
torch.empty(6, pin_memory=True)
torch.ones(6, pin_memory=True)
torch.eye(6, pin_memory=True)
torch.arange(3, 5, pin_memory=True)
```

Part of the bigger: `Remove Storage` plan.

Now compatible with both torch scripts:
 `  _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"), pin_memory=False)`
and
`  _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))`

Same checked for all similar functions `rand_like`, `empty_like` and others

It is fixed version of #18455
Pull Request resolved: pytorch/pytorch#18952

Differential Revision: D14801792

Pulled By: VitalyFedyunin

fbshipit-source-id: 8dbc61078ff7a637d0ecdb95d4e98f704d5450ba
zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
…models. (pytorch#19174)

Summary:
Helps to test pytorch#18952
Pull Request resolved: pytorch#19174

Differential Revision: D14899474

Pulled By: VitalyFedyunin

fbshipit-source-id: a4854ad44da28bd0f5115ca316e6078cbfe29d0d
zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this pull request May 6, 2019
pytorch#18952)

Summary:
Make it possible to construct a pinned memory tensor without creating a storage first and without calling pin_memory() function. It is also faster, as copy operation is unnecessary.

Supported functions:
```python
torch.rand_like(t, pin_memory=True)
torch.randn_like(t, pin_memory=True)
torch.empty_like(t, pin_memory=True)
torch.full_like(t, 4, pin_memory=True)
torch.zeros_like(t, pin_memory=True)
torch.ones_like(t, pin_memory=True)
torch.tensor([10,11], pin_memory=True)
torch.randn(3, 5, pin_memory=True)
torch.rand(3, pin_memory=True)
torch.zeros(3, pin_memory=True)
torch.randperm(3, pin_memory=True)
torch.empty(6, pin_memory=True)
torch.ones(6, pin_memory=True)
torch.eye(6, pin_memory=True)
torch.arange(3, 5, pin_memory=True)
```

Part of the bigger: `Remove Storage` plan.

Now compatible with both torch scripts:
 `  _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"), pin_memory=False)`
and
`  _1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))`

Same checked for all similar functions `rand_like`, `empty_like` and others

It is fixed version of pytorch#18455
Pull Request resolved: pytorch#18952

Differential Revision: D14801792

Pulled By: VitalyFedyunin

fbshipit-source-id: 8dbc61078ff7a637d0ecdb95d4e98f704d5450ba
@Majdoddin
Copy link
Copy Markdown

@VitalyFedyunin What is the status of this PR please? What stops it to go to stable API?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants