Skip to content

Commit

Permalink
Add docstrings for save_on_cpu hooks (#62410)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #62410

This PR adds docstrings for CPU hooks introduced in #61928.

Also uncomments the warning about pinned memory in CUDA semantics docs.

Depends on: #62361.

For now docstrings are an orphan page at https://docs-preview.pytorch.org/62410/generated/torch.autograd.graph.set_save_on_cpu_hooks.html#torch-autograd-graph-set-save-on-cpu-hooks

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D29990129

Pulled By: Varal7

fbshipit-source-id: 7a98eeee6a0abb11e2c2d9169cd1aa35ad7ba3f4
  • Loading branch information
Varal7 authored and facebook-github-bot committed Aug 4, 2021
1 parent 5542d59 commit 5830f12
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
8 changes: 4 additions & 4 deletions docs/source/notes/cuda.rst
Expand Up @@ -468,11 +468,11 @@ also preserve :class:`torch.device` and :class:`torch.dtype` of a Tensor).
Use pinned memory buffers
^^^^^^^^^^^^^^^^^^^^^^^^^

.. warning:
.. warning::

This is an advanced tip. You overuse of pinned memory can cause serious
problems if you'll be running low on RAM, and you should be aware that
pinning is often an expensive operation.
This is an advanced tip. If you overuse pinned memory, it can cause serious
problems when running low on RAM, and you should be aware that pinning is
often an expensive operation.

Host to GPU copies are much faster when they originate from pinned (page-locked)
memory. CPU tensors and storages expose a :meth:`~torch.Tensor.pin_memory`
Expand Down
46 changes: 43 additions & 3 deletions torch/autograd/saved_variable_default_hooks.py
Expand Up @@ -8,18 +8,58 @@ def reset_saved_tensors_default_hooks():
torch._C._autograd._reset_default_hooks()

class save_on_cpu(object):
""""Context-manager under which tensors saved by the forward pass will be
stored on cpu, then retrieved for backward.
When performing operations within this context manager, intermediary
results saved in the graph during the forward pass will be moved to CPU,
then copied back to the original device when needed for the backward pass.
If the graph was already on CPU, no tensor copy is performed.
Use this context-manager to trade compute for GPU memory usage (e.g.
when your model doesn't fit in GPU memory during training).
Args:
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
during packing and copied to GPU asynchronously during unpacking.
Defaults to ``False``.
Also see :ref:`cuda-memory-pinning`.
Example::
>>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda")
>>>
>>> def f(a, b, c):
... prod_1 = a * b # a and b are saved on GPU
... with torch.autograd.graph.save_on_cpu():
... prod_2 = prod_1 * c # prod_1 and c are saved on CPU
... y = prod_2 * a # prod_2 and a are saved on GPU
... return y
>>>
>>> y = f(a, b, c)
>>> del a, b, c # for illustration only
>>> # the content of a, b, and prod_2 are still alive on GPU
>>> # the content of prod_1 and c only live on CPU
>>> y.sum().backward() # all CPU tensors are moved back to GPU, for backward
>>> # all intermediary tensors are released (deleted) after the call to backward
"""
def __init__(self, pin_memory=False):
def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())

storage = torch.empty(
packed = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
pin_memory=(torch.cuda.is_available() and not tensor.is_sparse))
storage.copy_(tensor)
return (tensor.device, storage)
packed.copy_(tensor)
return (tensor.device, packed)

def unpack_from_cpu(packed):
device, tensor = packed
Expand Down

0 comments on commit 5830f12

Please sign in to comment.