New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients of non-scalars (higher rank Jacobians) #675

Open
zackchase opened this Issue Jan 4, 2016 · 64 comments

Comments

Projects
None yet
@zackchase
Contributor

zackchase commented Jan 4, 2016

Currently if you call gradients(ys, xs), it will return the sum of dy/dx over all ys for each x in xs. I believe this doesn't accord with an a priori mathematical notion of the derivative of a vector. I'd like the way to take the derivative of ys wrt xs where both are vectors and have a Jacobian matrix returned. By extension, I'd like to take the derivative of a vector wrt a matrix and get back a 3-tensor. There doesn't seem to be a convenient tensorflow function to compute the Jacobian or higher order derivatives. Am I missing something or is this functionality that we could add?

@keveman

This comment has been minimized.

Show comment
Hide comment
@keveman

keveman Jan 4, 2016

Contributor

zackchase@, you are right about the current gradients function. Currently, you can compute the Jacobian of, say, a vector, by calling gradients multiple times, one for every scalar component (obtained by slicing) of the original vector, and reassembling the results. Contributions are welcome to make this nicer and efficient.

Contributor

keveman commented Jan 4, 2016

zackchase@, you are right about the current gradients function. Currently, you can compute the Jacobian of, say, a vector, by calling gradients multiple times, one for every scalar component (obtained by slicing) of the original vector, and reassembling the results. Contributions are welcome to make this nicer and efficient.

@girving

This comment has been minimized.

Show comment
Hide comment
@girving

girving Jan 4, 2016

Contributor

It'd be pretty hard to support gradients of non-scalars with our current setup, since it would require every gradient function to handle extra rank input. The one possibility I could see would be if we add some sort of map facility to register how to add extra ranks to ops, then compute gradients with respect to extra rank by computing lower rank and calling the registered map transformations.

Someone asked for map a while back, so if anyone wanted to tackle this task that might be the way to go. Handling it at the gradient function level is probably bad, since it would add required complexity to an existing feature. Warning: This is a pretty large change, so a good deal of discussion would be in order before starting.

Contributor

girving commented Jan 4, 2016

It'd be pretty hard to support gradients of non-scalars with our current setup, since it would require every gradient function to handle extra rank input. The one possibility I could see would be if we add some sort of map facility to register how to add extra ranks to ops, then compute gradients with respect to extra rank by computing lower rank and calling the registered map transformations.

Someone asked for map a while back, so if anyone wanted to tackle this task that might be the way to go. Handling it at the gradient function level is probably bad, since it would add required complexity to an existing feature. Warning: This is a pretty large change, so a good deal of discussion would be in order before starting.

@girving girving changed the title from Problems Calculating the Jacobian to Gradients of non-scalars (higher rank Jacobians) Jan 4, 2016

@girving girving added the enhancement label Jan 4, 2016

@zackchase

This comment has been minimized.

Show comment
Hide comment
@zackchase

zackchase Jan 4, 2016

Contributor

Hi Geoffrey, thanks for taking an interest in this issue. I was initially confused by the use of "rank" to describe the number of dimensions of the array. Should we avoid this name in the thread title and documentation to preempt confusion via overloading the linear algebra notion of rank?

Contributor

zackchase commented Jan 4, 2016

Hi Geoffrey, thanks for taking an interest in this issue. I was initially confused by the use of "rank" to describe the number of dimensions of the array. Should we avoid this name in the thread title and documentation to preempt confusion via overloading the linear algebra notion of rank?

@girving

This comment has been minimized.

Show comment
Hide comment
@girving

girving Jan 4, 2016

Contributor

Tensor rank is very standard terminology: http://mathworld.wolfram.com/TensorRank.html

Contributor

girving commented Jan 4, 2016

Tensor rank is very standard terminology: http://mathworld.wolfram.com/TensorRank.html

@zackchase

This comment has been minimized.

Show comment
Hide comment
@zackchase

zackchase Jan 4, 2016

Contributor

Cool. The terminology gets funny when we talk about rank-R decompositions of tensors, meaning the tensor can be represented as a sum of R outer products of rank-1 tensors, but probably not a problem for us to solve here.

One thing I thought of is that I would like to compute the frobenius norm of the Jacobian of the log probabilities for use as a smoothness penalty much like the smoothness penalty used in a contractive autoencoder. In this case, as we only seek a scalar at the end, is there a more efficient method than separately calculating the derivative of each output with respect to the inputs?

Contributor

zackchase commented Jan 4, 2016

Cool. The terminology gets funny when we talk about rank-R decompositions of tensors, meaning the tensor can be represented as a sum of R outer products of rank-1 tensors, but probably not a problem for us to solve here.

One thing I thought of is that I would like to compute the frobenius norm of the Jacobian of the log probabilities for use as a smoothness penalty much like the smoothness penalty used in a contractive autoencoder. In this case, as we only seek a scalar at the end, is there a more efficient method than separately calculating the derivative of each output with respect to the inputs?

@girving

This comment has been minimized.

Show comment
Hide comment
@girving

girving Jan 4, 2016

Contributor

Are you saying your network has a bunch of outputs, and then you combine them into a single scalar that you are trying to optimize? In that case, you should differentiate with respect to that single scalar.

Contributor

girving commented Jan 4, 2016

Are you saying your network has a bunch of outputs, and then you combine them into a single scalar that you are trying to optimize? In that case, you should differentiate with respect to that single scalar.

@zackchase

This comment has been minimized.

Show comment
Hide comment
@zackchase

zackchase Jan 4, 2016

Contributor

Not exactly. I'm saying if one wants to penalize the norm of the Jacobian of the mapping function.
So optimization objective would be (pseudocode):

cost(y,yhat, X) = loss(y,yhat) + norm(Jacobian(log(yhat), X))

Contributor

zackchase commented Jan 4, 2016

Not exactly. I'm saying if one wants to penalize the norm of the Jacobian of the mapping function.
So optimization objective would be (pseudocode):

cost(y,yhat, X) = loss(y,yhat) + norm(Jacobian(log(yhat), X))

@girving

This comment has been minimized.

Show comment
Hide comment
@girving

girving Jan 4, 2016

Contributor

Ah, sorry for not reading carefully. You're correct that (as far as I know) there's no easy way to do that in current tensorflow. According to someone more knowledgeable than I, people generally do such contractive autoencoders by writing out the first derivative manually. Also, they generally restrict to single layer at a time networks for speed issues, since doing the full Jacobian for a multilayer network is quite expensive.

Contributor

girving commented Jan 4, 2016

Ah, sorry for not reading carefully. You're correct that (as far as I know) there's no easy way to do that in current tensorflow. According to someone more knowledgeable than I, people generally do such contractive autoencoders by writing out the first derivative manually. Also, they generally restrict to single layer at a time networks for speed issues, since doing the full Jacobian for a multilayer network is quite expensive.

@zackchase

This comment has been minimized.

Show comment
Hide comment
@zackchase

zackchase Jan 9, 2016

Contributor

Regardless, it would be good to have a way to call derivatives of vectors and receive gradients of the expected shape.

Contributor

zackchase commented Jan 9, 2016

Regardless, it would be good to have a way to call derivatives of vectors and receive gradients of the expected shape.

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Jan 9, 2016

Contributor

Differentiating with respect to one variable is similar to how it works in Theano. I agree it may be confusing when TensorFlow automatically turns many variables into one by taking the sum. An alternative would be to fail if there's more than 1 output variable specified, or have a wrapper that automatically calls existing gradient function on each output variable

