Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Jacobian like tf.GradientTape #23475

Open
Gsunshine opened this issue Jul 27, 2019 · 16 comments

Comments

@Gsunshine
Copy link

@Gsunshine Gsunshine commented Jul 27, 2019

馃殌 Feature

We hope to get a parallel implementation of batched Jacobian like tensorflow, e.g.

from tensorflow.python.ops.parallel_for.gradients import jacobian
jac = jacobian(y, x)
with tf.GradientTape() as g:
  x = tf.placeholder(tf.float32,  shape=(None, 3, 32, 32))
  g.watch(x)
  y = f(x)

batch_jacobian = g.batch_jacobian(y, x)
@colesbury

This comment has been minimized.

Copy link
Member

@colesbury colesbury commented Jul 29, 2019

@Gsunshine what does parallel mean in this context?

No one is currently working on this, but I think we'd happily accept a PR.

@colesbury colesbury added the triaged label Jul 29, 2019
@Gsunshine

This comment has been minimized.

Copy link
Author

@Gsunshine Gsunshine commented Jul 30, 2019

A toy implementation embeded in Python.

def jacobian_in_batch(y, x):
    '''
    Compute the Jacobian matrix in batch form.
    Return (B, D_y, D_x)
    '''

    batch = y.shape[0]
    single_y_size = np.prod(y.shape[1:])
    y = y.view(batch, -1)
    vector = torch.ones(batch).to(y)

    # Compute Jacobian row by row.
    # dy_i / dx -> dy / dx
    # (B, D) -> (B, 1, D) -> (B, D, D)
    jac = [torch.autograd.grad(y[:, i], x, 
                               grad_outputs=vector, 
                               retain_graph=True,
                               create_graph=True)[0].view(batch, -1)
                for i in range(single_y_size)]
    jac = torch.stack(jac, dim=1)
    
    return jac

However, it relies on a for loop in Python with O(D_y) complexity. This is based on the reverse mode AD with API torch.autograd.grad. For a minimum modification to Pytorch core, it is reasonable to compute the Jacobian matrix via a parallel for loop, like pfor in tensorflow. Otherwise, we need to change the autograd mechanism in pytorch, which is expensive and unreasonable.

The question is how difficult it is to add pfor for pytorch. It seems not easy according to the implementation in tensorflow. But I think it is a desirable faeture. If the complexity is well-estimated, I will implement it and send a PR.

Any suggestions? Thx!

@chanchann

This comment has been minimized.

Copy link

@chanchann chanchann commented Aug 1, 2019

Great suggestion! Pytorch should have an efficient implementation of jacobian.

@soumith

This comment has been minimized.

Copy link
Member

@soumith soumith commented Aug 1, 2019

A naive "parallel" can simply be a python for-loop.

A tf pfor basically masks, does the op in parallel and unmasks, and it only supports a limited set of operators "parallely" that they hand-registered.

The closest thing we have in-flight with the same rough concept is NestedTensor which is a bit more general than pfor, and is similar to tf's RaggedTensor

The similar but other instances in literature to do this are for example Dynet's dynamic batching.

An alternate approach is to thread all of these parallely (using python threads or multiprocessing) and hope for good inter-op / intra-op parallelism setup in the framework, like discussed here.

This is not a simple issue to work on.

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented Aug 2, 2019

FWIW if we ramped up our symbolic autodiff in the JIT we could implement a JAX-like wrapper around PyTorch, and then this would be expressible. Otherwise it would be quite hard to transform all our derivatives to work in a batched setting (again, unless we move them to TorchScript, where this batching transform can be applied automatically).

@Gsunshine

This comment has been minimized.

Copy link
Author

@Gsunshine Gsunshine commented Aug 6, 2019

Note my toy implementation above. A for-loop is applied to y's dimension, not the batch dimension. And I test it in a toy case. It seems its complexity increases not linearly.

In [82]: x = torch.randn([128, 3, 32, 32]).requires_grad_()

In [83]: conv = nn.Conv2d(3, 1, 32, bias=False)

In [84]: y = conv(x).view(128, -1)

In [85]: y.shape
Out[85]: torch.Size([128, 1])

In [86]: timeit batch_jacobian(y, x)
4.31 ms 脗卤1.36 ms per loop (mean 脗卤std. dev. of 7 runs, 100 loops each)

