Skip to content

Commit

Permalink
Update on "[AOTAutograd] add export entrypoints"
Browse files Browse the repository at this point in the history
The main addition in this PR is two new API's in AOTAutograd.

**APIs**

`aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.

`aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
* If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
* If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee 

**Implementation Strategy**

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

* The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
* `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
* A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
* I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
* I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
* `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.




[ghstack-poisoned]
  • Loading branch information
bdhirsh committed May 12, 2023
2 parents 37bf0dc + 11ea302 commit 59a4374
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ class AOTConfig:
aot_id: int
keep_inference_input_mutations: bool
is_export: bool = False
no_tangents: bool = False
dynamic_shapes: bool = False
aot_autograd_arg_pos_to_source : Optional[List[Source]] = None
inference_compiler: Optional[Callable] = None
Expand Down Expand Up @@ -1195,7 +1196,7 @@ def inner_fn(*args):
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
def create_joint(
fn: Callable,
fn: Callable, *, aot_config: AOTConfig
) -> Any:
def inner_fn(primals: List[Any], tangents: List[Any]):
outs, tangent_mask = fn(*primals)
Expand Down Expand Up @@ -1235,8 +1236,9 @@ def inner_fn(primals: List[Any], tangents: List[Any]):
# Call the backwards pass
if grad_primals:
with fx_traceback.preserve_node_meta():
# If our output is a scalar loss, we don't need to pass in tangents.
if len(needed_tangents) == 1 and needed_tangents[0].numel() == 1:
# for full graph export, we always export a joint graph where we assume no tangents are needed.
if aot_config.no_tangents:
assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
Expand Down Expand Up @@ -2628,7 +2630,7 @@ def aot_dispatch_autograd_graph(flat_fn, flat_args: List[Any], aot_config: AOTCo
flat_fn,
fw_metadata,
)
joint_fn_to_trace = create_joint(fn_prepared_for_autograd)
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)

fx_g = create_functionalized_graph(
joint_fn_to_trace,
Expand Down Expand Up @@ -3485,6 +3487,7 @@ def aot_function(
dynamic_shapes=dynamic,
aot_autograd_arg_pos_to_source=None,
is_export=False,
no_tangents=False,
enable_log=enable_log,
)
cached_res = None
Expand Down Expand Up @@ -3658,6 +3661,7 @@ def aot_module_simplified(
dynamic_shapes=dynamic_shapes,
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
is_export=False,
no_tangents=False,
)

compiled_fn = create_aot_dispatcher_function(
Expand Down Expand Up @@ -3797,6 +3801,7 @@ def fn_to_trace(*args):
full_args,
decompositions=decompositions,
num_params_buffers=len(params_and_buffers_flat),
no_tangents=True,
)
if trace_joint:
def flattened_joint(*args):
Expand Down Expand Up @@ -3934,6 +3939,14 @@ def _aot_export_function(
*,
num_params_buffers: int = 0,
decompositions: Optional[Dict] = None,
# If we're exporting a joint graph and we don't want any tangent inputs in the graph
# (because we are backpropping through a scalar 1 loss),
# we need to explicitly specify not to include tangents in the graph.
# It's not enough just to check that our tangent is a scalar, since we also
# need to know if it is a 1 (no need to make it a graph input), or something else
# (requiring it to be a graph input).
# We don't know this info at trace time though, so we need to make it an explicit config.
no_tangents: bool = False,
) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
dynamic_shapes = False
for x in args:
Expand Down Expand Up @@ -3962,6 +3975,7 @@ def _aot_export_function(
dynamic_shapes=dynamic_shapes,
aot_autograd_arg_pos_to_source=None,
is_export=True,
no_tangents=no_tangents,
)

fx_g, meta = create_aot_dispatcher_function(
Expand Down

0 comments on commit 59a4374

Please sign in to comment.