The reason for "one output variable at a time" in TensorFlow (and Theano) is because we do reverse mode AD by default. In reverse AD you have a single target scalar quantity and you propagate sensitivities with respect to that quantity. In contrast, if you we did forward AD instead, we would naturally support multiple output variables, but only compute derivative with respect to one scalar variable at a time. Supporting mixed mode propagation to cover "multiple inputs/multiple outputs" case in the most efficient way could be a lot of extra plumbing.

If you have a small number of output variables but large number of input variables, standard thing to do is to apply reverse AD with respect to each variable in a loop. This is what Theano recommends to do for compute Hessian for instance: http://deeplearning.net/software/theano/tutorial/gradients.html#computing-the-hessian. If you have a small number of input variables but large number of output variables, then the most efficient thing to do would be to run forward-mode AD for all the input variables in a loop. Forward mode AD is not implemented and would require adding an equivalent of Theano's "Rop" operator to differentiable ops and some plumbing to call them instead of existing op "gradient" function (existing gradient function is an equivalent of Lop operation, or "left multiply sensitivity vector by op's jacobian" operation)

Contributor

yaroslavvb commented Jan 9, 2016

Differentiating with respect to one variable is similar to how it works in Theano. I agree it may be confusing when TensorFlow automatically turns many variables into one by taking the sum. An alternative would be to fail if there's more than 1 output variable specified, or have a wrapper that automatically calls existing gradient function on each output variable

The reason for "one output variable at a time" in TensorFlow (and Theano) is because we do reverse mode AD by default. In reverse AD you have a single target scalar quantity and you propagate sensitivities with respect to that quantity. In contrast, if you we did forward AD instead, we would naturally support multiple output variables, but only compute derivative with respect to one scalar variable at a time. Supporting mixed mode propagation to cover "multiple inputs/multiple outputs" case in the most efficient way could be a lot of extra plumbing.

If you have a small number of output variables but large number of input variables, standard thing to do is to apply reverse AD with respect to each variable in a loop. This is what Theano recommends to do for compute Hessian for instance: http://deeplearning.net/software/theano/tutorial/gradients.html#computing-the-hessian. If you have a small number of input variables but large number of output variables, then the most efficient thing to do would be to run forward-mode AD for all the input variables in a loop. Forward mode AD is not implemented and would require adding an equivalent of Theano's "Rop" operator to differentiable ops and some plumbing to call them instead of existing op "gradient" function (existing gradient function is an equivalent of Lop operation, or "left multiply sensitivity vector by op's jacobian" operation)

@tillahoffmann

This comment has been minimized.

Show comment
Hide comment
@tillahoffmann

tillahoffmann Apr 27, 2016

Contributor

I was hoping to implement higher order derivatives using the map function but am getting an error message I can't quite get my head around. My implementation is (in pseudo code)

params = tf.Variable("some initial value")
loss = some_function(params)
grads = tf.gradients(loss, params)[0]
hess = tf.map_fn(lambda grad: tf.gradients(grad, X)[0], grads)

When I fetch the hessian, I get the error message

InvalidArgumentError: All inputs to node map/while/gradients/map/TensorArrayUnpack_grad/TensorArrayGrad/TensorArrayGrad must be from the same frame.

I assumed that tensorflow has an issue because it doesn't know about params in the loop (cf. non_sequences in theano scan), and extended map_fn to pass extra arguments to the loop. Unfortunately, the extra arguments get wrapped in an identity transformation and tf.gradients(params, tf.identity(params)) gives [None], which seems a bit unintuitive.

Looping in python is of course fine but I'd like to avoid introducing an extra node to the graph for every parameter. Any suggestions?

Contributor

tillahoffmann commented Apr 27, 2016

I was hoping to implement higher order derivatives using the map function but am getting an error message I can't quite get my head around. My implementation is (in pseudo code)

params = tf.Variable("some initial value")
loss = some_function(params)
grads = tf.gradients(loss, params)[0]
hess = tf.map_fn(lambda grad: tf.gradients(grad, X)[0], grads)

When I fetch the hessian, I get the error message

InvalidArgumentError: All inputs to node map/while/gradients/map/TensorArrayUnpack_grad/TensorArrayGrad/TensorArrayGrad must be from the same frame.

I assumed that tensorflow has an issue because it doesn't know about params in the loop (cf. non_sequences in theano scan), and extended map_fn to pass extra arguments to the loop. Unfortunately, the extra arguments get wrapped in an identity transformation and tf.gradients(params, tf.identity(params)) gives [None], which seems a bit unintuitive.

Looping in python is of course fine but I'd like to avoid introducing an extra node to the graph for every parameter. Any suggestions?

@girving

This comment has been minimized.

Show comment
Hide comment
@girving

girving Apr 28, 2016

Contributor

@yuanbyu: Do you understand this issue with tf.map_fn?

Contributor

girving commented Apr 28, 2016

@yuanbyu: Do you understand this issue with tf.map_fn?

@girving girving added the triaged label Jun 8, 2016

@girving

This comment has been minimized.

Show comment
Hide comment
@girving

girving Jun 9, 2016

Contributor

Note for anyone who comes across this thread: tf.map_fn is an unrelated thing involving control flow, not something related to mapping over extra rank tensors.

Contributor

girving commented Jun 9, 2016

Note for anyone who comes across this thread: tf.map_fn is an unrelated thing involving control flow, not something related to mapping over extra rank tensors.

@yuanbyu

This comment has been minimized.

Show comment
Hide comment
@yuanbyu

yuanbyu Aug 24, 2016

Contributor

We don't support higher-order gradients for while_loop/map_fn/scan/fold. You should see an informative error message if you try to do that.

Contributor

yuanbyu commented Aug 24, 2016

We don't support higher-order gradients for while_loop/map_fn/scan/fold. You should see an informative error message if you try to do that.

@yuanbyu yuanbyu closed this Aug 24, 2016

@vladfi1

This comment has been minimized.

Show comment
Hide comment
@vladfi1

vladfi1 Oct 3, 2016

Contributor

@yaroslavvb Any plans on adding forward mode AD? I filed an issue on it a couple weeks ago but haven't heard back.

Contributor

vladfi1 commented Oct 3, 2016

@yaroslavvb Any plans on adding forward mode AD? I filed an issue on it a couple weeks ago but haven't heard back.

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Oct 3, 2016

Contributor

@vladfi1 I'm no longer at Brain, so I wouldn't know. I would say it is unlikely to ever be part of core TensorFlow. There are >450 ops in TF, so Brain team would have to implement forward AD grad method for all of 450 ops and maintain them forever, or alternatively have to explain why someone's favorite op doesn't have forward AD support. It seems more realistic that someone would create a separately maintained library that does forward-AD, and utilizes TensorFlow as backend. Kind of like autograd but using TensorFlow instead of numpy as the backend.

Contributor

yaroslavvb commented Oct 3, 2016

@vladfi1 I'm no longer at Brain, so I wouldn't know. I would say it is unlikely to ever be part of core TensorFlow. There are >450 ops in TF, so Brain team would have to implement forward AD grad method for all of 450 ops and maintain them forever, or alternatively have to explain why someone's favorite op doesn't have forward AD support. It seems more realistic that someone would create a separately maintained library that does forward-AD, and utilizes TensorFlow as backend. Kind of like autograd but using TensorFlow instead of numpy as the backend.

@myaooo

This comment has been minimized.

Show comment
Hide comment
@myaooo

myaooo Feb 20, 2017

Is tf.test.compute_gradient is some kind of function that we can get the Jacobian matrix (not as a tensor but as a numpy.ndarray) of a vector tensor y w.r.t. a vector tensor x?

