Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTAutograd] add export entrypoints #100587

Closed
wants to merge 4 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented May 3, 2023

The main addition in this PR is two new API's in AOTAutograd.

APIs

aot_export_module: Given a module, exports it into a functionalized FX graph. Returns an fx.GraphModule, GraphSignature pair. The GraphSignature tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify trace_joint=True, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is required to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used @SherlockNoMad and @suo's internal version of the GraphSignature object for this API, with a few minor changes in order to integrate it into AOTAutograd.

aot_export_joint_simple: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is not required to return a scalar loss. However, this API makes the guarantee that you do not need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:

  • If you pass trace_joint=False, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
  • If you pass trace_joint=True, then you will need to manually use the default_partitioner or min_cut_partitioner from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like torch.cond() can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the true_bw, false_bw subgraphs. cc @zou3519 @Chillee

Implementation Strategy

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

  • The two new API's are both thin wrappers around _aot_export_function: this is a general purpose export API, that just re-uses create_aot_dispatcher_function. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using _aot_export_function.
  • aot_export_module works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under torch.no_grad() to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually .detach() all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
  • A large portion of aot_export_module comes from parsing out and creating a GraphSignature object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see doc). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
  • I factored out create_functional_call() and create_tree_flattened_fn() for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
  • I added an AOTConfig.is_export flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
  • aot_dispatch_autograd() now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an autograd.Function. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented May 3, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100587

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 59a4374:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request May 3, 2023
ghstack-source-id: ffa6f74fa8c9b4046d1e9934de8cfb59589d6ae6
Pull Request resolved: #100587
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 3, 2023

TODO: I added one large test for aot_export_module, but I still need to add some tests for aot_export_joint_simplified, and for all of the error cases where we ban certain functions/modules (stubbed out for now).

@albanD albanD removed their request for review May 3, 2023 22:18
pass

def test_aot_export_functionalized_rng_banned(self):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO I suppose

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, will add these in before landing

it corresponds to (forward input)
(2) A mapping from each gradient (backwards output) to the user input
it corresponds to (forward input)
(3) Which of the forward outputs corresponds to the loss, that we backprop on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth saying explicitly what the str denotes (node.target, amirite?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""

# Parameters/buffers are named according to their name in the GraphModule,
# *not* their name in the input/output list.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confusing. Maybe worth a type alias?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs_to_parameters are recording the name mapping.
key is graph input name
value is parameter fqn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up making some type aliases. FQN, InputName and OutputName

(2) `func` cannot mutate any inputs
(3) The outputs of `func` cannot alias any inputs.

Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If you actually want to docblock it, the docblock goes in the function, not on top.

)
out_spec.set(spec)
return flat_out
return flat_fn, out_spec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any difference in code movement here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, copy paste

"have tuple outputs or use aot_module instead."
)
return out
return functional_call
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any change from code motion here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, just a straight move out into its own function

# of their forward),
# we'll automatically detach them here.
if o.requires_grad and i != output_loss_index:
o.detach_()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK to do this inplace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now error instead of detaching

trace_joint: bool,
# If trace_joint is True, we expect your module to return a scalar loss.
# Your module can return multiple outputs, so you must specify which output the loss is.
output_loss_index: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If trace_joint = True but output_loss_index is None, what is the behavior?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be invalid input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

out_loss = out
num_fw_outs = 1
assert isinstance(out_loss, torch.Tensor)
assert out_loss.requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of asserts here is inappropriate, because users can violate these asserts via inputs. Do real RuntimeErrors with proper error messages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated them all to RuntimeErrors, thanks for the callout

if isinstance(a, torch.Tensor) and a.requires_grad:
assert grad is not None, """\
Found a parameter that did not receive a gradient.
"This is most likely a bug, but if this needs to be supported please file an issue on GitHub."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to preemptively file bugs and link 'em in the error message, so that I know I will get notified if someone does need it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call - updated, issue here #101192

fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos)
)
# Attempt to run the fw_module with the original user inputs
fw_module(*args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhh, is args fake or not here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch... we fakeify later, so they're not guaranteed to be fake here. I'll ensure fakeness

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Ship it!

inputs_to_parameters: Dict[str, str]
inputs_to_buffers: Dict[str, str]

buffers_to_mutate: Dict[str, str]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment for this field, (or have a better name make it explicit)
key is buffer graph output name
value is buffer fqn

Comment on lines +975 to +979
# Calling convention assumptions:
# (1) graph inputs = (params, buffers, user_inputs)
# (2) graph outputs = (mutated_inputs, user_outs, param_gradients)
# (If we are capturing an inference graph, this convention is identical
# except that param_gradients is empty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love it!

allow_unused=True,
)
# If our output is a scalar loss, we don't need to pass in tangents.
if len(needed_tangents) == 1 and needed_tangents[0].ndim == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss could be 1d scalar? A more robust condition is

needed_tangents[0].numel() == 1

# TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect
# when we need to manually detach() some inputs in the forward.
# Higher order ops might eventually need to do the same.
return fx_g
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly concern about returning a different type for aot_dispatch_autograd.
The function return signature will be confused by this...

I wonder if it can be refactored into sth like this

def aot_dispatch_autograd_graph_version -> GraphModule:
      ....
      return fx_g

def aot_dispatch_autograd -> Callable:
     fx_g = aot_dispatch_autograd_graph_version(...)
     fn = create_runtime_wrapper(fx_g)
     return fn

Please also feel free to ignore this nitpick, coz I know it might be hard to do so in practice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed - I originally planned to push this off and wait for more clarity on how functionalized RNG fits into the export picture. But I may as well do this refactor now (and I'll leave functionalized RNG only into the callable version for now).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return flat_out
return flat_fn, out_spec

def _graph_input_names(gm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick... this should be be graph module API....

return [node.name for node in gm.graph.nodes if node.op == "placeholder"]


def _graph_output_names(gm):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here... this should be be graph module API....

(4) Any input mutations will be converted into additional outputs in the graph,
meaning whoever calls this graph is responsible for applying the mutations
back to the original inputs.
(5) If is_backward is provided the graph will return parameter gradients in addition to user outputs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if trace_joint is provided?

(3) Metadata mutations on params/buffers/inputs are banned.
(4) Data mutations on anything that requires gradients are banned (parameters)
(5) If an input is mutated, it is not allowed to alias any other inputs.
(6) Parameters must not be duplicated.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!

# of their forward),
# we'll automatically detach them here.
if o.requires_grad and i != output_loss_index:
o.detach_()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should throw? We only allow loss to require_grad.
calling detach for user might be surprising behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair, I'll error instead.

)

"""
A simplified version of export. Used by higher order operators.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to understand the use case here...