In [87]: conv = nn.Conv2d(3, 128, 32, bias=False)

In [88]: y = conv(x).view(128, -1)

In [89]: y.shape
Out[89]: torch.Size([128, 128])

In [90]: timeit batch_jacobian(y, x)
3.29 s 脗卤227 ms per loop (mean 脗卤std. dev. of 7 runs, 1 loop each)

@soumith Besides, I test some cases in tf's pfor. A basic observation is that they do not support tf.cond and tf.while_loop. Thus BN is not available. I am trying to follow #22169 and #22783 . However, it seems there is not vectorized support in the current prototype.

@apaszke Is it possible now to accelerate it via current torch.jit.script?

I agree it is actually not a simple issue to work on.

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented Aug 9, 2019

The non-linear behavior is likely not because of the differentiation logic, but e.g. Conv2d picking different algorithms internally when you change the dimensionality. If you want to see the true complexity without any hidden variables like that use a simpler function (e.g. a polynomial).

I actually have an idea for how to do batched autograd, and it might not even be too complicated. I'll see what we can do.

@Gsunshine

This comment has been minimized.

Copy link
Author

@Gsunshine Gsunshine commented Aug 25, 2019

@apaszke It is a greatly desirable feature. We are very appreciate if you can help.

Is it possible that we can hack a faster implementation than a naive for-loop, if is might not be too complicated?

@zakizhou

This comment has been minimized.

Copy link

@zakizhou zakizhou commented Oct 14, 2019

FWIW if we ramped up our symbolic autodiff in the JIT we could implement a JAX-like wrapper around PyTorch, and then this would be expressible. Otherwise it would be quite hard to transform all our derivatives to work in a batched setting (again, unless we move them to TorchScript, where this batching transform can be applied automatically).

Is there a timeline for this feature? It's great to have explicit hessian and jacobian in pytorch like jax.

@apaszke

This comment has been minimized.

Copy link
Member

@apaszke apaszke commented Oct 20, 2019

Explicit jacobian and hessian are still relatively easy to do as shown in this gist. We should upstream those functions so that people don't have to write them out by hand though. In the future we'll also be able to compute batched gradients, thanks to the nested tensor patches that @cpuhrsch is working on.

@n-gao

This comment has been minimized.

Copy link

@n-gao n-gao commented Oct 31, 2019

This feature would be great!
Is there any faster way than a for loop atm? I have to compute the derivatives of every network output (which is one per sample in a batch) wrt every parameter in the network. So, if my batch_size is N and I have M parameters I need a NxM matrix. The batch size has to be quite large and this has to be evaluated at every training step. The for loop is just not fast enough for this purpose.

@tsauri

This comment has been minimized.

Copy link

@tsauri tsauri commented Nov 5, 2019

I need to get jacobian of weights parameter wrt outputs with ndim say =10. Can param.grad be expanded?
Currently I output.backward(one_hot) according to this gist multiple times.
https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa

@radenmuaz

This comment has been minimized.

Copy link

@radenmuaz radenmuaz commented Nov 30, 2019

FWIW if we ramped up our symbolic autodiff in the JIT we could implement a JAX-like wrapper around PyTorch, and then this would be expressible. Otherwise it would be quite hard to transform all our derivatives to work in a batched setting (again, unless we move them to TorchScript, where this batching transform can be applied automatically).

Any updates on this issue?

@cpuhrsch

This comment has been minimized.

Copy link
Contributor

@cpuhrsch cpuhrsch commented Dec 2, 2019

@zou3519

This comment has been minimized.

Copy link
Contributor

@zou3519 zou3519 commented Dec 2, 2019

Explicit jacobian and hessian are still relatively easy to do as shown in this gist. We should upstream those functions so that people don't have to write them out by hand though. In the future we'll also be able to compute batched gradients, thanks to the nested tensor patches that @cpuhrsch is working on.

cc @albanD how do you feel about upstreaming these for now so that users don't have to write them out by hand? I see requests for those from time-to-time (this is the most recent one) and each time I write out the code so I think it would be nicer to have those as helper functions even though they may not be the most efficient things in the world.

@albanD

This comment has been minimized.

Copy link
Contributor

@albanD albanD commented Dec 2, 2019

Yes, I want to do it: #30632

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
You can鈥檛 perform that action at this time.