myaooo commented Feb 20, 2017

Is tf.test.compute_gradient is some kind of function that we can get the Jacobian matrix (not as a tensor but as a numpy.ndarray) of a vector tensor y w.r.t. a vector tensor x?

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Feb 20, 2017

Contributor
Contributor

yaroslavvb commented Feb 20, 2017

@myaooo

This comment has been minimized.

Show comment
Hide comment
@myaooo

myaooo Feb 21, 2017

@yaroslavvb Thanks for your reply. I see why it's expensive to do that.
Do you have any suggestions if I want to get the numerical results of the Jacobian with certain inputs x. The only not so expensive workaround I can think of is to apply perturbation on each dim of x to get the approximate results.

myaooo commented Feb 21, 2017

@yaroslavvb Thanks for your reply. I see why it's expensive to do that.
Do you have any suggestions if I want to get the numerical results of the Jacobian with certain inputs x. The only not so expensive workaround I can think of is to apply perturbation on each dim of x to get the approximate results.

@tillahoffmann

This comment has been minimized.

Show comment
Hide comment
@tillahoffmann

tillahoffmann Feb 21, 2017

Contributor

Hessians are supported. But, as you mentioned, they are expensive to compute.

Contributor

tillahoffmann commented Feb 21, 2017

Hessians are supported. But, as you mentioned, they are expensive to compute.

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Feb 21, 2017

Contributor

@myaooo currently tf.gradients gives gradient of function which outputs a scalar. You could call it multiple times for each output. IE, [tf.gradients(component, var) for component in vector].
There's a trick for computing hessians more efficiently described in #4897 (comment)

Contributor

yaroslavvb commented Feb 21, 2017

@myaooo currently tf.gradients gives gradient of function which outputs a scalar. You could call it multiple times for each output. IE, [tf.gradients(component, var) for component in vector].
There's a trick for computing hessians more efficiently described in #4897 (comment)

@el3ment

This comment has been minimized.

Show comment
Hide comment
@el3ment

el3ment Mar 11, 2017

Adding a comment that this was also a bit of a gotcha for me, and simultaenously adding a vote to consider a change to the api. Calling tf.gradients(matrix, vector, aggregation_method=None) I would have expected it to return an tensor of the same shape as matrix, but it returns sum(M) along an axis.

el3ment commented Mar 11, 2017

Adding a comment that this was also a bit of a gotcha for me, and simultaenously adding a vote to consider a change to the api. Calling tf.gradients(matrix, vector, aggregation_method=None) I would have expected it to return an tensor of the same shape as matrix, but it returns sum(M) along an axis.

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Mar 11, 2017

Contributor

It's a little less surprising when you know that there's no efficient algorithm to compute Jacobian in this setting, therefore all neural net frameworks use reverse mode AD which requires target to be a scalar

Contributor

yaroslavvb commented Mar 11, 2017

It's a little less surprising when you know that there's no efficient algorithm to compute Jacobian in this setting, therefore all neural net frameworks use reverse mode AD which requires target to be a scalar

@girving

This comment has been minimized.

Show comment
Hide comment
@girving

girving Mar 13, 2017

Contributor

@el3ment I consider this a bug. Unfortunately, I've looked into fixing it, and there's a surprising number of tests within TensorFlow that depend on it. I didn't get past said piles of tests to see whether there's interesting downstream code in Google that uses it.

Contributor

girving commented Mar 13, 2017

@el3ment I consider this a bug. Unfortunately, I've looked into fixing it, and there's a surprising number of tests within TensorFlow that depend on it. I didn't get past said piles of tests to see whether there's interesting downstream code in Google that uses it.

@deeptimhe

This comment has been minimized.

Show comment
Hide comment
@deeptimhe

deeptimhe Mar 21, 2017

@el3ment @girving
I found this confusion as well.

x = tf.constant([3], dtype=tf.float32)
y = tf.constant([9,8,7,6,5], dtype=tf.float32)
z = x_mean * y
sess.run(tf.gradients(z, x_mean))

and I got array([35.])

deeptimhe commented Mar 21, 2017

@el3ment @girving
I found this confusion as well.

x = tf.constant([3], dtype=tf.float32)
y = tf.constant([9,8,7,6,5], dtype=tf.float32)
z = x_mean * y
sess.run(tf.gradients(z, x_mean))

and I got array([35.])

@isabeaups

This comment has been minimized.

Show comment
Hide comment
@isabeaups

isabeaups Mar 25, 2017

Just trying to understand something. I was trying to make a hack of tf.gradient that would give, for a y of rank (M,N) and an x of rank (Q,P) a gradient tensor of rank (M,N,Q,P) as one would naturally expect. However, as mentioned already here, what one gets is a rank (Q,P) which is the grad of the sum of the elements of y. Now what I can't figure out, looking into the tensorflow code is where is that sum over elements of y made? Is it as the beginning or at the end? Could someone help me pinpoint the lines of code where that is done?

isabeaups commented Mar 25, 2017

Just trying to understand something. I was trying to make a hack of tf.gradient that would give, for a y of rank (M,N) and an x of rank (Q,P) a gradient tensor of rank (M,N,Q,P) as one would naturally expect. However, as mentioned already here, what one gets is a rank (Q,P) which is the grad of the sum of the elements of y. Now what I can't figure out, looking into the tensorflow code is where is that sum over elements of y made? Is it as the beginning or at the end? Could someone help me pinpoint the lines of code where that is done?

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Mar 25, 2017

Contributor

Note that "summing ys together" is sort of buried deep in backprop. Note that TensorFlow uses Reverse Mode AD in order to compute gradients, and reverse mode AD only supports a scalar output function. If you have non-scalar output, we don't have any algorithm much better than calling reverse mode AD on each of the components of the output separate.

