Skip to content

Commit

Permalink
[WIP] Checkpoint that can be nested and contain grad
Browse files Browse the repository at this point in the history
ghstack-source-id: 5805313f59c92e32e791d10e3021cc3ed100a81d
Pull Request resolved: #90105
  • Loading branch information
soulitzer committed Dec 27, 2022
1 parent eb01f47 commit cae9e7d
Showing 1 changed file with 201 additions and 0 deletions.
201 changes: 201 additions & 0 deletions torch/autograd/graph.py
Expand Up @@ -145,6 +145,205 @@ def unpack_from_cpu(packed):

super().__init__(pack_to_cpu, unpack_from_cpu)

# NOTE: [new checkpoint mechanism]
#
# Contents:
# - Definition
# - Mechanism design
# - Example
#
# Definition: Checkpointing
# =========================
#
# We define checkpoint as a context manager such that any variables that
# were saved by forward under this context AND remain saved at the point
# in time when we exit the context are cleared and marked as needed during
# recomputation.
#
# In other words, if saved tensors are manually cleared, e.g. by running
# backward, checkpoint will ignore those tensors because it only cares about
# the set of tensors saved at the point in time when the context exits.
#
# with checkpoint():
# y = x.sin() # saves x
# z = y.exp() # saves z
# torch.autograd.grad(z, z) # clears z, only x remains saved
# # As we exit, clears x only
#
# This may also mean that we cannot simply halt execution early as soon as we've
# saved the right number of buffers.
#
# Special handling of input for the nested case
# ---------------------------------------------
#
# There is some specially handling for the nested case: the inputs to
# are treated as saved variables in the parent context.
#
# with checkpoint0():
# with checkpoint1(): # saves `y` in check0
# y = f(x) # f's saved variables are cleared by check1
# with checkpoint2(): # saves `z` in check0
# z = g(y) # g's saved variables are cleared by check1
# # exiting check0, clears `y` and `z`
# # whatever f and g save are hidden
#
# NB: We never need to recompute function until we have finished running it
#
# TODO: Handling of free variables.
# The current stack was created to recompute values for some call to checkpoint.
# If that checkpointed function calls checkpoint one or more times (possibly
# in a nested way), the inputs to the top-most checkpoints are cleared, so if we
# detect that we are the direct children to the ambient frame, we save its inputs
#
# We can register as many hooks as we want, they all do the same thing
# Backward usually happens outside of the context of any checkpoint anyway
# so we'll at least need to call this once per recomputation stack
#
# We can reuse the checkpoint code because what we are doing is very similar
# to checkpointing, we want to manage the tensors saved and collect the ones
# that remain alive at the very end.
# Creating a new CheckpointStack creates an ambient frame which manages this
#
# Sketch of the mechanism
# =======================
# TODO: Why we need a stack of stacks
#
# Demonstrating the mechanism with an example
# -------------------------------------------
# TODO: reference cycle concerns
#

class CheckpointFrame():
def __init__(self, parent, fn):
self.parent = weakref.ref(parent) if parent is not None else parent
self.fn = fn
self.exited = False
self.needed_counter = 0
# Assume that entries are added/removed determinsitically
# and that removal preserves order
self.saved_tensors: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
self.child_inputs: List[Tuple[Any, ...]] = []
# Set this when we leave
self.num_needs_recompute = 0
self.idx_of_input_in_parent = None
self.num_child_inputs_need_recompute = 0

# During recomputation
self.recomputed: List[torch.Tensor] = []
self.recomputed_child_inputs: List[Tuple[Any, ...]] = []

class CheckpointStack():
def __init__(self, parent, target_frame=None):
self.parent: CheckpointStack = parent
self.target_frame: CheckpointFrame = target_frame
self.frames = [CheckpointFrame(parent=None, fn=None)]

checkpoint_stacks: List[CheckpointStack] = []

class _checkpoint_hook(saved_tensors_hooks):
def __init__(self):
def pack_hook(x):
# Snapshot the state of the current checkpoint stack
handle = _Handle()
frames = tuple(checkpoint_stacks[-1].frames)
frames[-1].saved_tensors[handle] = x
frames[-1].needed_counter += 1
return handle, frames, frames[-1].needed_counter - 1

