New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add default hooks to save tensors on CPU #61928
Conversation
Fix #57100 [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 679572c (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
Fix #57100. Creates a context manager `torch.autograd.graph.save_on_cpu()` under which all tensors saved during the forward pass are actually copied to cpu, then copied back to the appropriate device for the backward pass. [ghstack-poisoned]
@Varal7 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Fix #57100. Creates a function `torch.autograd.graph.set_save_on_cpu_hooks()` which can be used to register default hooks under which all tensors saved during the forward pass are moved* to cpu, then copied back to the appropriate device for the backward pass. *If the tensor was already on cpu, the entire operation is a no op. If the tensor is on GPU, we move the tensor to `pin_memory` during packing so that the unpacking can be done asynchronously. With the current PR, hooks are set with `torch.autograd.graph.set_save_on_cpu_hooks()` and unset with `torch.autograd.graph.reset_saved_tensors_default_hooks`. In the near future, we want to make these hooks thread-local and expose a context manager `torch.autograd.graph.save_on_cpu`. See [benchmark](#61928 (comment)) and [note about training large models](#61928 (comment)) Differential Revision: [D29848526](https://our.internmc.facebook.com/intern/diff/D29848526) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
Fix #57100. Creates a function `torch.autograd.graph.set_save_on_cpu_hooks()` which can be used to register default hooks under which all tensors saved during the forward pass are moved* to cpu, then copied back to the appropriate device for the backward pass. *If the tensor was already on cpu, the entire operation is a no op. If the tensor is on GPU, we move the tensor to `pin_memory` during packing so that the unpacking can be done asynchronously. With the current PR, hooks are set with `torch.autograd.graph.set_save_on_cpu_hooks()` and unset with `torch.autograd.graph.reset_saved_tensors_default_hooks`. In the near future, we want to make these hooks thread-local and expose a context manager `torch.autograd.graph.save_on_cpu`. See [benchmark](#61928 (comment)) and [note about training large models](#61928 (comment)) Differential Revision: [D29848526](https://our.internmc.facebook.com/intern/diff/D29848526) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
@Varal7 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
storage = torch.empty( | ||
tensor.size(), | ||
dtype=tensor.dtype, | ||
layout=tensor.layout, | ||
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse)) | ||
storage.copy_(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we do storage = tensor.to(device='cpu', pin_memory=True)
?
Also nit: we should rename storage into something else; storage can be confused with torch.Storage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that pin_memory
is an acceptable argument of torch.tensor.to
Ok for storage
rename.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that
pin_memory
is an acceptable argument oftorch.tensor.to
Good point, thanks for the clarification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would something like storage = tensor.to("cpu", non_blocking=True).pin_memory()
work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would mean 2 copies
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Differential Revision: [D29990129](https://our.internmc.facebook.com/intern/diff/D29990129) [ghstack-poisoned]
This reverts commit 9beb279. [ghstack-poisoned]
Summary: Pull Request resolved: #62410 This PR adds docstrings for CPU hooks introduced in #61928. Also uncomments the warning about pinned memory in CUDA semantics docs. Depends on: #62361. For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks Test Plan: Imported from OSS Reviewed By: soulitzer Differential Revision: D29990129 Pulled By: Varal7 fbshipit-source-id: 7a98eeee6a0abb11e2c2d9169cd1aa35ad7ba3f4
Stack from ghstack:
Fix #57100.
Creates a context-manager
torch.autograd.graph.save_on_cpu()
under which all tensors saved during the forward pass are moved* to cpu, then copied back to the appropriate device for the backward pass.*If the tensor was already on cpu, the entire operation is a no op.
If the user so desires, we move the tensor to
pin_memory
during packing so that the unpacking can be done asynchronously.With the current PR, hooks are registered globally, across threads. In the near future, we want to make these hooks thread-local.
See benchmark and note about training large models
Differential Revision: D29848526