(there's a trick mentioned by ian here which is a bit better on Python overhead compared to calling gradients many times)

That said, the place where summation of ys happens is in line 479-481 in gradients_impl.py

First note that grad_ys are set to the same value in this line
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
Later there's this block


 # Add the initial gradients for the ys.
    for y, grad_y in zip(ys, grad_ys):
      _SetGrad(grads, y, grad_y)

Note that if you had a single y*=y1+y2+y3+..., then the gradients would propagate in the same way -- the same backprop value from y would be copied over to nodes y1,y2,y3, so this block implicitly treats ys as being added up into the final sum

Contributor

yaroslavvb commented Mar 25, 2017

Note that "summing ys together" is sort of buried deep in backprop. Note that TensorFlow uses Reverse Mode AD in order to compute gradients, and reverse mode AD only supports a scalar output function. If you have non-scalar output, we don't have any algorithm much better than calling reverse mode AD on each of the components of the output separate.

(there's a trick mentioned by ian here which is a bit better on Python overhead compared to calling gradients many times)

That said, the place where summation of ys happens is in line 479-481 in gradients_impl.py

First note that grad_ys are set to the same value in this line
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
Later there's this block


 # Add the initial gradients for the ys.
    for y, grad_y in zip(ys, grad_ys):
      _SetGrad(grads, y, grad_y)

Note that if you had a single y*=y1+y2+y3+..., then the gradients would propagate in the same way -- the same backprop value from y would be copied over to nodes y1,y2,y3, so this block implicitly treats ys as being added up into the final sum

@isabeaups

This comment has been minimized.

Show comment
Hide comment
@isabeaups

isabeaups Mar 25, 2017

Thanks for your response @yaroslavvb !

But there are two things which confuse me from your answer:

(1)

we don't have any algorithm much better than calling reverse mode AD on each of the components of the output separate.

I'm pretty sure that that cannot be the case. If you think of a deep network outputting a vector y, then the gradient of y[0] is basically the same as that from y[1], they only defer by the last weight matrix elements. i.e. y[0] = W(0,j) V(j, ...) while y[1] = W(1,j) V(j,...).

(2) You say the summation happens at line 479-481 of gradients_impl.py
But just above those lines, there is the following comment in the code:
# grads: op => list of gradients received on each output endpoint of the
# op. The gradients for each endpoint are initially collected as a list.
# When it is time to call the op's gradient function, for each endpoint we
# aggregate the list of received gradients into a Add() Operation if there
# is more than one.

And indeed when one looks into the function _SetGrad there is no addition, only appending to a list, and all the y 's, if I understand correctly, are still kept separate by being different keys to the dictionary grads.

So I am utterly confused by this.

Also thank you very much for @goodfeli 's trick, but I don't really understand what he means or how to implement it in practice.

isabeaups commented Mar 25, 2017

Thanks for your response @yaroslavvb !

But there are two things which confuse me from your answer:

(1)

we don't have any algorithm much better than calling reverse mode AD on each of the components of the output separate.

I'm pretty sure that that cannot be the case. If you think of a deep network outputting a vector y, then the gradient of y[0] is basically the same as that from y[1], they only defer by the last weight matrix elements. i.e. y[0] = W(0,j) V(j, ...) while y[1] = W(1,j) V(j,...).

(2) You say the summation happens at line 479-481 of gradients_impl.py
But just above those lines, there is the following comment in the code:
# grads: op => list of gradients received on each output endpoint of the
# op. The gradients for each endpoint are initially collected as a list.
# When it is time to call the op's gradient function, for each endpoint we
# aggregate the list of received gradients into a Add() Operation if there
# is more than one.

And indeed when one looks into the function _SetGrad there is no addition, only appending to a list, and all the y 's, if I understand correctly, are still kept separate by being different keys to the dictionary grads.

So I am utterly confused by this.

Also thank you very much for @goodfeli 's trick, but I don't really understand what he means or how to implement it in practice.

@yaroslavvb

This comment has been minimized.

Show comment
Hide comment
@yaroslavvb

yaroslavvb Mar 25, 2017

Contributor

@isabeaups the two gradients can differ significantly when there are non-linear operators, as is the case for neural networks.

As a toy example, suppose there's some operation that when f(x)=1, and you compute gradient, one of the backprop operators hits limits of its numeric range and produces a NaN, which is a valid float32 value. If you want to backprop from vector [1,2], you are going to have to do full backprop from 1, and from 2, to determine which value causes a NaN. There's no shortcut that'll let you keep same size intermediate matrices and be able to backprop from both values in parallel

Contributor

yaroslavvb commented Mar 25, 2017

@isabeaups the two gradients can differ significantly when there are non-linear operators, as is the case for neural networks.

As a toy example, suppose there's some operation that when f(x)=1, and you compute gradient, one of the backprop operators hits limits of its numeric range and produces a NaN, which is a valid float32 value. If you want to backprop from vector [1,2], you are going to have to do full backprop from 1, and from 2, to determine which value causes a NaN. There's no shortcut that'll let you keep same size intermediate matrices and be able to backprop from both values in parallel

@jeisses

This comment has been minimized.

Show comment
Hide comment
@jeisses

jeisses Aug 3, 2017

@liber145 I have to compute semi-large Jacobians and am also having performance issues using a for loop.

After looking at the code by @tillahoffmann I came up with the following function, which gives a decent speedup for larger N:

def jacobian(y_flat, x):
    n = y_flat.shape[0]

    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]

    _, jacobian = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j+1, result.write(j, tf.gradients(y_flat[j], x))),
        loop_vars)

    return jacobian.stack()

In my case speed is similar to Theano's jacobian function

jeisses commented Aug 3, 2017

@liber145 I have to compute semi-large Jacobians and am also having performance issues using a for loop.

After looking at the code by @tillahoffmann I came up with the following function, which gives a decent speedup for larger N:

def jacobian(y_flat, x):
    n = y_flat.shape[0]

    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]

    _, jacobian = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j+1, result.write(j, tf.gradients(y_flat[j], x))),
        loop_vars)

    return jacobian.stack()

In my case speed is similar to Theano's jacobian function

@tillahoffmann

This comment has been minimized.

Show comment
Hide comment
@tillahoffmann

tillahoffmann Aug 3, 2017

Contributor

@jeisses, the implementation of hessians on the master branch looks very similar to the one you suggested (see here).

Contributor

tillahoffmann commented Aug 3, 2017

@jeisses, the implementation of hessians on the master branch looks very similar to the one you suggested (see here).

@jeisses

This comment has been minimized.

Show comment
Hide comment
@jeisses

jeisses Aug 3, 2017

@tillahoffmann, yes, I did derive the function directly from that code 😄.

But I'm looking for just a Jacobian matrix, doesn't the hessians function contain second order derivates, and still use aggregated gradients? Or could I use the hessians function for this?

Thanks

jeisses commented Aug 3, 2017

@tillahoffmann, yes, I did derive the function directly from that code 😄.

But I'm looking for just a Jacobian matrix, doesn't the hessians function contain second order derivates, and still use aggregated gradients? Or could I use the hessians function for this?

Thanks

@tillahoffmann

This comment has been minimized.

Show comment
Hide comment
@tillahoffmann

tillahoffmann Aug 4, 2017

Contributor

@jeisses, of course, my bad. In fact, it would be great if we could integrate your proposed Jacobian implementation and simply implement the hessian as hessian(y, x) = jacobian(gradient(f, x), x).

Contributor

tillahoffmann commented Aug 4, 2017

@jeisses, of course, my bad. In fact, it would be great if we could integrate your proposed Jacobian implementation and simply implement the hessian as hessian(y, x) = jacobian(gradient(f, x), x).

@shaifugpt

This comment has been minimized.

Show comment
Hide comment
@shaifugpt

shaifugpt Sep 15, 2017

@shoyer I used your version to compute the value of Jacobian passing y=model.total_loss and x=weights. The value of jacobian for different weights is exactly same as the value of gradients. Why is it so

shaifugpt commented Sep 15, 2017

@shoyer I used your version to compute the value of Jacobian passing y=model.total_loss and x=weights. The value of jacobian for different weights is exactly same as the value of gradients. Why is it so

@shaifugpt

This comment has been minimized.

Show comment
Hide comment
@shaifugpt

shaifugpt Sep 27, 2017

shaifugpt commented Sep 27, 2017

@candidj0

This comment has been minimized.

Show comment
Hide comment
@candidj0

candidj0 Sep 27, 2017

Sorry @shaifugpt, try this one :

def jacobian(y, x, n):
    y_list = tf.unstack(y, num = n)
    jacobian_list = [[tf.gradients(y_, x)[0][i] for y_ in tf.unstack(y_list[i])] for i in range(n)] # list [grad(y0, x), grad(y1, x), ...]
    return tf.stack(jacobian_list)

n is the batch size.

candidj0 commented Sep 27, 2017

Sorry @shaifugpt, try this one :

def jacobian(y, x, n):
    y_list = tf.unstack(y, num = n)
    jacobian_list = [[tf.gradients(y_, x)[0][i] for y_ in tf.unstack(y_list[i])] for i in range(n)] # list [grad(y0, x), grad(y1, x), ...]
    return tf.stack(jacobian_list)

n is the batch size.

@shaifugpt

This comment has been minimized.

Show comment
Hide comment
@shaifugpt

shaifugpt Sep 28, 2017

shaifugpt commented Sep 28, 2017

