New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AOTAutograd] add export entrypoints #100587
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/100587
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 59a4374: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: ffa6f74fa8c9b4046d1e9934de8cfb59589d6ae6 Pull Request resolved: #100587
TODO: I added one large test for |
test/functorch/test_aotdispatch.py
Outdated
pass | ||
|
||
def test_aot_export_functionalized_rng_banned(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO I suppose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, will add these in before landing
it corresponds to (forward input) | ||
(2) A mapping from each gradient (backwards output) to the user input | ||
it corresponds to (forward input) | ||
(3) Which of the forward outputs corresponds to the loss, that we backprop on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth saying explicitly what the str denotes (node.target, amirite?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
node.name
😛 yep, will update (it came from here: https://github.com/pytorch/pytorch/pull/100587/files#diff-df954bbf954d2dcb81f687876053267ffa4ddb36ed86b7d2bd76319ff2b94416R3300)
torch/_functorch/aot_autograd.py
Outdated
""" | ||
|
||
# Parameters/buffers are named according to their name in the GraphModule, | ||
# *not* their name in the input/output list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confusing. Maybe worth a type alias?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputs_to_parameters are recording the name mapping.
key is graph input name
value is parameter fqn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ended up making some type aliases. FQN
, InputName
and OutputName
torch/_functorch/aot_autograd.py
Outdated
(2) `func` cannot mutate any inputs | ||
(3) The outputs of `func` cannot alias any inputs. | ||
|
||
Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: If you actually want to docblock it, the docblock goes in the function, not on top.
) | ||
out_spec.set(spec) | ||
return flat_out | ||
return flat_fn, out_spec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any difference in code movement here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope, copy paste
"have tuple outputs or use aot_module instead." | ||
) | ||
return out | ||
return functional_call |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any change from code motion here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope, just a straight move out into its own function
torch/_functorch/aot_autograd.py
Outdated
# of their forward), | ||
# we'll automatically detach them here. | ||
if o.requires_grad and i != output_loss_index: | ||
o.detach_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK to do this inplace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I now error instead of detaching
trace_joint: bool, | ||
# If trace_joint is True, we expect your module to return a scalar loss. | ||
# Your module can return multiple outputs, so you must specify which output the loss is. | ||
output_loss_index: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If trace_joint = True
but output_loss_index is None
, what is the behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be invalid input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
torch/_functorch/aot_autograd.py
Outdated
out_loss = out | ||
num_fw_outs = 1 | ||
assert isinstance(out_loss, torch.Tensor) | ||
assert out_loss.requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use of asserts here is inappropriate, because users can violate these asserts via inputs. Do real RuntimeErrors with proper error messages.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated them all to RuntimeErrors, thanks for the callout
torch/_functorch/aot_autograd.py
Outdated
if isinstance(a, torch.Tensor) and a.requires_grad: | ||
assert grad is not None, """\ | ||
Found a parameter that did not receive a gradient. | ||
"This is most likely a bug, but if this needs to be supported please file an issue on GitHub.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like to preemptively file bugs and link 'em in the error message, so that I know I will get notified if someone does need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call - updated, issue here #101192
torch/_functorch/aot_autograd.py
Outdated
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos) | ||
) | ||
# Attempt to run the fw_module with the original user inputs | ||
fw_module(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhh, is args fake or not here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch... we fakeify later, so they're not guaranteed to be fake here. I'll ensure fakeness
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Ship it!
torch/_functorch/aot_autograd.py
Outdated
inputs_to_parameters: Dict[str, str] | ||
inputs_to_buffers: Dict[str, str] | ||
|
||
buffers_to_mutate: Dict[str, str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment for this field, (or have a better name make it explicit)
key is buffer graph output name
value is buffer fqn
# Calling convention assumptions: | ||
# (1) graph inputs = (params, buffers, user_inputs) | ||
# (2) graph outputs = (mutated_inputs, user_outs, param_gradients) | ||
# (If we are capturing an inference graph, this convention is identical | ||
# except that param_gradients is empty) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
love it!
torch/_functorch/aot_autograd.py
Outdated
allow_unused=True, | ||
) | ||
# If our output is a scalar loss, we don't need to pass in tangents. | ||
if len(needed_tangents) == 1 and needed_tangents[0].ndim == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss could be 1d scalar? A more robust condition is
needed_tangents[0].numel() == 1
torch/_functorch/aot_autograd.py
Outdated
# TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect | ||
# when we need to manually detach() some inputs in the forward. | ||
# Higher order ops might eventually need to do the same. | ||
return fx_g |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slightly concern about returning a different type for aot_dispatch_autograd
.
The function return signature will be confused by this...
I wonder if it can be refactored into sth like this
def aot_dispatch_autograd_graph_version -> GraphModule:
....
return fx_g
def aot_dispatch_autograd -> Callable:
fx_g = aot_dispatch_autograd_graph_version(...)
fn = create_runtime_wrapper(fx_g)
return fn
Please also feel free to ignore this nitpick, coz I know it might be hard to do so in practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed - I originally planned to push this off and wait for more clarity on how functionalized RNG fits into the export picture. But I may as well do this refactor now (and I'll leave functionalized RNG only into the callable version for now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
return flat_out | ||
return flat_fn, out_spec | ||
|
||
def _graph_input_names(gm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick... this should be be graph module API....
return [node.name for node in gm.graph.nodes if node.op == "placeholder"] | ||
|
||
|
||
def _graph_output_names(gm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here... this should be be graph module API....
torch/_functorch/aot_autograd.py
Outdated
(4) Any input mutations will be converted into additional outputs in the graph, | ||
meaning whoever calls this graph is responsible for applying the mutations | ||
back to the original inputs. | ||
(5) If is_backward is provided the graph will return parameter gradients in addition to user outputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if trace_joint is provided?
torch/_functorch/aot_autograd.py
Outdated
(3) Metadata mutations on params/buffers/inputs are banned. | ||
(4) Data mutations on anything that requires gradients are banned (parameters) | ||
(5) If an input is mutated, it is not allowed to alias any other inputs. | ||
(6) Parameters must not be duplicated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great!
torch/_functorch/aot_autograd.py
Outdated
# of their forward), | ||
# we'll automatically detach them here. | ||
if o.requires_grad and i != output_loss_index: | ||
o.detach_() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should throw? We only allow loss to require_grad.
calling detach for user might be surprising behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair, I'll error instead.
torch/_functorch/aot_autograd.py
Outdated
) | ||
|
||
""" | ||
A simplified version of export. Used by higher order operators. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just trying to understand the use case here...
Say we have a torch.cond op, that we wish to export. Both branches will be passed into this aot_export_joint_simple function?
I am also having a hard time imagine how trace_joint will behave for high-order ops...
e.g. The backward subgraph for torch.cond's branch is not directly connected to the forward subgraph... so there are not strictly "jointed" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the main use for this API isn't really strictly for export - it's so that we can generate the backward formulas for higher order ops, like torch.cond
, map, etc.
I have a test coming soon, but I'm picturing that this API will be used by higher order ops roughly like this:
I'm open to feedback on exactly what this API should return (cc @zou3519). I went off the thought process of:
(1) partitioning a joint into a fw/bw is only a few lines of code (see code snippet above)
(2) returning a joint graph is a bit more general, and alleviates AOTAutograd from having to tell the caller how to use / call the partitioned backward graph (although maybe this isn't a real concern).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @bdhirsh for pushing this to finish line!!!
The main addition in this PR is two new API's in AOTAutograd. **APIs** `aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs. There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step. I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd. `aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following: * If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function * If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function. The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee **Implementation Strategy** A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely: * The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`. * `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them). * A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then. * I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code * I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export. * `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export. [ghstack-poisoned]
ghstack-source-id: 0d5a31951d4f0348b0fb15f98a6fe72e68656523 Pull Request resolved: #100587
The main addition in this PR is two new API's in AOTAutograd. **APIs** `aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs. There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step. I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd. `aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following: * If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function * If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function. The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee **Implementation Strategy** A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely: * The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`. * `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them). * A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then. * I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code * I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export. * `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export. [ghstack-poisoned]
ghstack-source-id: f876a0bb8c6e8b5f660bf4bc4798e128ef3058b4 Pull Request resolved: #100587
hmm. I added this logic to
Where the idea was that for export, we want to make sure that there are no tangent inputs to our joint graph, since we know we're just calling This doesn't actually work though, because we don't know what the value of the scalar is. If the scalar is a The easiest thing to do is probably to plumb |
The main addition in this PR is two new API's in AOTAutograd. **APIs** `aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs. There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step. I (gratefully) used SherlockNoMad and suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd. `aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following: * If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function * If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function. The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc zou3519 Chillee **Implementation Strategy** A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely: * The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`. * `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them). * A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then. * I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code * I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export. * `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export. [ghstack-poisoned]
ghstack-source-id: a65ef1fed5759981e27f2f02b0c5d5fd3a119751 Pull Request resolved: #100587
# need to know if it is a 1 (no need to make it a graph input), or something else | ||
# (requiring it to be a graph input). | ||
# We don't know this info at trace time though, so we need to make it an explicit config. | ||
no_tangents: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang I don't love this but let me know what you think. My goal was to avoid having to duplicate the logic for calling autograd.grad()
- if we want a joint graph with no tangent inputs, we need to make sure we don't specify grad_outputs=...
. I originally thought I could detect that automatically by checking to see if our tangent is a single scalar, but that's not true - if the scalar is not 1
, then it should still be an input to our graph (but we don't know its value at trace time).
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The main addition in this PR is two new API's in AOTAutograd. **APIs** `aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs. There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step. I (gratefully) used @SherlockNoMad and @suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd. `aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following: * If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function * If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function. The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc @zou3519 @Chillee **Implementation Strategy** A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely: * The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`. * `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them). * A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then. * I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code * I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export. * `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export. Pull Request resolved: #100587 Approved by: https://github.com/ezyang, https://github.com/SherlockNoMad
The main addition in this PR is two new API's in AOTAutograd.
APIs
aot_export_module
: Given a module, exports it into a functionalized FX graph. Returns anfx.GraphModule
,GraphSignature
pair. TheGraphSignature
tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specifytrace_joint=True
, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is required to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.
I (gratefully) used @SherlockNoMad and @suo's internal version of the
GraphSignature
object for this API, with a few minor changes in order to integrate it into AOTAutograd.aot_export_joint_simple
: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is not required to return a scalar loss. However, this API makes the guarantee that you do not need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:trace_joint=False
, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original functiontrace_joint=True
, then you will need to manually use thedefault_partitioner
ormin_cut_partitioner
from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.The main use case for this API is higher order ops: a higher order op like
torch.cond()
can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on thetrue_bw
,false_bw
subgraphs. cc @zou3519 @ChilleeImplementation Strategy
A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:
_aot_export_function
: this is a general purpose export API, that just re-usescreate_aot_dispatcher_function
. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using_aot_export_function
.aot_export_module
works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export undertorch.no_grad()
to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually.detach()
all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).aot_export_module
comes from parsing out and creating aGraphSignature
object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see doc). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.create_functional_call()
andcreate_tree_flattened_fn()
for pytree-flattening and lifting-params-and-buffers, since I also need them in the export codeAOTConfig.is_export
flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.aot_dispatch_autograd()
now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create anautograd.Function
. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.Stack from ghstack (oldest at bottom):