Skip to content

Commit

Permalink
[AOTAutograd] add export entrypoints (#100587)
Browse files Browse the repository at this point in the history
The main addition in this PR is two new API's in AOTAutograd.

**APIs**

`aot_export_module`: Given a module, exports it into a functionalized FX graph. Returns an `fx.GraphModule`, `GraphSignature` pair. The `GraphSignature` tells you various information about the graph, such as which graph inputs correspond to module params/buffers (and their fqn's), how to pytree-ify the inputs and the outputs of the graph. If you specify `trace_joint=True`, then you'll get back a joint forward-backward graph, that also returns parameter gradients in addition to the user outputs.

There are several restrictions on this API, detailed in the comments. The most notable one is probably that this API does not handle partial graphs: If you want a backward graph, then you module's forward function is **required** to return a scalar loss that we can backprop through. It also does not support capturing the optimizer step.

I (gratefully) used @SherlockNoMad and @suo's internal version of the `GraphSignature` object for this API, with a few minor changes in order to integrate it into AOTAutograd.

`aot_export_joint_simple`: Given a function, we'll trace it into a joint forward-backward graph and return it. Unlike the above API, the function is **not** required to return a scalar loss. However, this API makes the guarantee that you **do not** need to make any calling convention changes between the original function, and the exported one, provided that you do that you do the following:
* If you pass `trace_joint=False`, no work is needed: we'll export a functionalized forward graph with the same set of inputs as the original function
* If you pass `trace_joint=True`, then you will need to manually use the `default_partitioner` or `min_cut_partitioner` from functorch. If you do, and get back a fw and bw graph, then the forward graph will be runnable identically to the original user function.

The main use case for this API is higher order ops: a higher order op like `torch.cond()` can implement its derivative formula by using this API to export a joint graph (for both the true subgraph and the false subgraph), partition it into a fw/bw graph, and run cond on the `true_bw`, `false_bw` subgraphs. cc @zou3519 @Chillee

**Implementation Strategy**

A lot of the work in this PR went in to trying to find a reasonable way to re-use existing AOTAutograd components to expose these API's. Concretely:

* The two new API's are both thin wrappers around `_aot_export_function`: this is a general purpose export API, that just re-uses `create_aot_dispatcher_function`. If we want to add e.g. an export API that includes the optimizer step in the future, we could probably implement it using `_aot_export_function`.
* `aot_export_module` works extra hard to re-use as much of AOTAutograd as possible. For example, when tracing an inference graph, I perform the export under `torch.no_grad()` to make sure we don't accidentally trace out a backwards graph. When exporting a joint graph, I manually `.detach()` all user outputs except the loss, to make sure that we don't accidentally compute gradients for any other user outputs (even if the user forgot to manually detach them).
* A large portion of `aot_export_module` comes from parsing out and creating a `GraphSignature` object. We discussed a few weeks ago that there's potentially a lot more information that we could stuff into this object (see [doc](https://docs.google.com/document/d/1_qzdKew5D1J2Q2GkZ1v5jsczSsIU-Sr0AJiPW7DdGjE/edit?usp=sharing)). For now, I ended up deciding to support the more limited use case of exporting a fwd-bwd full graph, without some of the extra annotations in that doc (for example, if we were to export partial graphs, we would need annotations for saved activations). My thought is that once a more concrete use case comes up that the existing API doesn't satisfy, we can revisit the annotations then.
* I factored out `create_functional_call()` and `create_tree_flattened_fn()` for pytree-flattening and lifting-params-and-buffers, since I also need them in the export code
* I added an `AOTConfig.is_export` flag. The export API re-uses all of the same code paths as the rest of AOTAutograd, but there are a few points where we need to either exit early (and avoid making a runtime epilogue), or add extra error checking, that is only valuable for export.
* `aot_dispatch_autograd()` now exits early if it's being called in an export context, so it returns the full graph instead of also trying to create an `autograd.Function`. I think we probably want to factor this out, although I figured it would be safer to wait a bit for clarity on how functional RNG works with export.

Pull Request resolved: #100587
Approved by: https://github.com/ezyang, https://github.com/SherlockNoMad
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 15, 2023
1 parent bba12a4 commit ee40cce
Show file tree
Hide file tree
Showing 2 changed files with 933 additions and 85 deletions.
233 changes: 232 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
grad, vjp, vmap, jacrev,
make_fx
)
from torch._functorch.aot_autograd import aot_module_simplified
from torch._functorch.aot_autograd import aot_module_simplified, aot_export_module, aot_export_joint_simple
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
min_cut_rematerialization_partition, aot_function, aot_module,
Expand Down Expand Up @@ -1946,6 +1946,237 @@ def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition, dy
dynamic=dynamic)(*inps).sum().backward()
return (fw_graph_cell[0], bw_graph_cell[0])