@oborchers

This comment has been minimized.

Show comment
Hide comment
@oborchers

oborchers Sep 29, 2017

@jeisses and @shoyer I am somewhat confused by the implementation because of the resulting shape of J.

According to https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant
Let f: R^n -> R^m , then the J(f) is a m * n matrix. Given, that the input has a size s, J should be a s * m * n matrix, which the implementation of @candidj0 gives us. (However, it is slow due to the for loop.)

[None, n] doesn't work

Setting s = 1,
Jacobian_1 gives (1, 1, 1, 500) in 3s
Jacobian_2 gives (1, 10, 500) in 3s,
Jacobian_3 gives (1, 10, 1, 500) in 3s,

whereas an s=20 is outputting
Jacobian_1 gives (20, 1, 20, 500) in 3s,
Jacobian_2 gives (20, 10, 500) in 19s,
Jacobian_3 gives (20, 10, 20, 500) in 61,

Have I missed a point somewhere in the implementation (because I really like to get that speedup)?

import tensorflow as tf
import numpy as np
import time

def jacobian_1(y_flat, x):
    n = y_flat.shape[0]

    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]
    _, jacobian = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j+1, result.write(j, tf.gradients(y_flat[j], x))),
        loop_vars)
    return jacobian.stack()

def jacobian_2(y, x, n):
    y_list = tf.unstack(y, num = n)
    jacobian_list = [[tf.gradients(y_, x)[0][i] for y_ in tf.unstack(y_list[i])] for i in range(n)] # list [grad(y0, x), grad(y1, x), ...]
    return tf.stack(jacobian_list)

def jacobian_3(y, x):
  y_flat = tf.reshape(y, (-1,))
  jacobian_flat = tf.stack(
      [tf.gradients(y_i, x)[0] for y_i in tf.unstack(y_flat)])
  return tf.reshape(jacobian_flat, y.shape.concatenate(x.shape))

s = 20
n = 500
m = 10

x = tf.placeholder(tf.float32, [s, n])
w = tf.Variable(tf.truncated_normal([n, m], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[m]))
y = tf.matmul(x, w) + b

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

start = time.time()
j = jacobian_1(y, x)
j_out = sess.run(j, feed_dict={x:np.random.rand(s,n)})
print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))

start = time.time()
j_2 = jacobian_2(y,x, s)
j_out = sess.run(j_2, feed_dict={x:np.random.rand(s,n)})
print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))

start = time.time()
j_3 = jacobian_3(y,x)
j_out = sess.run(j_3, feed_dict={x:np.random.rand(s,n)})
print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))

oborchers commented Sep 29, 2017

@jeisses and @shoyer I am somewhat confused by the implementation because of the resulting shape of J.

According to https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant
Let f: R^n -> R^m , then the J(f) is a m * n matrix. Given, that the input has a size s, J should be a s * m * n matrix, which the implementation of @candidj0 gives us. (However, it is slow due to the for loop.)

[None, n] doesn't work

Setting s = 1,
Jacobian_1 gives (1, 1, 1, 500) in 3s
Jacobian_2 gives (1, 10, 500) in 3s,
Jacobian_3 gives (1, 10, 1, 500) in 3s,

whereas an s=20 is outputting
Jacobian_1 gives (20, 1, 20, 500) in 3s,
Jacobian_2 gives (20, 10, 500) in 19s,
Jacobian_3 gives (20, 10, 20, 500) in 61,

Have I missed a point somewhere in the implementation (because I really like to get that speedup)?

import tensorflow as tf
import numpy as np
import time

def jacobian_1(y_flat, x):
    n = y_flat.shape[0]

    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]
    _, jacobian = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j+1, result.write(j, tf.gradients(y_flat[j], x))),
        loop_vars)
    return jacobian.stack()

def jacobian_2(y, x, n):
    y_list = tf.unstack(y, num = n)
    jacobian_list = [[tf.gradients(y_, x)[0][i] for y_ in tf.unstack(y_list[i])] for i in range(n)] # list [grad(y0, x), grad(y1, x), ...]
    return tf.stack(jacobian_list)

def jacobian_3(y, x):
  y_flat = tf.reshape(y, (-1,))
  jacobian_flat = tf.stack(
      [tf.gradients(y_i, x)[0] for y_i in tf.unstack(y_flat)])
  return tf.reshape(jacobian_flat, y.shape.concatenate(x.shape))

s = 20
n = 500
m = 10

x = tf.placeholder(tf.float32, [s, n])
w = tf.Variable(tf.truncated_normal([n, m], stddev=0.1))
b = tf.Variable(tf.constant(0.1, shape=[m]))
y = tf.matmul(x, w) + b

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

start = time.time()
j = jacobian_1(y, x)
j_out = sess.run(j, feed_dict={x:np.random.rand(s,n)})
print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))

start = time.time()
j_2 = jacobian_2(y,x, s)
j_out = sess.run(j_2, feed_dict={x:np.random.rand(s,n)})
print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))

start = time.time()
j_3 = jacobian_3(y,x)
j_out = sess.run(j_3, feed_dict={x:np.random.rand(s,n)})
print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))
@ModarTensai

This comment has been minimized.

Show comment
Hide comment
@ModarTensai

ModarTensai Oct 12, 2017

For me, I prefer computing the Jacobian matrix using very lightweight operations in the graph.
Since tf.gradients returns the sum, I mask the layer at a single index and then compute the gradient.
I compute the Jacobian for each point in batches then I stack them at the end outside the graph.
Here is a running example based on @oborchers example that produces (s x m x n) array:

import time
import numpy as np
import tensorflow as tf


def jacobian(session, y, x, points, batch_size, as_matrix=True):
    """The Jacobian matrix of `y` w.r.t. `x` at `points`

    Let f(x) be some function that has a Jacobian A at point p
    then, f(p) = y = Ap+b
    where A of shape mxn, p of shape nx1 and b of shape mx1

    Args:
        y: The output tensor
        x: The input tensor
        points: The points of linearization where it can be many points
            of shape [num_points, *self.features_shape]
        batch_size: How many rows of the Jacobian to compute at once
        as_matrix: Whether to return the Jacobian as a matrix or retain
            the shape of the input

    Returns:
        The Jacobian matrices for the given points
        of shape [num_points, *jacobian_shape]
        If `as_matrix`, jacobian_shape is [y.size, *x.shape]
        else, jacobian_shape is [y.size, x.size]
    """
    # add and/or get cached ops to the graph
    if not hasattr(session.graph, "_placeholder"):
        session.graph._placeholder = {}
    if not hasattr(session.graph, "_gradient"):
        session.graph._gradient = {}
    with session.graph.as_default():
        if y.dtype in session.graph._placeholder:
            placeholder = session.graph._placeholder[y.dtype]
        else:
            placeholder = tf.placeholder(y.dtype)
            session.graph._placeholder[y.dtype] = placeholder

        if (y, x) in session.graph._gradient:
            gradient = session.graph._gradient[(y, x)]
        else:
            gradient = tf.gradients(placeholder * y, x)[0]
            session.graph._gradient[(y, x)] = gradient

    # extract the Jacobians for all points
    jacobians_list = []
    for i in range(points.shape[0]):
        # extract the Jacobian matrix for a single point
        partials_list = []
        point = points[i:i + 1, :]
        shape = y.shape.as_list()[1:]
        repeated_point = point
        for mask in masks_batches(shape, batch_size):
            # repeat the point according to the mask's batch_size
            batch_size = mask.shape[0]
            if repeated_point.shape[0] < batch_size:
                repeated_point = np.vstack([point] * batch_size)
            if repeated_point.shape[0] > batch_size:
                repeated_point = repeated_point[:batch_size, :]
            feed = {placeholder: mask, x: repeated_point}
            partial = session.run(gradient, feed_dict=feed)
            partials_list.append(partial)
        jacobian = np.vstack(partials_list)

        # reshape it as a matrix
        if as_matrix:
            jacobian = jacobian.reshape(jacobian.shape[0], -1)

        jacobians_list.append(jacobian)

    # stack Jacobians
    jacobians = np.stack(jacobians_list)

    return jacobians


