Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] Forward-mode automatic differentiation #10223

Open
krishnap25 opened this issue Aug 3, 2018 · 13 comments

Comments

@krishnap25
Copy link

@krishnap25 krishnap25 commented Aug 3, 2018

Thanks for the awesome library! It would be great if PyTorch could support forward-mode automatic differentiation. The main use case is to compute a Jacobian-vector product. I tried using this trick that simulates forward-mode autodiff by running reverse-mode twice, but it causes my GPU to run out of memory with AlexNet. HIPS/autograd supports this operation, and it would be really nice if PyTorch could as well. Thanks!

@jeisner

This comment has been minimized.

Copy link

@jeisner jeisner commented Aug 17, 2018

Yes! Another use case is to compute Hessian-vector products, which are useful in nested optimization such as SPEN, as well as in second-order optimizers such as stochastic meta-descent.

Pearlmutter (1994) gives the best algorithm. You want the Hessian of f(x) times the vector v. First compute ∇f(xv) (e.g., by backprop) with respect to input x at η=0. Then apply forward-mode AD to that computation to determine the sensitivity of that gradient vector to η, which represents the size of the step in direction v.

This is computing the partials of an output vector ∇f(xv) with respect to an input scalar η-- exactly the situation where forward-mode AD is O(n) times faster than reverse-mode. Reverse mode computes a whole vector at each node of the computation graph, whereas forward mode computes only a scalar at each node.