class TestMod(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True))
self.fn = fn

def forward(self, *args):
return self.fn(self.p, *args)

class TestAOTExport(AOTTestCase):

def test_aot_export_module_joint(self):
class ConvBatchnormRelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
user_out = torch.nn.functional.relu(x)
loss = user_out.sum()
return loss, user_out.detach()

mod = ConvBatchnormRelu()
mod.train()
inp = torch.randn(1, 1, 3, 3)
o_ref = mod(inp)
fx_g, signature = aot_export_module(mod, [inp], trace_joint=True, output_loss_index=0)
# Some important characteristics of the exported graph below:
# 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
# 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)
self.assertExpectedInline(fx_g.print_readable(print_output=False), """\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]):
# No stacktrace found for following nodes
convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None
add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); arg3_1 = arg4_1 = arg5_1 = None
getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0]
getitem_1: f32[3] = _native_batch_norm_legit_functional[1]
getitem_2: f32[3] = _native_batch_norm_legit_functional[2]
getitem_3: f32[3] = _native_batch_norm_legit_functional[3]
getitem_4: f32[3] = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None
detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
sum_1: f32[] = torch.ops.aten.sum.default(relu)
detach_1: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
detach_2: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_1); detach_1 = None
ones_like: f32[] = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
expand: f32[1, 3, 3, 3] = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
threshold_backward: f32[1, 3, 3, 3] = torch.ops.aten.threshold_backward.default(expand, relu, 0); expand = relu = None
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
getitem_5: f32[1, 3, 3, 3] = native_batch_norm_backward[0]
getitem_6: f32[3] = native_batch_norm_backward[1]
getitem_7: f32[3] = native_batch_norm_backward[2]; native_batch_norm_backward = None
convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); getitem_5 = arg7_1 = arg0_1 = None
getitem_8 = convolution_backward[0]
getitem_9: f32[3, 1, 1, 1] = convolution_backward[1]
getitem_10: f32[3] = convolution_backward[2]; convolution_backward = None
return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
""") # noqa: B950


