Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTAutograd] add export entrypoints #100587

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
233 changes: 232 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
grad, vjp, vmap, jacrev,
make_fx
)
from torch._functorch.aot_autograd import aot_module_simplified
from torch._functorch.aot_autograd import aot_module_simplified, aot_export_module, aot_export_joint_simple
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
min_cut_rematerialization_partition, aot_function, aot_module,
Expand Down Expand Up @@ -1929,6 +1929,237 @@ def get_fw_bw_graph(f, inps, partitioner=min_cut_rematerialization_partition, dy
dynamic=dynamic)(*inps).sum().backward()
return (fw_graph_cell[0], bw_graph_cell[0])

class TestMod(torch.nn.Module):
def __init__(self, fn):
super().__init__()
self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True))
self.fn = fn

def forward(self, *args):
return self.fn(self.p, *args)

class TestAOTExport(AOTTestCase):

def test_aot_export_module_joint(self):
class ConvBatchnormRelu(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
user_out = torch.nn.functional.relu(x)
loss = user_out.sum()
return loss, user_out.detach()

mod = ConvBatchnormRelu()
mod.train()
inp = torch.randn(1, 1, 3, 3)
o_ref = mod(inp)
fx_g, signature = aot_export_module(mod, [inp], trace_joint=True, output_loss_index=0)
# Some important characteristics of the exported graph below:
# 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
# 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)
self.assertExpectedInline(fx_g.print_readable(print_output=False), """\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]):
# No stacktrace found for following nodes
convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None
add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); arg3_1 = arg4_1 = arg5_1 = None
getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0]
getitem_1: f32[3] = _native_batch_norm_legit_functional[1]
getitem_2: f32[3] = _native_batch_norm_legit_functional[2]
getitem_3: f32[3] = _native_batch_norm_legit_functional[3]
getitem_4: f32[3] = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None
detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
sum_1: f32[] = torch.ops.aten.sum.default(relu)
detach_1: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu)
detach_2: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(detach_1); detach_1 = None
ones_like: f32[] = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
expand: f32[1, 3, 3, 3] = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None
threshold_backward: f32[1, 3, 3, 3] = torch.ops.aten.threshold_backward.default(expand, relu, 0); expand = relu = None
native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None
getitem_5: f32[1, 3, 3, 3] = native_batch_norm_backward[0]
getitem_6: f32[3] = native_batch_norm_backward[1]
getitem_7: f32[3] = native_batch_norm_backward[2]; native_batch_norm_backward = None
convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); getitem_5 = arg7_1 = arg0_1 = None
getitem_8 = convolution_backward[0]
getitem_9: f32[3, 1, 1, 1] = convolution_backward[1]
getitem_10: f32[3] = convolution_backward[2]; convolution_backward = None
return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7)
""") # noqa: B950


self.assertExpectedInline(str(signature.parameters), """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""")
self.assertExpectedInline(str(signature.buffers), """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""")
self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""")
self.assertExpectedInline(str(signature.inputs_to_parameters), """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""") # noqa: B950
self.assertExpectedInline(str(signature.inputs_to_buffers), """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""") # noqa: B950
self.assertExpectedInline(str(signature.buffers_to_mutate), """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""") # noqa: B950
self.assertExpectedInline(str(signature.backward_signature.gradients_to_parameters), """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""") # noqa: B950
self.assertExpectedInline(str(signature.backward_signature.gradients_to_user_inputs), """{}""")
self.assertExpectedInline(str(signature.backward_signature.loss_output), """getitem_3""")

# Also check the inference graph
# Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs.
fx_g_inference, signature_inference = aot_export_module(mod, [inp], trace_joint=False)
self.assertExpectedInline(fx_g_inference.print_readable(print_output=False), """\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: f32[3, 1, 1, 1], arg1_1: f32[3], arg2_1: f32[3], arg3_1: f32[3], arg4_1: f32[3], arg5_1: f32[3], arg6_1: i64[], arg7_1: f32[1, 1, 3, 3]):
# No stacktrace found for following nodes
convolution: f32[1, 3, 3, 3] = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None
add: i64[] = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None
getitem: f32[1, 3, 3, 3] = _native_batch_norm_legit_functional[0]
getitem_3: f32[3] = _native_batch_norm_legit_functional[3]
getitem_4: f32[3] = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
relu: f32[1, 3, 3, 3] = torch.ops.aten.relu.default(getitem); getitem = None
sum_1: f32[] = torch.ops.aten.sum.default(relu)
detach: f32[1, 3, 3, 3] = torch.ops.aten.detach.default(relu); relu = None
return (getitem_3, getitem_4, add, sum_1, detach)
""") # noqa: B950
# Some important characteristics of the exported graph below:
# 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input
# 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters)

def test_aot_export_simplified_basic(self):
def f(x, y):
return x * y, y * y.detach()

x = torch.randn(2, requires_grad=True)
y = torch.randn(2, requires_grad=True)

f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False)
out_ref = f(x, y)
# No calling convention changes necessary to invoke the traced graph
out_test = f_graph_fw(x, y)
self.assertEqual(out_ref, out_test)

# Now test the backward
x = torch.randn(2, requires_grad=True)
y = torch.randn(2, requires_grad=True)
x2 = x.clone().detach().requires_grad_(True)
y2 = y.clone().detach().requires_grad_(True)
x3 = x.clone().detach().requires_grad_(True)
y3 = y.clone().detach().requires_grad_(True)
f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True)
num_fw_outputs = 2
fw_g, bw_g = default_partition(f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs)
out_ref2 = f(x2, y2)
fw_outs = fw_g(x3, y3)
out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:]
self.assertEqual(out_ref2, out_test2)

# Test running the traced backward graph with a mocked-up grad_output
grad_outs = [torch.ones_like(x) for x in out_ref2]
grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs)
grads_test = bw_g(*activations, *grad_outs)
for g_ref, g_test in zip(grads_ref, grads_test):
self.assertEqual(g_ref, g_test)

def test_aot_export_metadata_mutation_banned(self):
def fn(p, x):
x.t_()
return (x * 2,)
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Found an input that received a metadata mutation"
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
aot_export_module(mod, [inp], trace_joint=False)

def test_aot_export_input_mutation_on_parameter_banned(self):
def fn(p, x):
p.mul_(2)
return (p + x,)
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Found a graph input that requires gradients, and received a mutation"
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
aot_export_module(mod, [inp], trace_joint=False)

def test_aot_export_synthetic_bases_banned(self):
def fn(p, x, y):
x.mul_(2)
return (x + y,)
mod = TestMod(fn)
inp = torch.randn(2)
inp2 = inp.view(-1)
with self.assertRaisesRegex(
RuntimeError, "Encountered aliased inputs that are mutated"
):
aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True)
aot_export_module(mod, [inp, inp2], trace_joint=False)

def test_aot_export_input_dupes_banned(self):
def fn(p, x, y):
x.mul_(2)
return (x + y,)
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Encountered duplicated inputs that are mutated in the graph"
):
aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True)
aot_export_module(mod, [inp, inp], trace_joint=False)

def test_aot_export_multiple_outputs_require_grad_banned(self):
def fn(p, x):
out = p * x
return out, out.sum()
mod = TestMod(fn)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "Found an output of the forward that requires gradients, that was not"
):
aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1)

def test_aot_export_simplified_input_mutations_banned(self):
def fn(x):
x.mul_(2)
return (x + x,)
inp = torch.randn(2)
with self.assertRaisesRegex(
RuntimeError, "aot_export_joint_simple does not support input mutations"
):
aot_export_joint_simple(fn, [inp], trace_joint=False)
aot_export_joint_simple(fn, [inp], trace_joint=True)

def test_aot_export_simplified_pytrees_banned(self):
def fn(inps):
return (inps[0] + inps[1],)
inp1 = torch.randn(2)
inp2 = torch.randn(2)
inps = [inp1, inp2]
with self.assertRaisesRegex(
RuntimeError, "aot_export_joint_simple requires individual inputs not to be pytrees"
):
aot_export_joint_simple(fn, [inps], trace_joint=False)
aot_export_joint_simple(fn, [inps], trace_joint=True)

def test_aot_export_functionalized_rng_banned(self):
def fn(p, x):
return (p + x,)
mod = TestMod(fn)
inp = torch.randn(2)
with patch("functorch.compile.config.functionalize_rng_ops", True), self.assertRaisesRegex(
RuntimeError, "Functionalized RNG is not currently supported in the aot_export"
):
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False)
aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True)
aot_export_module(mod, [inp], trace_joint=False)


class TestPartitioning(AOTTestCase):
@unittest.skipIf(not USE_NETWORKX, "networkx not available")
Expand Down