Skip to content

Commit

Permalink
Add context manager to save tensors on CPU
Browse files Browse the repository at this point in the history
Fix #57100

ghstack-source-id: a9e56df9653a4240b4d382f7dbec3080a3a07a3d
Pull Request resolved: #61928
  • Loading branch information
Varal7 committed Jul 21, 2021
1 parent 40b84d7 commit 268160e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/test_autograd.py
Expand Up @@ -5900,6 +5900,36 @@ def test_default_saved_variable_hooks_double_backward(self):
finally:
torch.autograd.graph.reset_saved_tensors_default_hooks()

def test_graph_save_on_cpu(self):
try:
torch.autograd.graph.set_save_on_cpu(True)
a = torch.randn(5, requires_grad=True)
y = a * a
self.assertEqual(a, y.grad_fn._saved_self)
self.assertEqual(a, y.grad_fn._saved_other)
y.sum().backward()
self.assertEqual(2 * a, a.grad)
finally:
torch.autograd.graph.set_save_on_cpu(False)

@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
def test_graph_save_on_cpu_cuda(self):
a = torch.randn(5, requires_grad=True, device="cuda")
try:
torch.autograd.graph.set_save_on_cpu(True)
y = a * a
finally:
torch.autograd.graph.set_save_on_cpu(False)
self.assertTrue(y.is_cuda)
before = CudaMemoryLeakCheck.get_cuda_memory_usage()
self.assertTrue(y.grad_fn._saved_self.is_cuda)
after = CudaMemoryLeakCheck.get_cuda_memory_usage()
self.assertGreater(after, before)
self.assertEqual(a, y.grad_fn._saved_self)
self.assertEqual(a, y.grad_fn._saved_other)
y.sum().backward()
self.assertEqual(2 * a, a.grad)


def index_perm_variable(shape, max_indices):
if not isinstance(shape, tuple):
Expand Down
16 changes: 16 additions & 0 deletions torch/autograd/saved_variable_default_hooks.py
Expand Up @@ -5,3 +5,19 @@ def set_saved_tensors_default_hooks(pack_hook, unpack_hook):

def reset_saved_tensors_default_hooks():
torch._C._autograd._reset_default_hooks()

def set_save_on_cpu(save_on_cpu):
if not save_on_cpu:
torch._C._autograd._reset_default_hooks()
return

def pack_hook(tensor):
storage = torch.empty(tensor.size(), pin_memory=torch.cuda.is_available())
storage.copy_(tensor)
return (tensor.device, storage)

def unpack_hook(packed):
device, tensor = packed
return tensor.to(device, non_blocking=True)

torch._C._autograd._register_default_hooks(pack_hook, unpack_hook)

0 comments on commit 268160e

Please sign in to comment.