def masks_batches(shape, batch_size):
    """Batches iterator over all possible masks of the given shape

    A mask is a numpy.ndarray of shape `shape` of all zeros except
    for a single position it is one. It is useful to get those masks
    in batches instead of getting them one by one.

    Args:
        shape: The shape of each mask
        batch_size: How many masks to return in each iteration

    Returns:
        A batch of masks of shape [batch_size, *shape]
    """
    num_rows = np.prod(shape)
    if num_rows < batch_size:
        batch_size = num_rows

    eye = np.eye(batch_size)
    _mask = np.zeros((batch_size, *shape))
    mask = _mask.reshape(batch_size, -1)

    num_batches = -(-num_rows // batch_size)
    for i in range(num_batches):
        start = i * batch_size
        end = min(start + batch_size, num_rows)

        # check if last batch is smaller than batch size
        if end - start < batch_size:
            batch_size = end - start
            eye = np.eye(batch_size)
            _mask = np.zeros((batch_size, *shape))
            mask = _mask.reshape(batch_size, -1)

        mask[:, start:end] = eye
        yield _mask
        mask[:, start:end] = 0


if __name__ == '__main__':
    m = 10
    n = 500
    s = 20

    x = tf.placeholder(tf.float32)
    w = tf.Variable(tf.truncated_normal([n, m], stddev=0.1))
    b = tf.Variable(tf.constant(0.1, shape=[m]))
    y = tf.matmul(x, w) + b

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    start = time.time()
    j_out = jacobian(sess, y, x, np.random.rand(s, n), m)
    w_out = sess.run(w)
    # they should be equal and error ~ < 1e-6 (single precision)
    error = np.linalg.norm(w_out.T - np.mean(j_out, axis=0))
    if error < 1e-6:
        print("Correct Jacobian!")
    else:
        print("Error was {}".format(error))
    print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))

ModarTensai commented Oct 12, 2017

For me, I prefer computing the Jacobian matrix using very lightweight operations in the graph.
Since tf.gradients returns the sum, I mask the layer at a single index and then compute the gradient.
I compute the Jacobian for each point in batches then I stack them at the end outside the graph.
Here is a running example based on @oborchers example that produces (s x m x n) array:

import time
import numpy as np
import tensorflow as tf


def jacobian(session, y, x, points, batch_size, as_matrix=True):
    """The Jacobian matrix of `y` w.r.t. `x` at `points`

    Let f(x) be some function that has a Jacobian A at point p
    then, f(p) = y = Ap+b
    where A of shape mxn, p of shape nx1 and b of shape mx1

    Args:
        y: The output tensor
        x: The input tensor
        points: The points of linearization where it can be many points
            of shape [num_points, *self.features_shape]
        batch_size: How many rows of the Jacobian to compute at once
        as_matrix: Whether to return the Jacobian as a matrix or retain
            the shape of the input

    Returns:
        The Jacobian matrices for the given points
        of shape [num_points, *jacobian_shape]
        If `as_matrix`, jacobian_shape is [y.size, *x.shape]
        else, jacobian_shape is [y.size, x.size]
    """
    # add and/or get cached ops to the graph
    if not hasattr(session.graph, "_placeholder"):
        session.graph._placeholder = {}
    if not hasattr(session.graph, "_gradient"):
        session.graph._gradient = {}
    with session.graph.as_default():
        if y.dtype in session.graph._placeholder:
            placeholder = session.graph._placeholder[y.dtype]
        else:
            placeholder = tf.placeholder(y.dtype)
            session.graph._placeholder[y.dtype] = placeholder

        if (y, x) in session.graph._gradient:
            gradient = session.graph._gradient[(y, x)]
        else:
            gradient = tf.gradients(placeholder * y, x)[0]
            session.graph._gradient[(y, x)] = gradient

    # extract the Jacobians for all points
    jacobians_list = []
    for i in range(points.shape[0]):
        # extract the Jacobian matrix for a single point
        partials_list = []
        point = points[i:i + 1, :]
        shape = y.shape.as_list()[1:]
        repeated_point = point
        for mask in masks_batches(shape, batch_size):
            # repeat the point according to the mask's batch_size
            batch_size = mask.shape[0]
            if repeated_point.shape[0] < batch_size:
                repeated_point = np.vstack([point] * batch_size)
            if repeated_point.shape[0] > batch_size:
                repeated_point = repeated_point[:batch_size, :]
            feed = {placeholder: mask, x: repeated_point}
            partial = session.run(gradient, feed_dict=feed)
            partials_list.append(partial)
        jacobian = np.vstack(partials_list)

        # reshape it as a matrix
        if as_matrix:
            jacobian = jacobian.reshape(jacobian.shape[0], -1)

        jacobians_list.append(jacobian)

    # stack Jacobians
    jacobians = np.stack(jacobians_list)

    return jacobians


def masks_batches(shape, batch_size):
    """Batches iterator over all possible masks of the given shape

    A mask is a numpy.ndarray of shape `shape` of all zeros except
    for a single position it is one. It is useful to get those masks
    in batches instead of getting them one by one.

    Args:
        shape: The shape of each mask
        batch_size: How many masks to return in each iteration

    Returns:
        A batch of masks of shape [batch_size, *shape]
    """
    num_rows = np.prod(shape)
    if num_rows < batch_size:
        batch_size = num_rows

    eye = np.eye(batch_size)
    _mask = np.zeros((batch_size, *shape))
    mask = _mask.reshape(batch_size, -1)

    num_batches = -(-num_rows // batch_size)
    for i in range(num_batches):
        start = i * batch_size
        end = min(start + batch_size, num_rows)

        # check if last batch is smaller than batch size
        if end - start < batch_size:
            batch_size = end - start
            eye = np.eye(batch_size)
            _mask = np.zeros((batch_size, *shape))
            mask = _mask.reshape(batch_size, -1)

        mask[:, start:end] = eye
        yield _mask
        mask[:, start:end] = 0


if __name__ == '__main__':
    m = 10
    n = 500
    s = 20

    x = tf.placeholder(tf.float32)
    w = tf.Variable(tf.truncated_normal([n, m], stddev=0.1))
    b = tf.Variable(tf.constant(0.1, shape=[m]))
    y = tf.matmul(x, w) + b

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    start = time.time()
    j_out = jacobian(sess, y, x, np.random.rand(s, n), m)
    w_out = sess.run(w)
    # they should be equal and error ~ < 1e-6 (single precision)
    error = np.linalg.norm(w_out.T - np.mean(j_out, axis=0))
    if error < 1e-6:
        print("Correct Jacobian!")
    else:
        print("Error was {}".format(error))
    print(str(int(time.time() - start)) + " Seconds: " + str(j_out.shape))
@candidj0

This comment has been minimized.

Show comment
Hide comment
@candidj0

candidj0 Oct 12, 2017

@oborchers i don't check the time but maybe you can try :

def body(y, x, i):
    n = tf.shape(y)[0]
    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]
    _, jacobian = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j+1, result.write(j, tf.gradients(y[j], x)[0][i])),
        loop_vars)
    return jacobian.stack()

