Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add context manager to save tensors on CPU
Fix #57100 ghstack-source-id: 9285a37caf37f90fd65b309235d7027504ce4ca5 Pull Request resolved: #61928
- Loading branch information
Showing
2 changed files
with
47 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,30 @@ | ||
import torch | ||
|
||
from typing import Any | ||
|
||
def set_saved_tensors_default_hooks(pack_hook, unpack_hook): | ||
torch._C._autograd._register_default_hooks(pack_hook, unpack_hook) | ||
|
||
def reset_saved_tensors_default_hooks(): | ||
torch._C._autograd._reset_default_hooks() | ||
|
||
class save_on_cpu(object): | ||
r"""Context-manager under which tensors saved by the forward pass will be | ||
stored on cpu, then retrieved for backward | ||
""" | ||
def __init__(self): | ||
def pack_hook(tensor): | ||
return (tensor.device, tensor.cpu()) | ||
|
||
def unpack_hook(packed): | ||
device, tensor = packed | ||
return tensor.to(device) | ||
|
||
self.pack_hook = pack_hook | ||
self.unpack_hook = unpack_hook | ||
|
||
def __enter__(self) -> None: | ||
torch._C._autograd._register_default_hooks(self.pack_hook, self.unpack_hook) | ||
|
||
def __exit__(self, *args: Any) -> None: | ||
torch._C._autograd._reset_default_hooks() |