From cae9e7d8faa21bd5f82c5eb0faf537fec2cef538 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 27 Dec 2022 18:45:06 -0500 Subject: [PATCH] [WIP] Checkpoint that can be nested and contain grad ghstack-source-id: 5805313f59c92e32e791d10e3021cc3ed100a81d Pull Request resolved: https://github.com/pytorch/pytorch/pull/90105 --- torch/autograd/graph.py | 201 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index fc490a9d8e31..2de3d9171ef0 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -145,6 +145,205 @@ def unpack_from_cpu(packed): super().__init__(pack_to_cpu, unpack_from_cpu) +# NOTE: [new checkpoint mechanism] +# +# Contents: +# - Definition +# - Mechanism design +# - Example +# +# Definition: Checkpointing +# ========================= +# +# We define checkpoint as a context manager such that any variables that +# were saved by forward under this context AND remain saved at the point +# in time when we exit the context are cleared and marked as needed during +# recomputation. +# +# In other words, if saved tensors are manually cleared, e.g. by running +# backward, checkpoint will ignore those tensors because it only cares about +# the set of tensors saved at the point in time when the context exits. +# +# with checkpoint(): +# y = x.sin() # saves x +# z = y.exp() # saves z +# torch.autograd.grad(z, z) # clears z, only x remains saved +# # As we exit, clears x only +# +# This may also mean that we cannot simply halt execution early as soon as we've +# saved the right number of buffers. +# +# Special handling of input for the nested case +# --------------------------------------------- +# +# There is some specially handling for the nested case: the inputs to +# are treated as saved variables in the parent context. +# +# with checkpoint0(): +# with checkpoint1(): # saves `y` in check0 +# y = f(x) # f's saved variables are cleared by check1 +# with checkpoint2(): # saves `z` in check0 +# z = g(y) # g's saved variables are cleared by check1 +# # exiting check0, clears `y` and `z` +# # whatever f and g save are hidden +# +# NB: We never need to recompute function until we have finished running it +# +# TODO: Handling of free variables. +# The current stack was created to recompute values for some call to checkpoint. +# If that checkpointed function calls checkpoint one or more times (possibly +# in a nested way), the inputs to the top-most checkpoints are cleared, so if we +# detect that we are the direct children to the ambient frame, we save its inputs +# +# We can register as many hooks as we want, they all do the same thing +# Backward usually happens outside of the context of any checkpoint anyway +# so we'll at least need to call this once per recomputation stack +# +# We can reuse the checkpoint code because what we are doing is very similar +# to checkpointing, we want to manage the tensors saved and collect the ones +# that remain alive at the very end. +# Creating a new CheckpointStack creates an ambient frame which manages this +# +# Sketch of the mechanism +# ======================= +# TODO: Why we need a stack of stacks +# +# Demonstrating the mechanism with an example +# ------------------------------------------- +# TODO: reference cycle concerns +# + +class CheckpointFrame(): + def __init__(self, parent, fn): + self.parent = weakref.ref(parent) if parent is not None else parent + self.fn = fn + self.exited = False + self.needed_counter = 0 + # Assume that entries are added/removed determinsitically + # and that removal preserves order + self.saved_tensors: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self.child_inputs: List[Tuple[Any, ...]] = [] + # Set this when we leave + self.num_needs_recompute = 0 + self.idx_of_input_in_parent = None + self.num_child_inputs_need_recompute = 0 + + # During recomputation + self.recomputed: List[torch.Tensor] = [] + self.recomputed_child_inputs: List[Tuple[Any, ...]] = [] + +class CheckpointStack(): + def __init__(self, parent, target_frame=None): + self.parent: CheckpointStack = parent + self.target_frame: CheckpointFrame = target_frame + self.frames = [CheckpointFrame(parent=None, fn=None)] + +checkpoint_stacks: List[CheckpointStack] = [] + +class _checkpoint_hook(saved_tensors_hooks): + def __init__(self): + def pack_hook(x): + # Snapshot the state of the current checkpoint stack + handle = _Handle() + frames = tuple(checkpoint_stacks[-1].frames) + frames[-1].saved_tensors[handle] = x + frames[-1].needed_counter += 1 + return handle, frames, frames[-1].needed_counter - 1 + + def unpack_hook(saved): + handle, frames, idx = saved + + # backward was called before leaving checkpoint context + if not frames[-1].exited: + assert handle in frames[-1].saved_tensors + return frames[-1].saved_tensors[handle] + # TODO: give a nice error for when we backwarded during the context + + assert len(frames[-1].recomputed) <= frames[-1].num_needs_recompute + if len(frames[-1].recomputed) < frames[-1].num_needs_recompute: + # The first frame is always the ambient frame + for frame in frames[1:]: + if (frame.exited is False + or (len(frame.recomputed) == frame.num_needs_recompute + and len(frame.recomputed_child_inputs) == frame.num_child_inputs_need_recompute)): + continue + assert frame.parent is not None and frame.parent() is not None + parent = frame.parent() + inps = parent.recomputed_child_inputs if parent.exited else parent.child_inputs + args, kwargs = inps[frame.idx_of_input_in_parent] + + with torch.autograd.enable_grad(): + # Do recomputation in a fresh checkpoint stack + _checkpoint(frame.fn, *args, target_frame=frame, **kwargs) + + ret = frames[-1].recomputed[idx] + return ret + + super().__init__(pack_hook, unpack_hook) + +def get_wrapped_fn(fn): + # Capture the current context, so we can replay it + def wrapped(*args, **kwargs): + return fn(*args, **kwargs) + return wrapped + +def _checkpoint(fn, *args, target_frame: CheckpointFrame = None, **kwargs): + needs_to_pop = False + if len(checkpoint_stacks) == 0 or target_frame is not None: + needs_to_pop = True + parent = None if len(checkpoint_stacks) == 0 else checkpoint_stacks[-1] + checkpoint_stacks.append(CheckpointStack(parent=parent, target_frame=target_frame)) + + curr_stack = checkpoint_stacks[-1] + + if target_frame is None: + # Create a proper checkpoint frame and append it to the current stack + wrapped_fn = get_wrapped_fn(fn) + curr_frame = CheckpointFrame(parent=curr_stack.frames[-1], fn=wrapped_fn) + curr_stack.frames.append(curr_frame) + else: + # Don't need a checkpoint frame if we're starting a stack for recomputation + assert len(curr_stack.frames) == 1 + curr_frame = curr_stack.frames[0] + + if (curr_stack.target_frame is not None + and curr_frame.parent is not None + and curr_frame.parent() is curr_stack.frames[0]): + # Top-level checkpoints save their inputs to the target frame during recomputation + assert len(curr_stack.frames) == 2 + curr_stack.target_frame.recomputed_child_inputs.append((args, kwargs)) + + with _checkpoint_hook(): + # We can register this hook as many times as we want, it only reads global state + ret = fn(*args, **kwargs) + + curr_frame.num_needs_recompute = len(curr_frame.saved_tensors) + if target_frame is None: + # Children register their inputs to the parent checkpoint to be cleared + # when the parent exits and restored when the parent is recomputed + inputs = (args, kwargs) + assert curr_frame.parent is not None + parent = curr_frame.parent() + assert parent is not None + curr_frame.idx_of_input_in_parent = len(parent.child_inputs) + parent.child_inputs.append(inputs) + parent.num_child_inputs_need_recompute += 1 + else: + # Stack for recomputation is getting destroyed, save into the target frame + assert len(target_frame.recomputed) == 0 + detached_saved = [t.detach() for t in curr_frame.saved_tensors.values()] + target_frame.recomputed.extend(detached_saved) + + curr_frame.child_inputs.clear() + curr_frame.saved_tensors.clear() + curr_frame.exited = True + + curr_stack.frames.pop() + if needs_to_pop: + checkpoint_stacks.pop() + + return ret + @contextlib.contextmanager def disable_saved_tensors_hooks(error_message): @@ -167,6 +366,8 @@ def disable_saved_tensors_hooks(error_message): ... pass """ + yield + return try: maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message() torch._C._autograd._saved_tensors_hooks_disable(error_message)