def unpack_hook(saved):
handle, frames, idx = saved

# backward was called before leaving checkpoint context
if not frames[-1].exited:
assert handle in frames[-1].saved_tensors
return frames[-1].saved_tensors[handle]
# TODO: give a nice error for when we backwarded during the context

assert len(frames[-1].recomputed) <= frames[-1].num_needs_recompute
if len(frames[-1].recomputed) < frames[-1].num_needs_recompute:
# The first frame is always the ambient frame
for frame in frames[1:]:
if (frame.exited is False
or (len(frame.recomputed) == frame.num_needs_recompute
and len(frame.recomputed_child_inputs) == frame.num_child_inputs_need_recompute)):
continue
assert frame.parent is not None and frame.parent() is not None
parent = frame.parent()
inps = parent.recomputed_child_inputs if parent.exited else parent.child_inputs
args, kwargs = inps[frame.idx_of_input_in_parent]

with torch.autograd.enable_grad():
# Do recomputation in a fresh checkpoint stack
_checkpoint(frame.fn, *args, target_frame=frame, **kwargs)

ret = frames[-1].recomputed[idx]
return ret

super().__init__(pack_hook, unpack_hook)

def get_wrapped_fn(fn):
# Capture the current context, so we can replay it
def wrapped(*args, **kwargs):
return fn(*args, **kwargs)
return wrapped

def _checkpoint(fn, *args, target_frame: CheckpointFrame = None, **kwargs):
needs_to_pop = False
if len(checkpoint_stacks) == 0 or target_frame is not None:
needs_to_pop = True
parent = None if len(checkpoint_stacks) == 0 else checkpoint_stacks[-1]
checkpoint_stacks.append(CheckpointStack(parent=parent, target_frame=target_frame))

curr_stack = checkpoint_stacks[-1]

if target_frame is None:
# Create a proper checkpoint frame and append it to the current stack
wrapped_fn = get_wrapped_fn(fn)
curr_frame = CheckpointFrame(parent=curr_stack.frames[-1], fn=wrapped_fn)
curr_stack.frames.append(curr_frame)
else:
# Don't need a checkpoint frame if we're starting a stack for recomputation
assert len(curr_stack.frames) == 1
curr_frame = curr_stack.frames[0]

if (curr_stack.target_frame is not None
and curr_frame.parent is not None
and curr_frame.parent() is curr_stack.frames[0]):
# Top-level checkpoints save their inputs to the target frame during recomputation
assert len(curr_stack.frames) == 2
curr_stack.target_frame.recomputed_child_inputs.append((args, kwargs))

with _checkpoint_hook():
# We can register this hook as many times as we want, it only reads global state
ret = fn(*args, **kwargs)

curr_frame.num_needs_recompute = len(curr_frame.saved_tensors)
if target_frame is None:
# Children register their inputs to the parent checkpoint to be cleared
# when the parent exits and restored when the parent is recomputed
inputs = (args, kwargs)
assert curr_frame.parent is not None
parent = curr_frame.parent()
assert parent is not None
curr_frame.idx_of_input_in_parent = len(parent.child_inputs)
parent.child_inputs.append(inputs)
parent.num_child_inputs_need_recompute += 1
else:
# Stack for recomputation is getting destroyed, save into the target frame
assert len(target_frame.recomputed) == 0
detached_saved = [t.detach() for t in curr_frame.saved_tensors.values()]
target_frame.recomputed.extend(detached_saved)

curr_frame.child_inputs.clear()
curr_frame.saved_tensors.clear()
curr_frame.exited = True

curr_stack.frames.pop()
if needs_to_pop:
checkpoint_stacks.pop()

return ret


@contextlib.contextmanager
def disable_saved_tensors_hooks(error_message):
Expand All @@ -167,6 +366,8 @@ def disable_saved_tensors_hooks(error_message):
... pass
"""
yield
return
try:
maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
torch._C._autograd._saved_tensors_hooks_disable(error_message)
Expand Down

0 comments on commit cae9e7d

Please sign in to comment.