-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Autodifferentiable scripts #22582
Conversation
I haven't had time to look into the design in detail, but the premise seems good. However, one concern I have is that it violates the "everything is script works in eager mode" invariant. Here we have something that cannot be debugged in eager mode with the same behavior. We rely on this invariant for a good debugging experience so this concerns me. |
@zdevito I agree. So I try to come up with a decorator that keeps the evaluation in python itself? In the previous version (https://github.com/t-vi/pytorch/tree/autograd_script) I implemented JITing of the current autograd.Function, but I'm not sure I like my ctx elimination or can thing of good alternatives. |
This is nice! @zdevito - I guess it might work out if we have a decorator in eager mode that turns that style function into autograd.Function. Then all we'd need to do for debugging it to replace @torch.jit.autograd_script with @torch.autograd.func or something like that. |
So here is a classic PyTorch closure-based class MetaAutogradFn(torch.autograd.Function):
@staticmethod
def forward(ctx, fn, *inputs):
results = fn(*inputs)
ctx.backward_fn = results[-1]
if len(results) == 2:
return results[0]
return results[:-1]
@staticmethod
def backward(ctx, *grad_out):
grad_in = ctx.backward_fn(*grad_out)
if not isinstance(grad_in, tuple):
return None, grad_in
return (None,) + grad_in
def autograd_function(fn):
""""decorator to make an autograd.Function from a function returning a backward closure""""
def autograd_fn(*inputs):
return MetaAutogradFn.apply(fn, *inputs)
return autograd_fn Use: @autograd_function
def testfn(x):
y = torch.arange(5,10.)
def backward(grad_in):
return x
return x**2, backward
x = torch.arange(5., requires_grad=True)
res = testfn(x)
print(res, torch.autograd.grad(res.sum(), x)) In order to implement the sanity checks offered by |
cc @malvika2147; with this and the Python level decorator, the only last piece is C++ autograd function support, and that's what Malvika is working on. |
This is a cool function, but I'm economizing for now. |
Is there a way to use an experimental version of that? |
There is an easy variant that does the same thing for the user but uses autograd instead of autodiff. |
Could you be more explicit or point to a piece of code? |
This patch intends to enable auto-differentiated (i.e. Source-to-Source diff) script functions:
There is some discussion in #22329 .
The interface was largely taken from
symbolic_script.cpp
.The state of this patch is very raw:
grad_in
s returned not matching number of inputs) currently run into internal asserts rather than being caught inTORCH_CHECK
at the beginning.compiler.cpp
rather thaninit.cpp
.prim::CallAutogradFunction
around until the autodiff graphs are created or the GraphExecutor deciding that it does not need gradients for certain bits. I'm not quite sure whether this causes us to miss out passes / optimizations we want.DifferentiableGraph
s because theDifferentiableGraphBackward
cannot handle them. This means that we want to getprim::UnpackTuple
after a autograd function into theDifferentiableGraph
but not at the top of aDifferentiableGraph
(because that would have a tuple output of the backward). Maybe in the end the right thing would be to teach DifferentiableGraphBackward about tuples.I plan to clean this up over the net couple of weeks, but if you want to weigh in on the fundamental design issues, I would appreciate your input.