Skip to content

Commit

Permalink
Add docstrings for save_on_cpu hooks
Browse files Browse the repository at this point in the history
This PR adds docstrings for CPU hooks introduced in #61928.

Also uncomments the warning about pinned memory in CUDA semantics docs

ghstack-source-id: 4c86248f5818546f5c82117a306e61d328389bb7
Pull Request resolved: #62410
  • Loading branch information
Varal7 committed Aug 3, 2021
1 parent 95f9c52 commit bf149b3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/notes/cuda.rst
Expand Up @@ -468,7 +468,7 @@ also preserve :class:`torch.device` and :class:`torch.dtype` of a Tensor).
Use pinned memory buffers
^^^^^^^^^^^^^^^^^^^^^^^^^

.. warning:
.. warning::

This is an advanced tip. You overuse of pinned memory can cause serious
problems if you'll be running low on RAM, and you should be aware that
Expand Down
41 changes: 38 additions & 3 deletions torch/autograd/saved_variable_default_hooks.py
Expand Up @@ -7,7 +7,42 @@ def reset_saved_tensors_default_hooks():
torch._C._autograd._reset_default_hooks()

def set_save_on_cpu_hooks(pin_memory=False):
def pack_hook(tensor):
"""Sets pack_to_cpu / unpack_from_cpu hooks for saved tensors.
When these hooks are set, intermediary results saved in the graph during
the forward pass will be moved to CPU, then copied back to the original device
when needed for the backward pass. If the graph was already on CPU, no tensor copy
is performed.
Use this hook to tradeoff speed for less GPU memory usage.
You can set these hooks once before creating the graph; or you can control
which part of the graph should be saved on CPU by registering these hooks
before - and resetting them after - creating the part of the graph to be saved
on CPU.
Args:
pin_memory (bool): If ``True`` tensors will be saved to CPU pinned memory
during packing and copied to GPU asynchronously during unpacking.
Defaults to ``False``.
Also see :ref:`cuda-memory-pinning`.
Example::
>>> a = torch.randn(5, requires_grad=True, device="cuda")
>>> b = torch.randn(5, requires_grad=True, device="cuda")
>>> c = torch.randn(5, requires_grad=True, device="cuda")
>>> d = a * b # a and b are saved in the graph (on GPU)
>>> torch.autograd.graph.set_save_on_cpu_hooks()
>>> e = d * c # d and c are saved on CPU
>>> torch.autograd.graph.reset_saved_tensors_default_hooks()
>>> f = a * e # a and e are saved on GPU
>>> del a, b, c, d, e
>>> # the content of a, b, e are still alive on GPU
>>> # the content of c and d only live on CPU
"""
def pack_to_cpu(tensor):
if not pin_memory:
return (tensor.device, tensor.cpu())

Expand All @@ -19,8 +54,8 @@ def pack_hook(tensor):
storage.copy_(tensor)
return (tensor.device, storage)

def unpack_hook(packed):
def unpack_from_cpu(packed):
device, tensor = packed
return tensor.to(device, non_blocking=pin_memory)

torch._C._autograd._register_default_hooks(pack_hook, unpack_hook)
torch._C._autograd._register_default_hooks(pack_to_cpu, unpack_from_cpu)

0 comments on commit bf149b3

Please sign in to comment.