Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTAutograd] add export entrypoints #100587

Closed
wants to merge 4 commits into from

Commits on May 3, 2023

  1. [AOTAutograd] add export entrypoints

    [ghstack-poisoned]
    bdhirsh committed May 3, 2023
    Configuration menu
    Copy the full SHA
    f38870f View commit details
    Browse the repository at this point in the history

Commits on May 11, 2023

  1. Update on "[AOTAutograd] add export entrypoints"

    The main addition in this PR is two new API's in AOTAutograd.
    
    **APIs**
    
    `aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.
    
    There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.
    
    I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.
    
    `aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
    * If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
    * If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.
    
    The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee 
    
    **Implementation Strategy**
    
    A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:
    
    * The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
    * `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
    * A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
    * I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
    * I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
    * `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.
    
    
    
    
    [ghstack-poisoned]
    bdhirsh committed May 11, 2023
    Configuration menu
    Copy the full SHA
    81b73a0 View commit details
    Browse the repository at this point in the history
  2. Update on "[AOTAutograd] add export entrypoints"

    The main addition in this PR is two new API's in AOTAutograd.
    
    **APIs**
    
    `aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.
    
    There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.
    
    I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.
    
    `aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
    * If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
    * If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.
    
    The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee 
    
    **Implementation Strategy**
    
    A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:
    
    * The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
    * `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
    * A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
    * I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
    * I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
    * `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.
    
    
    
    
    [ghstack-poisoned]
    bdhirsh committed May 11, 2023
    Configuration menu
    Copy the full SHA
    37bf0dc View commit details
    Browse the repository at this point in the history

Commits on May 12, 2023

  1. Update on "[AOTAutograd] add export entrypoints"

    The main addition in this PR is two new API's in AOTAutograd.
    
    **APIs**
    
    `aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.
    
    There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.
    
    I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.
    
    `aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
    * If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
    * If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.
    
    The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee 
    
    **Implementation Strategy**
    
    A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:
    
    * The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
    * `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
    * A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
    * I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
    * I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
    * `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.
    
    
    
    
    [ghstack-poisoned]
    bdhirsh committed May 12, 2023
    Configuration menu
    Copy the full SHA
    59a4374 View commit details
    Browse the repository at this point in the history