def tf_jacobian(y, x, n):
    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]
    _, jacobian = tf.while_loop(
        lambda i, _: i < n,
        lambda i, result: (i+1, result.write(i, body(y[i], x, i))),
        loop_vars)
    return jacobian.stack()

jacobians = tf_jacobian(y, x, n) where n is the batch size.

candidj0 commented Oct 12, 2017

@oborchers i don't check the time but maybe you can try :

def body(y, x, i):
    n = tf.shape(y)[0]
    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]
    _, jacobian = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j+1, result.write(j, tf.gradients(y[j], x)[0][i])),
        loop_vars)
    return jacobian.stack()

def tf_jacobian(y, x, n):
    loop_vars = [
        tf.constant(0, tf.int32),
        tf.TensorArray(tf.float32, size=n),
    ]
    _, jacobian = tf.while_loop(
        lambda i, _: i < n,
        lambda i, result: (i+1, result.write(i, body(y[i], x, i))),
        loop_vars)
    return jacobian.stack()

jacobians = tf_jacobian(y, x, n) where n is the batch size.

@aribenjamin

This comment has been minimized.

Show comment
Hide comment
@aribenjamin

aribenjamin Oct 21, 2017

@ModarTensai this is great! Thank you.

It produces an error for me, though, when m=1. I'd like to use this to produce per-example gradients of a single-valued output, so my application will have m=1 as well. Is there an easy fix you know of, before I dig in?

aribenjamin commented Oct 21, 2017

@ModarTensai this is great! Thank you.

It produces an error for me, though, when m=1. I'd like to use this to produce per-example gradients of a single-valued output, so my application will have m=1 as well. Is there an easy fix you know of, before I dig in?

@ModarTensai

This comment has been minimized.

Show comment
Hide comment
@ModarTensai

ModarTensai Oct 21, 2017

@aribenjamin If m = 1, w is a vector not a matrix but I fixed and updated the code
all what I needed to change is line 53 from

repeated_point = np.zeros(1)

to

repeated_point = point

Thanks for noticing the bug!

ModarTensai commented Oct 21, 2017

@aribenjamin If m = 1, w is a vector not a matrix but I fixed and updated the code
all what I needed to change is line 53 from

repeated_point = np.zeros(1)

to

repeated_point = point

Thanks for noticing the bug!

@oborchers

This comment has been minimized.

Show comment
Hide comment
@oborchers

oborchers Nov 9, 2017

@ModarTensai and @candidj0 : Awesome guys! You rock, thank you! I'll test them over the weekend and see how they work on a toy dataset.

oborchers commented Nov 9, 2017

@ModarTensai and @candidj0 : Awesome guys! You rock, thank you! I'll test them over the weekend and see how they work on a toy dataset.

@czhang96

This comment has been minimized.

Show comment
Hide comment
@czhang96

czhang96 Dec 5, 2017

There are no built-in jacobians in TensorFlow, instead anything called
'grad' or 'gradient' computes Jacobian-vector product (also called LOp in
theano),

Sanity checking, but LOp is vector-Jacobian product, not Jacobian-vector product, correct?

czhang96 commented Dec 5, 2017

There are no built-in jacobians in TensorFlow, instead anything called
'grad' or 'gradient' computes Jacobian-vector product (also called LOp in
theano),

Sanity checking, but LOp is vector-Jacobian product, not Jacobian-vector product, correct?

@mholzel

This comment has been minimized.

Show comment
Hide comment
@mholzel

mholzel Feb 3, 2018

Contributor

Nobody seems to have posted any followup, but the codes proposed by @candidj0 and @jeisses do not work when nested (testing in tensorflow 1.5.0). So computing the hessian by nesting will not work (@tillahoffmann). Let me make this a bit more concrete. I am using the jacobian

def map(f, x, dtype=None, parallel_iterations=10):
    '''
    Apply f to each of the elements in x using the specified number of parallel iterations.

    Important points:
    1. By "elements in x", we mean that we will be applying f to x[0],...x[tf.shape(x)[0]-1].
    2. The output size of f(x[i]) can be arbitrary. However, if the dtype of that output
       is different than the dtype of x, then you need to specify that as an additional argument.
    '''
    if dtype is None:
        dtype = x.dtype

    n = tf.shape(x)[0]
    loop_vars = [
        tf.constant(0, n.dtype),
        tf.TensorArray(dtype, size=n),
    ]
    _, fx = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j + 1, result.write(j, f(x[j]))),
        loop_vars,
        parallel_iterations=parallel_iterations
    )
    return fx.stack()

def jacobian(fx, x, parallel_iterations=10):
    '''
    Given a tensor fx, which is a function of x, vectorize fx (via tf.reshape(fx, [-1])),
    and then compute the jacobian of each entry of fx with respect to x.
    Specifically, if x has shape (m,n,...,p), and fx has L entries (tf.size(fx)=L), then
    the output will be (L,m,n,...,p), where output[i] will be (m,n,...,p), with each entry denoting the
    gradient of output[i] wrt the corresponding element of x.
    '''
    return map(lambda fxi: tf.gradients(fxi, x)[0],
               tf.reshape(fx, [-1]),
               dtype=x.dtype,
               parallel_iterations=parallel_iterations)

I am using this because it supports dynamic sizes, that is, one of the dimensions can be None.

However, this simple test

    import numpy
    from numpy.random import randn

    numpy.random.seed(0)

    # Here is how everything would look in numpy
    x = randn(3, 3)
    A = randn(2, 3)
    y = numpy.dot(A,numpy.dot(x,x))

    # and in tensorflow... 
    xtf = tf.constant(x, tf.float64)
    Atf = tf.constant(A, tf.float64)
    ytf = tf.matmul(tf.matmul(Atf, xtf), xtf)

    with tf.Session() as sess:

        # Now let's try to compute the jacobian 
        dydx = jacobian(ytf, xtf)
        print(sess.run(dydx))
        
        # and the hessian... 
        d2ydx2 = tf.squeeze(jacobian(dydx, xtf))
        print(sess.run(d2ydx2))

throws the error

ValueError: Cannot use 'while_1/while/gradients/f_count_1' as input to 'while_1/while/gradients/f_count' because they are in different while loops. See info log for more details.

Does anybody know the issue here?
The first jacobian is correct. The second one (essentially the hessian) throws an error.

Contributor

mholzel commented Feb 3, 2018

Nobody seems to have posted any followup, but the codes proposed by @candidj0 and @jeisses do not work when nested (testing in tensorflow 1.5.0). So computing the hessian by nesting will not work (@tillahoffmann). Let me make this a bit more concrete. I am using the jacobian

def map(f, x, dtype=None, parallel_iterations=10):
    '''
    Apply f to each of the elements in x using the specified number of parallel iterations.

    Important points:
    1. By "elements in x", we mean that we will be applying f to x[0],...x[tf.shape(x)[0]-1].
    2. The output size of f(x[i]) can be arbitrary. However, if the dtype of that output
       is different than the dtype of x, then you need to specify that as an additional argument.
    '''
    if dtype is None:
        dtype = x.dtype

    n = tf.shape(x)[0]
    loop_vars = [
        tf.constant(0, n.dtype),
        tf.TensorArray(dtype, size=n),
    ]
    _, fx = tf.while_loop(
        lambda j, _: j < n,
        lambda j, result: (j + 1, result.write(j, f(x[j]))),
        loop_vars,
        parallel_iterations=parallel_iterations
    )
    return fx.stack()