Say we have a torch.cond op, that we wish to export. Both branches will be passed into this aot_export_joint_simple function?

I am also having a hard time imagine how trace_joint will behave for high-order ops...
e.g. The backward subgraph for torch.cond's branch is not directly connected to the forward subgraph... so there are not strictly "jointed" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the main use for this API isn't really strictly for export - it's so that we can generate the backward formulas for higher order ops, like torch.cond, map, etc.

I have a test coming soon, but I'm picturing that this API will be used by higher order ops roughly like this:
image

I'm open to feedback on exactly what this API should return (cc @zou3519). I went off the thought process of:
(1) partitioning a joint into a fw/bw is only a few lines of code (see code snippet above)
(2) returning a joint graph is a bit more general, and alleviates AOTAutograd from having to tell the caller how to use / call the partitioned backward graph (although maybe this isn't a real concern).

Copy link
Contributor

@SherlockNoMad SherlockNoMad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @bdhirsh for pushing this to finish line!!!

The main addition in this PR is two new API's in AOTAutograd.

**APIs**

`aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.

`aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
* If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
* If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee 

**Implementation Strategy**

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

* The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
* `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
* A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
* I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
* I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
* `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 11, 2023
ghstack-source-id: 0d5a31951d4f0348b0fb15f98a6fe72e68656523
Pull Request resolved: #100587
The main addition in this PR is two new API's in AOTAutograd.

**APIs**

`aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.

`aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
* If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
* If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee 

**Implementation Strategy**

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

* The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
* `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
* A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
* I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
* I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
* `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.




[ghstack-poisoned]
@bdhirsh bdhirsh added the ciflow/trunk Trigger trunk jobs on your pull request label May 11, 2023
bdhirsh added a commit that referenced this pull request May 11, 2023
ghstack-source-id: f876a0bb8c6e8b5f660bf4bc4798e128ef3058b4
Pull Request resolved: #100587
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 12, 2023

hmm. I added this logic to create_joint():

                # If our output is a scalar loss, we don't need to pass in tangents.
                if len(needed_tangents) == 1 and needed_tangents[0].numel() == 1:
                    backward_out = torch.autograd.grad(
                        needed_outs,
                        grad_primals,
                        allow_unused=True,
                    )
                else:
                    backward_out = torch.autograd.grad(
                        needed_outs,
                        grad_primals,
                        grad_outputs=needed_tangents,
                        allow_unused=True,
                    )

Where the idea was that for export, we want to make sure that there are no tangent inputs to our joint graph, since we know we're just calling .backward() (which will implicitly pass a scalar 1 as the tangent).

This doesn't actually work though, because we don't know what the value of the scalar is. If the scalar is a 1, then the above is ok, but if it's something else, then we need to actually use the tangent value. And we don't know this info at trace time, since we don't know value of the tangent.

The easiest thing to do is probably to plumb aot_config into create_joint() and branch off of aot_config.is_export

The main addition in this PR is two new API's in AOTAutograd.

**APIs**

`aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.

`aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
* If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
* If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee 

**Implementation Strategy**

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

* The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
* `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
* A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
* I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
* I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
* `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.




[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2023
ghstack-source-id: a65ef1fed5759981e27f2f02b0c5d5fd3a119751
Pull Request resolved: #100587
# need to know if it is a 1 (no need to make it a graph input), or something else
# (requiring it to be a graph input).
# We don't know this info at trace time though, so we need to make it an explicit config.
no_tangents: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang I don't love this but let me know what you think. My goal was to avoid having to duplicate the logic for calling autograd.grad() - if we want a joint graph with no tangent inputs, we need to make sure we don't specify grad_outputs=.... I originally thought I could detect that automatically by checking to see if our tangent is a single scalar, but that's not true - if the scalar is not 1, then it should still be an input to our graph (but we don't know its value at trace time).

@ezyang
Copy link
Contributor

ezyang commented May 15, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jcaip pushed a commit that referenced this pull request May 23, 2023
The main addition in this PR is two new API's in AOTAutograd.

**APIs**

`aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used @SherlockNoMad and @suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.

`aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
* If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
* If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc @zou3519 @Chillee

**Implementation Strategy**

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

* The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
* `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
* A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
* I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
* I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
* `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.

Pull Request resolved: #100587
Approved by: https://github.com/ezyang, https://github.com/SherlockNoMad
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/415/head branch June 8, 2023 15:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: AO frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants