Skip to content

Commit

Permalink
Update on "Update save_on_cpu and checkpointing to work with functorc…
Browse files Browse the repository at this point in the history
…h wrapped tensors"


Design doc: https://docs.google.com/document/d/1OX5__xKsZP-natgEnsRrK4gfD0wSN9WS6j4kekfn-IA/edit

This approach saves the inner-most tensor. As we unwrap, we also append to a list of rewrap functions that capture the necessary information to restore the metadata on the original wrappers. This PR tries to do most things in Python, but there are probably some APIs that could exist (or maybe already exist) that could simplify this PR.

- This PR does very weird things to stash autograd metadata:
  - The rewrap function needs to capture autograd metadata so that the higher order graphs don't get disconnected, we reuse TensorWrapper to do this, but in a way that is careful not to save the original TensorWrapper's data
  - During packing, we do a clone on the original TensorWrapper, then replace the value_ with an empty tensor, so this new dataless TensorWrapper gets captured instead by rewrap fn
  - During unpacking, when we run the rewrap fn, we just replace the value_ with the value we desire (this could either be the 
    recomputed value or value that was previously offloaded)
  - The API exposed to replace value_ is set_data!
- There doesn't seem to be a reliable way to uniquely identify a tensor since id() gets reused, using data_ptr helps but it is 
  also not enough sometimes. In this PR, I'm also using the first element of the Tensor to get a test to pass.

Unanswered questions:
- Why did we need to enable grad mode while packing (where was it disabled)

Other prototypes:
- #89159 (alternate approach that saves the outer-most tensor instead and unwraps the necessary number of layers during unpack - the issue is that we cannot tell when we are saving the outer-most tensor)
- #88976 (same approach as this PR, but in cpp, unfinished)

TODO:
- verify in tests that we are actually saving the correct amount of tensors
- try a non-zero bdim
- make that assert more precise


[ghstack-poisoned]
  • Loading branch information
soulitzer committed Nov 28, 2022
1 parent 91d4539 commit a02b774
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
28 changes: 25 additions & 3 deletions torch/autograd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,37 @@ def _functorch_unwrap_to_level_no_rewrap(tensor: torch.Tensor, target_level: int
while current_level > target_level:
tensor = torch._C._functorch._unwrap_for_grad(tensor, current_level)
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level == target_level, (current_level, target_level)
assert max(current_level, -1) == max(target_level, -1), (current_level, target_level)
return tensor

def _functorch_unwrap_to_level_no_rewrap_alltheway(tensor: torch.Tensor) -> torch.Tensor:
# I clearly don't know how wrapper level work, but this works lol!
#
# We're running into TensorWrapper(lvl=-2, inner=TensorWrapper(lvl=1, inner=Tensor(1.))
# We cannot use unwrap_for_grad here, but get_unwrapped works ok.
current_level = torch._C._functorch.maybe_get_level(tensor)
prev_level = 9999
while True:
try:
tensor = torch._C._functorch.get_unwrapped(tensor)
# tensor = torch._C._functorch._unwrap_for_grad(tensor, current_level)
current_level = torch._C._functorch.maybe_get_level(tensor)
if current_level == prev_level:
break
else:
prev_level = current_level
except Exception as e:
print("error", e)
break
return tensor


# It might be better to do more things in cpp:
# https://github.com/pytorch/pytorch/pull/88976
def _functorch_unwrap_to_level(tensor: torch.Tensor, target_level: int) -> torch.Tensor:
assert target_level != 0, "level 0 is not supported, you should pass -1 instead"
current_level = torch._C._functorch.maybe_get_level(tensor)
assert current_level >= target_level, (current_level, target_level)
assert max(current_level, -1) >= max(target_level, -1), (current_level, target_level)
rewrap_fns = []
for _ in range(max(current_level, 0), max(target_level, 0), -1):
current_level = torch._C._functorch.maybe_get_level(tensor)
Expand All @@ -181,7 +203,7 @@ def _functorch_unwrap_to_level(tensor: torch.Tensor, target_level: int) -> torch
rewrap_fns.append(rewrap_fn)

result_level = torch._C._functorch.maybe_get_level(tensor)
assert result_level == target_level, (result_level, target_level)
assert max(result_level, -1) == max(target_level, -1), (result_level, target_level)
return tensor, rewrap_fns

class save_on_cpu(saved_tensors_hooks):
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
import weakref
from typing import Any, Iterable, List, Tuple
from torch.autograd.graph import _get_tid, _functorch_unwrap_to_level
from torch.autograd.graph import _get_tid, _functorch_unwrap_to_level, _functorch_unwrap_to_level_no_rewrap, _functorch_unwrap_to_level_no_rewrap_alltheway

__all__ = [
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
Expand Down Expand Up @@ -451,7 +451,7 @@ def inner_unpack(packed):

# Wrap all the way to the inner-most tensor, and rewrap using the
# rewrap function saved from forward
ret = _functorch_unwrap_to_level(storage[handle], -1)[0]
ret = _functorch_unwrap_to_level_no_rewrap_alltheway(storage[handle])
for fn in reversed(rewrap_fns):
ret = fn(ret)
return ret
Expand Down

0 comments on commit a02b774

Please sign in to comment.