def jacobian(fx, x, parallel_iterations=10):
    '''
    Given a tensor fx, which is a function of x, vectorize fx (via tf.reshape(fx, [-1])),
    and then compute the jacobian of each entry of fx with respect to x.
    Specifically, if x has shape (m,n,...,p), and fx has L entries (tf.size(fx)=L), then
    the output will be (L,m,n,...,p), where output[i] will be (m,n,...,p), with each entry denoting the
    gradient of output[i] wrt the corresponding element of x.
    '''
    return map(lambda fxi: tf.gradients(fxi, x)[0],
               tf.reshape(fx, [-1]),
               dtype=x.dtype,
               parallel_iterations=parallel_iterations)

I am using this because it supports dynamic sizes, that is, one of the dimensions can be None.

However, this simple test

    import numpy
    from numpy.random import randn

    numpy.random.seed(0)

    # Here is how everything would look in numpy
    x = randn(3, 3)
    A = randn(2, 3)
    y = numpy.dot(A,numpy.dot(x,x))

    # and in tensorflow... 
    xtf = tf.constant(x, tf.float64)
    Atf = tf.constant(A, tf.float64)
    ytf = tf.matmul(tf.matmul(Atf, xtf), xtf)

    with tf.Session() as sess:

        # Now let's try to compute the jacobian 
        dydx = jacobian(ytf, xtf)
        print(sess.run(dydx))
        
        # and the hessian... 
        d2ydx2 = tf.squeeze(jacobian(dydx, xtf))
        print(sess.run(d2ydx2))

throws the error

ValueError: Cannot use 'while_1/while/gradients/f_count_1' as input to 'while_1/while/gradients/f_count' because they are in different while loops. See info log for more details.

Does anybody know the issue here?
The first jacobian is correct. The second one (essentially the hessian) throws an error.

@mholzel

This comment has been minimized.

Show comment
Hide comment
@mholzel

mholzel Feb 9, 2018

Contributor

I just tested the previous code on the nightly Docker build, and the error remains.

Contributor

mholzel commented Feb 9, 2018

I just tested the previous code on the nightly Docker build, and the error remains.

@dancasas

This comment has been minimized.

Show comment
Hide comment
@dancasas

dancasas Apr 24, 2018

@mholzel I wanted to thank you for that jacobian implementation, seems to work perfectly for a problem I am working on. In my opinion, this capability should be pushed to the main TensorFlow branch, as it is useful in many problems.

dancasas commented Apr 24, 2018

@mholzel I wanted to thank you for that jacobian implementation, seems to work perfectly for a problem I am working on. In my opinion, this capability should be pushed to the main TensorFlow branch, as it is useful in many problems.

@marcociccone

This comment has been minimized.

Show comment
Hide comment
@marcociccone

marcociccone Apr 27, 2018

@mholzel Have you found how to nest while loop to make it work?

marcociccone commented Apr 27, 2018

@mholzel Have you found how to nest while loop to make it work?

@mholzel

This comment has been minimized.

Show comment
Hide comment
@mholzel

mholzel Apr 27, 2018

Contributor

No. It seems like a really subtle bug in either the gradient or while loop implementation. That requires somebody with a lot more knowledge than me about what is going on there

Contributor

mholzel commented Apr 27, 2018

No. It seems like a really subtle bug in either the gradient or while loop implementation. That requires somebody with a lot more knowledge than me about what is going on there

@marcociccone

This comment has been minimized.

Show comment
Hide comment
@marcociccone

marcociccone Apr 27, 2018

damn, thanks anyway... Is there anyone that can help? maybe we can bother @ebrevdo ?

marcociccone commented Apr 27, 2018

damn, thanks anyway... Is there anyone that can help? maybe we can bother @ebrevdo ?

@dancasas

This comment has been minimized.

Show comment
Hide comment
@dancasas

dancasas Apr 27, 2018

@marcociccone in TF 1.4 seems to be working fine, if this is an alternative for you.

dancasas commented Apr 27, 2018

@marcociccone in TF 1.4 seems to be working fine, if this is an alternative for you.

@skye

This comment has been minimized.

Show comment
Hide comment
@skye

skye May 4, 2018

Member

This is indeed a bug in TF. It's caused by taking the gradient(y, x) inside a while loop wrt such that the computation of y from x goes through a different while loop, something like:

x = ...
y = while_loop(..., [x])
z = while_loop(..., tf.gradients(y, x), [y])

So in @mholzel's script, it's from passing the outcome of one jacobian call to the other jacobian call. (BTW, thanks very much for the easy repro.)

Unfortunately this is quite tricky to fix. I'll try to take another look at it tomorrow and see if I can come up with something.

Member

skye commented May 4, 2018

This is indeed a bug in TF. It's caused by taking the gradient(y, x) inside a while loop wrt such that the computation of y from x goes through a different while loop, something like:

x = ...
y = while_loop(..., [x])
z = while_loop(..., tf.gradients(y, x), [y])

So in @mholzel's script, it's from passing the outcome of one jacobian call to the other jacobian call. (BTW, thanks very much for the easy repro.)

Unfortunately this is quite tricky to fix. I'll try to take another look at it tomorrow and see if I can come up with something.

@skye

This comment has been minimized.

Show comment
Hide comment
@skye

skye May 5, 2018

Member

This is actually extremely tricky to fix. I'm not sure how this was working in 1.4, was it definitely giving the right answer?

The fundamental problem is that the second jacobian call (i.e. the hessian) is calling tf.gradients() inside a while loop, and that backprop calculation must go through the while loop from the first jacobian call. TF computes while loop gradients using stacks to store intermediate loop values, so if you're doing that calculation multiple times via another loop, we'd have to somehow re-use the stack values on each iteration. This is conceptually possible but would be a pretty big change. I can at least try to improve the error message though.

Member

skye commented May 5, 2018

This is actually extremely tricky to fix. I'm not sure how this was working in 1.4, was it definitely giving the right answer?

The fundamental problem is that the second jacobian call (i.e. the hessian) is calling tf.gradients() inside a while loop, and that backprop calculation must go through the while loop from the first jacobian call. TF computes while loop gradients using stacks to store intermediate loop values, so if you're doing that calculation multiple times via another loop, we'd have to somehow re-use the stack values on each iteration. This is conceptually possible but would be a pretty big change. I can at least try to improve the error message though.

@mholzel

This comment has been minimized.

Show comment
Hide comment
@mholzel

mholzel May 5, 2018

Contributor

I have never seen the nested call work, but I did not try 1.4.

Contributor

mholzel commented May 5, 2018

I have never seen the nested call work, but I did not try 1.4.

@Joshuaalbert

This comment has been minimized.

Show comment
Hide comment
@Joshuaalbert

Joshuaalbert Jun 12, 2018

@mholzel for the double jacobian, suppose you just map to tf.hessian instead of tf.gradient. Doesn't solve to arbitrary order but it does get the hessian of a tensor wrt variables.

Joshuaalbert commented Jun 12, 2018

@mholzel for the double jacobian, suppose you just map to tf.hessian instead of tf.gradient. Doesn't solve to arbitrary order but it does get the hessian of a tensor wrt variables.

@agarwal-ashish

This comment has been minimized.

Show comment
Hide comment

agarwal-ashish commented Jul 12, 2018

There is now an experimental new approach to doing Jacobians here:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/parallel_for/gradients.py#L28

@tensorflowbutler

This comment has been minimized.

Show comment
Hide comment
@tensorflowbutler

tensorflowbutler Aug 15, 2018

Member

Nagging Assignee @skye: It has been 32 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

Member

tensorflowbutler commented Aug 15, 2018

Nagging Assignee @skye: It has been 32 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment