Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Autodifferentiable scripts #22582

Closed
wants to merge 4 commits into from
Closed

Conversation

t-vi
Copy link
Collaborator

@t-vi t-vi commented Jul 8, 2019

This patch intends to enable auto-differentiated (i.e. Source-to-Source diff) script functions:

@torch.jit.autograd_script
def my_function(x):
    def backward(grad_out):
        return 99 * x * grad_out  # not the true gradient, mind you!
    return x**2, backward

@torch.jit.script
def test(x):
    res = my_function(x)
    return res

There is some discussion in #22329 .
The interface was largely taken from symbolic_script.cpp.

The state of this patch is very raw:

  • It needs tests!
  • Needs more source comments.
  • It seems to cover several inputs and outputs. I have not fully checked if it allows all combinations and types.
  • Some things (e.g. number of grad_ins returned not matching number of inputs) currently run into internal asserts rather than being caught in TORCH_CHECK at the beginning.
  • The code organization (cogh), likely the compilation should take place in compiler.cpp rather than init.cpp.
  • There are various aspects of the design "what is done when by whom" that are debatable some things at the top of my mind (but there will be others):
    • We need to keep track of the function and also keep the things captured for the backward from being cut by DCE. Currently, I keep the function call (prim::CallAutogradFunction around until the autodiff graphs are created or the GraphExecutor deciding that it does not need gradients for certain bits. I'm not quite sure whether this causes us to miss out passes / optimizations we want.
    • One annoying thing we need to deal with is that Python functions only have one return value and multiple return values are wrapped in tuples. At the same time, we only want tensors or tensor lists, but not tuples containing tensors as inputs/outputs of DifferentiableGraphs because the DifferentiableGraphBackward cannot handle them. This means that we want to get prim::UnpackTuple after a autograd function into the DifferentiableGraph but not at the top of a DifferentiableGraph (because that would have a tuple output of the backward). Maybe in the end the right thing would be to teach DifferentiableGraphBackward about tuples.

I plan to clean this up over the net couple of weeks, but if you want to weigh in on the fundamental design issues, I would appreciate your input.

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries labels Jul 8, 2019
@soumith
Copy link
Member

soumith commented Jul 10, 2019

I plan to clean this up over the net couple of weeks, but if you want to weigh in on the fundamental design issues, I would appreciate your input.

@zdevito and @driazati (and @apaszke )

@zdevito
Copy link
Contributor

zdevito commented Jul 12, 2019

I haven't had time to look into the design in detail, but the premise seems good. However, one concern I have is that it violates the "everything is script works in eager mode" invariant. Here we have something that cannot be debugged in eager mode with the same behavior. We rely on this invariant for a good debugging experience so this concerns me.

@t-vi
Copy link
Collaborator Author

t-vi commented Jul 12, 2019

@zdevito I agree. So I try to come up with a decorator that keeps the evaluation in python itself? In the previous version (https://github.com/t-vi/pytorch/tree/autograd_script) I implemented JITing of the current autograd.Function, but I'm not sure I like my ctx elimination or can thing of good alternatives.

@dzhulgakov
Copy link
Collaborator

This is nice!

@zdevito - I guess it might work out if we have a decorator in eager mode that turns that style function into autograd.Function. Then all we'd need to do for debugging it to replace @torch.jit.autograd_script with @torch.autograd.func or something like that.

@t-vi
Copy link
Collaborator Author

t-vi commented Jul 13, 2019

So here is a classic PyTorch closure-based autograd_function decorator:

class MetaAutogradFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, fn, *inputs):
        results = fn(*inputs)
        ctx.backward_fn = results[-1]
        if len(results) == 2:
            return results[0]
        return results[:-1]
    @staticmethod
    def backward(ctx, *grad_out):
        grad_in = ctx.backward_fn(*grad_out)
        if not isinstance(grad_in, tuple):
            return None, grad_in
        return (None,) + grad_in

def autograd_function(fn):
    """"decorator to make an autograd.Function from a function returning a backward closure""""
    def autograd_fn(*inputs):
        return MetaAutogradFn.apply(fn, *inputs)
    return autograd_fn

Use:

@autograd_function
def testfn(x):
    y = torch.arange(5,10.)
    def backward(grad_in):
        return x
    return x**2, backward

x = torch.arange(5., requires_grad=True)
res = testfn(x)
print(res, torch.autograd.grad(res.sum(), x))

In order to implement the sanity checks offered by save_for_backward we might record the relevant state from backward.__closure__ and check them in the backward.

@ezyang
Copy link
Contributor

ezyang commented Jul 15, 2019

cc @malvika2147; with this and the Python level decorator, the only last piece is C++ autograd function support, and that's what Malvika is working on.

@t-vi
Copy link
Collaborator Author

t-vi commented Jan 5, 2020

This is a cool function, but I'm economizing for now.

@guillaumeBellec
Copy link

guillaumeBellec commented Oct 14, 2020

Is there a way to use an experimental version of that?
I tried compiling the branch t-vi:autograd_script2 but it returns an error during compilation.

@t-vi
Copy link
Collaborator Author

t-vi commented Oct 18, 2020

There is an easy variant that does the same thing for the user but uses autograd instead of autodiff.

@guillaumeBellec
Copy link

Could you be more explicit or point to a piece of code?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants