diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c41709baea02f..c7e20d2ad9baa 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -2929,11 +2929,8 @@ def fn(input, mask): expected = fn(*inputs1) actual = fn_opt(*inputs2) self.assertTrue(same(actual, expected)) - self.assertEqual(dict(counters["frames"]), {"total": 2, "ok": 2}) - self.assertEqual( - dict(counters["graph_break"]), {"autograd.Function with requires_grad": 1} - ) - self.assertEqual(cnt.op_count, 6) + self.assertEqual(dict(counters["frames"]), {"total": 1, "ok": 1}) + self.assertEqual(cnt.op_count, 1) self.assertEqual(cnt.frame_count, 1) cnt.clear() counters.clear() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 6b573d279e0e2..4884ec3497159 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -561,7 +561,7 @@ def inline_user_function_return(self, fn, args, kwargs): result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) self.output.guards.update(fn.guards) return result - except Exception: + except Exception as e: self.restore_graphstate(state) raise diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 8c42980627284..670f61d0cf42c 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -224,8 +224,16 @@ def visit(node): VariableTracker.apply(visit, (args, kwargs)) if requires_grad and torch.is_grad_enabled(): - # TODO(jansel): handle this in training mode - unimplemented("autograd.Function with requires_grad") + from .torch import TorchHigherOrderOperator + + def trampoline_autograd_fn(*args, **kwargs): + return self.fn_cls.apply(*args, **kwargs) + + # Speculate fwd + # TODO(voz): Check bwd soundness, or something, I dunno, bug Horace + # TODO(voz): NOTE: This is unguarded, but the odds of someone swapping self.fn_cls from autograd fn to something else + # is very low. We can add guarding before we ship this PR. + return TorchHigherOrderOperator(trampoline_autograd_fn).call_function(tx, args, kwargs) args = [AutogradFunctionContextVariable.create_for_inference(tx), *args] options = VariableTracker.propagate(self, args, kwargs.values()) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 6712a4def4e10..05b763de98c69 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -918,10 +918,13 @@ def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint): args = [] # One argument to graph per sub_args for a in sub_args: - if isinstance(a, TensorVariable): + if isinstance(a, ConstantVariable): + proxy = tracer.create_graph_input("const") + args.append(a) + elif isinstance(a, TensorVariable): tracer.create_graph_input(a.as_proxy().node.name) args.append(a) - else: + elif isinstance(a, torch.Tensor): # call_function() needs a TensorVariable, therefore we construct # one with inner graph proxy. assert isinstance(a, torch.Tensor) @@ -929,8 +932,11 @@ def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint): args.append( wrap_fx_proxy(tx=tx, proxy=proxy, example_value=a) ) + else: + raise unimplemented("Speculate subgraph with unsupported inputs.") output = f.call_function(tx, args, {}) + # breakpoint() # Register output to graph # Modeled off of compile_and_call_fx_graph # TODO: support non single Tensor output @@ -1177,6 +1183,31 @@ def speculate_branch(branch): ) r = body_r.as_proxy().node.meta["example_value"] example_value = r + elif self.value.__name__ == "trampoline_autograd_fn": + fn = TorchVariable( + self.value + ) + + checkpoint = tx.copy_graphstate() + graph_checkpoint = tx.output.graph + ( + body_r, + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + fn, + [ + *args, + ], + graph_checkpoint, + checkpoint, + ) + p_args = ( + *(arg.as_proxy() for arg in args), + *(arg for arg in body_lifted_freevars), + ) + r = body_r.as_proxy().node.meta["example_value"] + example_value = r else: unimplemented(f"HigherOrderOperator {self.value.__name__}")