(Admittedly, there is a simple alternative approach that doesn't require forward-mode AD, but that approach uses the finite-difference method, so it will suffer a little numerical error without being any faster, and it requires you to pick a magic stepsize. So it would be neat if PyTorch could do it right.)

@ezyang

This comment has been minimized.

Copy link
Contributor

@ezyang ezyang commented Jun 23, 2019

Some implementation notes, based on scribbles in my notebook.

Derivatives.yaml can stay. It's not necessary to write "forwards mode" derivatives for all of the functions, the backwards functions (really, their "transpose", but you will see what this means shortly) suffice.

As an example, consider cross product, whose reverse-mode derivatives are currently defined as:

- name: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
  self: other.cross(grad, dim)
  other: grad.cross(self, dim)

The forward mode derivative is formed by computing each constituent derivative with the respective derivatives of the arguments, and then summing all the subexpressions together:

def cross_forward_ad(self, dself, other, dother, dim):
  return cross(self, other), other.cross(dself, dim) + dother.cross(self, dim)

You can verify that this formula is indeed what you would expect given that cross product is distributive over addition. (The "transposition" that happened here refers to the fact that in the reverse mode, we took a single grad input and produced two outputs; in forward mode, we took two grad inputs and produced a single output. But computationally, we have done essentially the same work: the key is that transposition is not an inverse!)

(It might still be profitable, however, to write more fused derivatives in some situations...)

EDIT: This is not right in the general case when inputs and outputs have different shapes, I think.

What should the API be? The easiest API is a HIPS/autograd / Tangent style API which takes a function and returns a new function that also computes this derivatives, with respect to the input differentials.

df = torch.autograd.forward_grad(f)
# if f(x, y) returns r, then
# df(x, dx, y, dy) returns r, dr

However, this isn't too consistent with the existing APIs we provide in torch.autograd, where we don't generally expect users to wrap up their models in a function and then pass it into a higher order function we provide.

One more "idiomatic" API would be to follow the lead of requires_grad, and allow users to associate derivatives with variables, which then get propagated in the forwards style.

dx = torch.randn(2, 3)
x = torch.randn(2, 3, derivative=dx)

dy = torch.randn(3, 4)
y = torch.randn(3, 4, derivative=dy)

z = f(x, y)
dz = z.derivative

derivative should probably not coincide with grad, because if you have derivatives set you want them to propagate by default (grad doesn't have anything to do with propagation; that's requires_grad job). Another benefit of this is that it is more obvious how to implement in this mode.

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented Jun 23, 2019

Another option would be to have a new tensor type that represents dual numbers over the base type.

Also, there's no reason why we couldn't write a more HIPS-like closure-based API (for both forward and reverse mode) if that ends up being much easier conceptually.

@ezyang

This comment has been minimized.

Copy link
Contributor

@ezyang ezyang commented Jun 23, 2019

Actually, the change in formula is not as simple as I thought. Consider matrix multiply:

name: mm(mat1, mat2)
mat1: grad.mm(mat2.t())
mat2: mat1.t().mm(grad)

The correct forwards mode formula is:

dmat1.mm(mat2) + mat2.mm(dmat1)

This is closely related to the original form (structurally, it's identical) but the captured variables were transposed.

@zou3519

This comment has been minimized.

Copy link
Contributor

@zou3519 zou3519 commented Oct 28, 2019

It is possible to automatically get the result of the forward-mode formulas (jvp formulas) from derivatives.yaml (that has vjp formulas) if we use the autograd trick at https://j-towns.github.io/2017/06/12/A-new-trick.html and invoking autograd twice.

I.e., for a given function f, its forward mode formula can be computed via two calls to autograd.grad: (working python code below)

def jvp(f, x, u):
    x = x.detach().requires_grad_()
    u = u.detach().requires_grad_()
    fx = f(x)
    if isinstance(fx, torch.Tensor):
        v = torch.ones_like(fx, requires_grad=True)
    else:
        v = [torch.ones_like(fxi, requires_grad=True) for fxi in fx]
    vjp = torch.autograd.grad(fx, x, grad_outputs=v, create_graph=True)
    output = torch.autograd.grad(vjp, v, grad_outputs=u)
    return output

# one example
f = lambda x: torch.chunk(x, 2, 0)
x = torch.randn(2, 3)
u = torch.randn_like(x)
jvp(f, x, u)

In fact, we could "get rid" of the first call to autograd.grad because that is just the reverse mode formula for f. This example works for functions that take in exactly one input but can probably generalize (I haven't thought too much about that).

I'm not sure in general if we can generate all forward-mode formulas from just derivatives.yaml without calling into autograd; it's easy to see the transpose in matrix multiplication and pointwise ops can reuse the same formulas but I'm not sure about more complicated things. If we can't then the double grad trick can help implement a first version of the feature.

@ezyang

This comment has been minimized.

Copy link
Contributor

@ezyang ezyang commented Oct 29, 2019

Note that you don't actually want to use double grad trick at runtime as it is quite a bit slower than just doing it normally

@RylanSchaeffer

This comment has been minimized.

Copy link

@RylanSchaeffer RylanSchaeffer commented Nov 23, 2019

@ezyang do you have a good reference/tutorial for a "normal" way to compute the end to end jacobian?

@zou3519

This comment has been minimized.

Copy link
Contributor

@zou3519 zou3519 commented Dec 2, 2019

The best way to compute a full jacobian matrix right now is to backward multiple times through the graph. For example, given a f that maps R^n to R^m, then we can compute the jacobian via the following:

def unit_vectors(length):
    result = []
    for i in range(0, length):
        x = torch.zeros(length)
        x[i] = 1
        result.append(x)
    return result

x = torch.randn(3, requires_grad=True)
y = f(x)
result = [torch.autograd.grad(outputs=[y], inputs=[x], grad_outputs=[unit], retain_graph=True)[0] for unit in unit_vectors(y.size(0))]
jacobian = torch.stack(result, dim=0)

Each grad call computes a row of the jacobian matrix: ( df_i / dx_0, df_i / dx_1, df_i / dx_2, ..., df_i / dx_n)

@RylanSchaeffer

This comment has been minimized.

Copy link

@RylanSchaeffer RylanSchaeffer commented Dec 2, 2019

Why are multiple grad calls necessary? Shouldn't one be sufficient?

@RylanSchaeffer

This comment has been minimized.

Copy link

@RylanSchaeffer RylanSchaeffer commented Dec 2, 2019

@zou3519 also, where do you use the unit_vectors() function?

@zou3519

This comment has been minimized.

Copy link
Contributor

@zou3519 zou3519 commented Dec 2, 2019

Why are multiple grad calls necessary? Shouldn't one be sufficient?

Let's move discussion to https://discuss.pytorch.org/ ; the topic of computing the full jacobian is a question and orthogonal to this feature request for forward-mode AD

@RylanSchaeffer

This comment has been minimized.

Copy link

@RylanSchaeffer RylanSchaeffer commented Dec 2, 2019

@zou3519 I already posted. Let me link the post.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
7 participants
You can’t perform that action at this time.