self.assertExpectedInline(str(signature.parameters), """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""")
self.assertExpectedInline(str(signature.buffers), """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""")
self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""")
self.assertExpectedInline(str(signature.inputs_to_parameters), """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""") # noqa: B950
self.assertExpectedInline(str(signature.inputs_to_buffers), """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""") # noqa: B950
self.assertExpectedInline(str(signature.buffers_to_mutate), """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""") # noqa: B950
self.assertExpectedInline(str(signature.backward_signature.gradients_to_parameters), """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""") # noqa: B950
self.assertExpectedInline(str(signature.backward_signature.gradients_to_user_inputs), """{}""")
self.assertExpectedInline(str(signature.backward_signature.loss_output), """getitem_3""")

# Also check the inference graph
# Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs.
fx_g_inference, signature_inference = aot_export_module(mod, [inp], trace_joint=False)
self.assertExpectedInline(fx_g_inference.print_readable(print_output=False), """\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]):
# No stacktrace found for following nodes
convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None
add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None
getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0]
getitem_3: f32[3] = _native_batch_norm_legit_functional[3]
getitem_4: f32[3] = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None
sum_1: f32[] = torch.ops.aten.sum.default(relu)
detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu); relu = None
return (getitem_3, getitem_4, add, sum_1, detach)
""") # noqa: B950
# Some important characteristics of the exported graph below:
# 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
# 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)

def test_aot_export_simplified_basic(self):
def f(x, y):
return x * y, y * y.detach()

x = torch.randn(2, requires_grad=True)
y = torch.randn(2, requires_grad=True)

f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False)
out_ref = f(x, y)
# No calling convention changes necessary to invoke the traced graph
out_test = f_graph_fw(x, y)
self.assertEqual(out_ref, out_test)

# Now test the backward
x = torch.randn(2, requires_grad=True)
y = torch.randn(2, requires_grad=True)
x2 = x.clone().detach().requires_grad_(True)
y2 = y.clone().detach().requires_grad_(True)
x3 = x.clone().detach().requires_grad_(True)
y3 = y.clone().detach().requires_grad_(True)
f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True)
num_fw_outputs = 2
fw_g, bw_g = default_partition(f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs)
out_ref2 = f(x2, y2)
fw_outs = fw_g(x3, y3)
out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:]
self.assertEqual(out_ref2, out_test2)

# Test running the traced backward graph with a mocked-up grad_output
grad_outs = [torch.ones_like(x) for x in out_ref2]
grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs)
grads_test = bw_g(*activations, *grad_outs)
for g_ref, g_test in zip(grads_ref, grads_test):
self.assertEqual(g_ref, g_test)

def test_aot_export_metadata_mutation_banned(self):
def fn(p, x):
x.t_()
return (x * 2,)
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Found an input that received a metadata mutation"
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
aot_export_module(mod, [inp], trace_joint=False)

def test_aot_export_input_mutation_on_parameter_banned(self):
def fn(p, x):
p.mul_(2)
return (p + x,)
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Found a graph input that requires gradients, and received a mutation"
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
aot_export_module(mod, [inp], trace_joint=False)

def test_aot_export_synthetic_bases_banned(self):
def fn(p, x, y):
x.mul_(2)
return (x + y,)
mod = TestMod(fn)
inp = torch.randn(2)
inp2 = inp.view(-1)
with self.assertRaisesRegex(
RuntimeError, "Encountered aliased inputs that are mutated"
):
aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True)
aot_export_module(mod, [inp, inp2], trace_joint=False)

def test_aot_export_input_dupes_banned(self):
def fn(p, x, y):
x.mul_(2)
return (x + y,)
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Encountered duplicated inputs that are mutated in the graph"
):
aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True)
aot_export_module(mod, [inp, inp], trace_joint=False)

def test_aot_export_multiple_outputs_require_grad_banned(self):
def fn(p, x):
out = p * x
return out, out.sum()
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Found an output of the forward that requires gradients, that was not"
):
aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1)

def test_aot_export_simplified_input_mutations_banned(self):
def fn(x):
x.mul_(2)
return (x + x,)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "aot_export_joint_simple does not support input mutations"
):
aot_export_joint_simple(fn, [inp], trace_joint=False)
aot_export_joint_simple(fn, [inp], trace_joint=True)

def test_aot_export_simplified_pytrees_banned(self):
def fn(inps):
return (inps[0] + inps[1],)
inp1 = torch.randn(2)
inp2 = torch.randn(2)
inps = [inp1, inp2]
with self.assertRaisesRegex(
RuntimeError, "aot_export_joint_simple requires individual inputs not to be pytrees"
):
aot_export_joint_simple(fn, [inps], trace_joint=False)
aot_export_joint_simple(fn, [inps], trace_joint=True)

def test_aot_export_functionalized_rng_banned(self):
def fn(p, x):
return (p + x,)
mod = TestMod(fn)
inp = torch.randn(2)
with patch("functorch.compile.config.functionalize_rng_ops", True), self.assertRaisesRegex(
RuntimeError, "Functionalized RNG is not currently supported in the aot_export"
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
aot_export_module(mod, [inp], trace_joint=False)


class TestPartitioning(AOTTestCase):
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
Expand Down

0 comments on commit ee40cce

Please sign in to comment.