Skip to content

Commit

Permalink
[Experimental] [Needs more text] Rewrite autograd function w/ grad as…
Browse files Browse the repository at this point in the history
… a TorchVariable and simulate user invoked allow ingraph, 2x faster Deberta training

ghstack-source-id: b6a20e9bb884150c10b4f5d205f3591d4426835e
Pull Request resolved: #99483

rm crap

Make it nice

cleanup

Fix test, source shenanigins
  • Loading branch information
voznesenskym committed Apr 28, 2023
1 parent 24935e5 commit 22494c4
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
7 changes: 2 additions & 5 deletions test/dynamo/test_repros.py
Expand Up @@ -2929,11 +2929,8 @@ def fn(input, mask):
expected = fn(*inputs1)
actual = fn_opt(*inputs2)
self.assertTrue(same(actual, expected))
self.assertEqual(dict(counters["frames"]), {"total": 2, "ok": 2})
self.assertEqual(
dict(counters["graph_break"]), {"autograd.Function with requires_grad": 1}
)
self.assertEqual(cnt.op_count, 6)
self.assertEqual(dict(counters["frames"]), {"total": 1, "ok": 1})
self.assertEqual(cnt.op_count, 1)
self.assertEqual(cnt.frame_count, 1)
cnt.clear()
counters.clear()
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/symbolic_convert.py
Expand Up @@ -561,7 +561,7 @@ def inline_user_function_return(self, fn, args, kwargs):
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
self.output.guards.update(fn.guards)
return result
except Exception:
except Exception as e:
self.restore_graphstate(state)
raise

Expand Down
12 changes: 10 additions & 2 deletions torch/_dynamo/variables/misc.py
Expand Up @@ -224,8 +224,16 @@ def visit(node):
VariableTracker.apply(visit, (args, kwargs))

if requires_grad and torch.is_grad_enabled():
# TODO(jansel): handle this in training mode
unimplemented("autograd.Function with requires_grad")
from .torch import TorchHigherOrderOperator

def trampoline_autograd_fn(*args, **kwargs):
return self.fn_cls.apply(*args, **kwargs)

# Speculate fwd
# TODO(voz): Check bwd soundness, or something, I dunno, bug Horace
# TODO(voz): NOTE: This is unguarded, but the odds of someone swapping self.fn_cls from autograd fn to something else
# is very low. We can add guarding before we ship this PR.
return TorchHigherOrderOperator(trampoline_autograd_fn).call_function(tx, args, kwargs)

args = [AutogradFunctionContextVariable.create_for_inference(tx), *args]
options = VariableTracker.propagate(self, args, kwargs.values())
Expand Down
35 changes: 33 additions & 2 deletions torch/_dynamo/variables/torch.py
Expand Up @@ -918,19 +918,25 @@ def speculate_subgraph(f, sub_args, graph_checkpoint, checkpoint):
args = []
# One argument to graph per sub_args
for a in sub_args:
if isinstance(a, TensorVariable):
if isinstance(a, ConstantVariable):
proxy = tracer.create_graph_input("const")
args.append(a)
elif isinstance(a, TensorVariable):
tracer.create_graph_input(a.as_proxy().node.name)
args.append(a)
else:
elif isinstance(a, torch.Tensor):
# call_function() needs a TensorVariable, therefore we construct
# one with inner graph proxy.
assert isinstance(a, torch.Tensor)
proxy = tracer.create_graph_input("arg")
args.append(
wrap_fx_proxy(tx=tx, proxy=proxy, example_value=a)
)
else:
raise unimplemented("Speculate subgraph with unsupported inputs.")

output = f.call_function(tx, args, {})
# breakpoint()
# Register output to graph
# Modeled off of compile_and_call_fx_graph
# TODO: support non single Tensor output
Expand Down Expand Up @@ -1177,6 +1183,31 @@ def speculate_branch(branch):
)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r
elif self.value.__name__ == "trampoline_autograd_fn":
fn = TorchVariable(
self.value
)

checkpoint = tx.copy_graphstate()
graph_checkpoint = tx.output.graph
(
body_r,
body_graph,
body_lifted_freevars,
) = speculate_subgraph(
fn,
[
*args,
],
graph_checkpoint,
checkpoint,
)
p_args = (
*(arg.as_proxy() for arg in args),
*(arg for arg in body_lifted_freevars),
)
r = body_r.as_proxy().node.meta["example_value"]
example_value = r
else:
unimplemented(f"HigherOrderOperator {self.value.__name__}")

Expand Down

0 comments on commit 22494c4

Please sign in to comment.