New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in hessians() and _hessian_vector_product() for tf.nn.sparse_softmax_cross_entropy_with_logits() #5876

Closed
kohpangwei opened this Issue Nov 27, 2016 · 16 comments

Comments

Projects
None yet
7 participants
@kohpangwei

kohpangwei commented Nov 27, 2016

On a simple model that implements logistic regression, constructing the loss using tf.nn.sparse_softmax_cross_entropy_with_logits() makes both hessians() and _hessian_vector_product() return identically zero vectors, which is incorrect. If I instead write the loss function manually using tf.log, tf.sigmoid, etc., hessians() and _hessian_vector_product return the correct answer. These two versions of the loss function agree on their values and their gradients; however, the Hessian output is different.

Here is some sample output:

Using sparse_softmax_cross_entropy_with_logits:
Loss before first step: 0.686726
Loss after first step : 0.686181
Actual diff in grad:
[ 0.000122    0.00014928]
Predicted diff in grad using _hessian_vector_product:
[array([ 0.,  0.], dtype=float32)]
Hessian:
[array([[ 0.,  0.],
       [ 0.,  0.]], dtype=float32)]

Using custom loss function:
Loss before first step: 0.686726
Loss after first step : 0.686181
Actual diff in grad:
[ 0.00012201  0.00014931]
Predicted diff in grad using _hessian_vector_product:
[array([ 0.00012199,  0.0001493 ], dtype=float32)]
Hessian:
[array([[ 0.08229966,  0.        ],
       [ 0.        ,  0.08278375]], dtype=float32)]

What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?

None that I can find. The code below uses hessians() and _hessian_vector_product() from https://github.com/tensorflow/tensorflow/blob/a4c8df209d7413068f4ed3e71c43eb798fbd5580/tensorflow/python/ops/gradients_impl.py

Here is the PR that implemented hessians(): #5329

Environment info

Operating System: Ubuntu 16.04

Installed version of CUDA and cuDNN:

-rw-r--r-- 1 root root   558720 Oct  1 00:18 /usr/local/cuda/lib64/libcudadevrt.a
lrwxrwxrwx 1 root root       16 Oct  1 00:18 /usr/local/cuda/lib64/libcudart.so -> libcudart.so.8.0
lrwxrwxrwx 1 root root       19 Oct  1 00:18 /usr/local/cuda/lib64/libcudart.so.8.0 -> libcudart.so.8.0.44
-rwxr-xr-x 1 root root   415432 Oct  1 00:18 /usr/local/cuda/lib64/libcudart.so.8.0.44
-rw-r--r-- 1 root root   775162 Oct  1 00:18 /usr/local/cuda/lib64/libcudart_static.a
-rwxr-xr-x 1 root root 78065952 Oct  1 16:19 /usr/local/cuda/lib64/libcudnn.so
-rwxr-xr-x 1 root root 78065952 Oct  1 16:19 /usr/local/cuda/lib64/libcudnn.so.5
-rwxr-xr-x 1 root root 78065952 Oct  1 16:19 /usr/local/cuda/lib64/libcudnn.so.5.0.5
-rw-r--r-- 1 root root 68709594 Oct  1 16:19 /usr/local/cuda/lib64/libcudnn_static.a

The same behavior occurs when running on CPU only.

Installed from: https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.11.0-cp27-none-linux_x86_64.whl
v0.11.0

If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)

tf.set_random_seed(0)

### Setup toy data and weights
images_placeholder = tf.placeholder(tf.float32, shape=(3, 2))
labels_placeholder = tf.placeholder(tf.int32, shape=(3))
feed_dict = {
    images_placeholder: np.array([[0, 0], [0, 1], [1, 0]]),
    labels_placeholder: np.array([0, 1, 1]),
}
  
weights = tf.Variable(
  tf.truncated_normal([2],
                      stddev=1.0 / math.sqrt(float(2))),
  name='weights')

### Calculate loss using built-in TF function
weights_with_zeros = tf.pack([tf.zeros([2]), weights], axis=1)
logits = tf.matmul(images_placeholder, weights_with_zeros)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels_placeholder)
loss = tf.reduce_mean(cross_entropy)

### Calculate loss using manually constructed TF function
logits2 = tf.matmul(images_placeholder, tf.reshape(weights, [2, 1]))
labels2 = (tf.to_float(labels_placeholder) * 2) - 1
logits_mul_labels = tf.mul(tf.reshape(logits2, [-1]), tf.reshape(labels2, [-1]))
cross_entropy2 = - tf.log(tf.sigmoid(logits_mul_labels))
loss2 = tf.reduce_mean(cross_entropy2)

### Create train_op
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

### Calculate gradients, Hessians, and Hessian-vector products for both versions of loss
grad = tf.gradients(loss, [weights])
grad2 = tf.gradients(loss2, [weights])
v_placeholder = tf.placeholder(tf.float32, shape=weights.get_shape())
hessian_vector = _hessian_vector_product(loss, [weights], [v_placeholder])
hessian_vector2 = _hessian_vector_product(loss2, [weights], [v_placeholder])
hessian = hessians(loss, [weights])
hessian2 = hessians(loss2, [weights])

### Run training for a single step to get the parameters to change.
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

old_weights_val, old_loss_val, old_grad_val, old_loss2_val, old_grad2_val= sess.run(
  [weights, loss, grad, loss2, grad2], 
  feed_dict=feed_dict)

_ = sess.run(train_op, feed_dict=feed_dict)

new_weights_val, new_loss_val, new_grad_val, new_loss2_val, new_grad2_val = sess.run(
  [weights, loss, grad, loss2, grad2], 
  feed_dict=feed_dict)

hessian_val, hessian2_val = sess.run(
  [hessian, hessian2], 
  feed_dict=feed_dict)

### Calculate the actual difference in gradients before and after the train step,
### and compare with the predicted difference in gradients based on the Hessian.
diff_in_weights = new_weights_val - old_weights_val
actual_diff_in_grad = new_grad_val[0] - old_grad_val[0]
actual_diff_in_grad2 = new_grad2_val[0] - old_grad2_val[0]

feed_dict[v_placeholder] = diff_in_weights
predicted_diff_in_grad = sess.run(hessian_vector, feed_dict=feed_dict)
predicted_diff_in_grad2 = sess.run(hessian_vector2, feed_dict=feed_dict)

print('Diff in weights:\n%s' % diff_in_weights)

print('\nUsing sparse_softmax_cross_entropy_with_logits:')
print('Loss before first step: %s' % old_loss_val)
print('Loss after first step : %s' % new_loss_val)
print('Actual diff in grad:\n%s' % actual_diff_in_grad)
print('Predicted diff in grad using _hessian_vector_product:\n%s' % predicted_diff_in_grad)
print('Hessian:\n%s' % hessian_val)

print('\nUsing custom loss function:')
print('Loss before first step: %s' % old_loss2_val)
print('Loss after first step : %s' % new_loss2_val)
print('Actual diff in grad:\n%s' % actual_diff_in_grad2)
print('Predicted diff in grad using _hessian_vector_product:\n%s' % predicted_diff_in_grad2)
print('Hessian:\n%s' % hessian2_val)

sess.close()

What other attempted solutions have you tried?

Running in CPU or GPU makes no difference.

Using more complicated networks (i.e., adding some non-linear hidden layers before the linear softmax step) makes the Hessian returned from sparse_softmax_cross_entropy_with_logits() non-zero, but the returned value is still wrong in the sense that it does not match the empirical values. In contrast, using the same custom loss function above returns the correct Hessians.

The same problem occurs when using "real" data (e.g., MNIST) or with more examples.

Logs or other output that would be helpful

Full output when using CPU:

Diff in weights:
[ 0.00148226  0.0018035 ]

Using sparse_softmax_cross_entropy_with_logits:
Loss before first step: 0.686726
Loss after first step : 0.686181
Actual diff in grad:
[ 0.000122    0.00014928]
Predicted diff in grad using _hessian_vector_product:
[array([ 0.,  0.], dtype=float32)]
Hessian:
[array([[ 0.,  0.],
       [ 0.,  0.]], dtype=float32)]

Using custom loss function:
Loss before first step: 0.686726
Loss after first step : 0.686181
Actual diff in grad:
[ 0.00012201  0.00014931]
Predicted diff in grad using _hessian_vector_product:
[array([ 0.00012199,  0.0001493 ], dtype=float32)]
Hessian:
[array([[ 0.08229966,  0.        ],
       [ 0.        ,  0.08278375]], dtype=float32)]
@prb12

This comment has been minimized.

Show comment
Hide comment
@prb12

prb12 Nov 27, 2016

Member

@goodfeli, @tillahoffmann, @vrv This is possibly related to #5329 which you were all involved with. Could somebody please comment on the above?

Member

prb12 commented Nov 27, 2016

@goodfeli, @tillahoffmann, @vrv This is possibly related to #5329 which you were all involved with. Could somebody please comment on the above?

@tillahoffmann

This comment has been minimized.

Show comment
Hide comment
@tillahoffmann

tillahoffmann Nov 27, 2016

Contributor

Will have a look tomorrow.

Contributor

tillahoffmann commented Nov 27, 2016

Will have a look tomorrow.

@tillahoffmann

This comment has been minimized.

Show comment
Hide comment
@tillahoffmann

tillahoffmann Nov 28, 2016

Contributor

I won't have time to dig into this properly for a while but here's a hunch: The sparse_softmax_cross_entropy_with_logits presumably indexes the tensor of logits and some of the indexing operations don't have gradients defined (cf. #206). It may be possible to compute the first derivative of a function but the second derivative may not yet be implemented. @benoitsteiner and @yuefengz probably have a better understanding of the implementation.

Contributor

tillahoffmann commented Nov 28, 2016

I won't have time to dig into this properly for a while but here's a hunch: The sparse_softmax_cross_entropy_with_logits presumably indexes the tensor of logits and some of the indexing operations don't have gradients defined (cf. #206). It may be possible to compute the first derivative of a function but the second derivative may not yet be implemented. @benoitsteiner and @yuefengz probably have a better understanding of the implementation.

@goodfeli

This comment has been minimized.

Show comment
Hide comment
@goodfeli

goodfeli Nov 29, 2016

I agree with @tillahoffmann that this sounds like the gradient of sparse_softmax_cross_entropy_with_logits does not implement the gradient correctly.

If this is indeed caused by an op failing to implement the gradient, it seems like tf.gradients ought to raise a NotImplementedError or issue some other error message, rather than silently returning numerically incorrect values.

Vincent Dumoulin and Alex Kurakin have both told me they have had trouble with ops silently returning zero as their second derivative in the past.

goodfeli commented Nov 29, 2016

I agree with @tillahoffmann that this sounds like the gradient of sparse_softmax_cross_entropy_with_logits does not implement the gradient correctly.

If this is indeed caused by an op failing to implement the gradient, it seems like tf.gradients ought to raise a NotImplementedError or issue some other error message, rather than silently returning numerically incorrect values.

Vincent Dumoulin and Alex Kurakin have both told me they have had trouble with ops silently returning zero as their second derivative in the past.

@kohpangwei

This comment has been minimized.

Show comment
Hide comment
@kohpangwei

kohpangwei Nov 29, 2016

Thanks for looking into this! On further testing, it seems like the problem is not localized to sparse_softmax_cross_entropy_with_logits. Implementing the same example using softmax_cross_entropy_with_logits (i.e., changing the labels from [0, 1, 1] to [[1, 0], [0, 1], [0, 1]]) has the same issue of a zero second derivative. So the problem might be in the softmax function itself and not the sparse indexing.

kohpangwei commented Nov 29, 2016

Thanks for looking into this! On further testing, it seems like the problem is not localized to sparse_softmax_cross_entropy_with_logits. Implementing the same example using softmax_cross_entropy_with_logits (i.e., changing the labels from [0, 1, 1] to [[1, 0], [0, 1], [0, 1]]) has the same issue of a zero second derivative. So the problem might be in the softmax function itself and not the sparse indexing.

@goodfeli

This comment has been minimized.

Show comment
Hide comment
@goodfeli

goodfeli Nov 29, 2016

The softmax function shouldn't actually be involved in the implementation of either of those costs (the numerically stable way to implement them involves simplifying log softmax x to x - log sum exp x), so I doubt that it's a problem with the softmax op.

goodfeli commented Nov 29, 2016

The softmax function shouldn't actually be involved in the implementation of either of those costs (the numerically stable way to implement them involves simplifying log softmax x to x - log sum exp x), so I doubt that it's a problem with the softmax op.

@kohpangwei

This comment has been minimized.

Show comment
Hide comment
@kohpangwei

kohpangwei Nov 29, 2016

Oh, right. Sorry, I was sloppy with that statement. I meant to say that the problem is probably in the (presumably shared) softmax_cross_entropy_with_logits part of the function and not in the additional indexing operations that sparse_softmax_cross_entropy_with_logits does on top of that, which @tillahoffmann suggested was the culprit.

kohpangwei commented Nov 29, 2016

Oh, right. Sorry, I was sloppy with that statement. I meant to say that the problem is probably in the (presumably shared) softmax_cross_entropy_with_logits part of the function and not in the additional indexing operations that sparse_softmax_cross_entropy_with_logits does on top of that, which @tillahoffmann suggested was the culprit.

@tillahoffmann

This comment has been minimized.

Show comment
Hide comment
@tillahoffmann

tillahoffmann Nov 29, 2016

Contributor

Hm, I'm computing Hessians of cross entropy losses in my work and don't seem to have a problem with them. Will have a look into a minimum working example and get back to you.

Contributor

tillahoffmann commented Nov 29, 2016

Hm, I'm computing Hessians of cross entropy losses in my work and don't seem to have a problem with them. Will have a look into a minimum working example and get back to you.

@vrv

This comment has been minimized.

Show comment
Hide comment
@vrv

vrv Nov 29, 2016

Contributor

According to the 0.12 release notes, there are some ops whose second gradients have been fixed, so it might be worth upgrading to 0.12rc0 (just released last night) and see if the problem persists

Contributor

vrv commented Nov 29, 2016

According to the 0.12 release notes, there are some ops whose second gradients have been fixed, so it might be worth upgrading to 0.12rc0 (just released last night) and see if the problem persists

@kohpangwei

This comment has been minimized.

Show comment
Hide comment
@kohpangwei

kohpangwei Nov 30, 2016

Thanks for the tip! I just tested the above example in 0.12rc0 but unfortunately the behavior is unchanged (zeros for Hessians).

kohpangwei commented Nov 30, 2016

Thanks for the tip! I just tested the above example in 0.12rc0 but unfortunately the behavior is unchanged (zeros for Hessians).

@vrv vrv added the bug label Dec 6, 2016

@ebrevdo

This comment has been minimized.

Show comment
Hide comment
@ebrevdo

ebrevdo Dec 7, 2016

Contributor

This is unfortunately a known bug and is caused by the fact that the cross_entropy_loss ops are fused: they calculate fw and bprop values at the same time and return them together (the bprop is hidden from the user but is used only during tf.gradients call). This doesn't work well with TF's gradient registration mechanism when you want the second derivative. Solutions may be forthcoming, but no promises on the timeline.

Contributor

ebrevdo commented Dec 7, 2016

This is unfortunately a known bug and is caused by the fact that the cross_entropy_loss ops are fused: they calculate fw and bprop values at the same time and return them together (the bprop is hidden from the user but is used only during tf.gradients call). This doesn't work well with TF's gradient registration mechanism when you want the second derivative. Solutions may be forthcoming, but no promises on the timeline.

@kohpangwei

This comment has been minimized.

Show comment
Hide comment
@kohpangwei

kohpangwei Dec 7, 2016

Ok. Thank you for the update!

kohpangwei commented Dec 7, 2016

Ok. Thank you for the update!

@goodfeli

This comment has been minimized.

Show comment
Hide comment
@goodfeli

goodfeli Dec 7, 2016

@ebrevdo Do you have a link to the github issue for this? More generally, is there a way to at least cause a NotImplementedError rather than returning the wrong numbers?

goodfeli commented Dec 7, 2016

@ebrevdo Do you have a link to the github issue for this? More generally, is there a way to at least cause a NotImplementedError rather than returning the wrong numbers?

@vrv

This comment has been minimized.

Show comment
Hide comment
@vrv

vrv Dec 7, 2016

Contributor

@goodfeli this is the only github issue afaik (the other one is internal and doesn't have much more info).

I have one idea about how to get this to fail loudly when differentiating twice, I'll talk to Eugene today and see if it's possible in the short-term. One workaround for now is to not use the fused cross entropy functions and use their stable, primitive-op expansion. It will be slower but at least your results will be more correct.

Contributor

vrv commented Dec 7, 2016

@goodfeli this is the only github issue afaik (the other one is internal and doesn't have much more info).

I have one idea about how to get this to fail loudly when differentiating twice, I'll talk to Eugene today and see if it's possible in the short-term. One workaround for now is to not use the fused cross entropy functions and use their stable, primitive-op expansion. It will be slower but at least your results will be more correct.

andrewharp pushed a commit to andrewharp/tensorflow that referenced this issue Dec 9, 2016

Fail if second gradient is called on (sparse_)softmax_cross_entropy_w…
…ith_logits

since it is known to not be correct.

The construction of these fused loss ops does not work with our tf.gradients
interface when taking second derivatives (e.g., for hessians).

This introduces an op that produces the identity on the forward pass,
but has no gradient registered for the backward pass.  So the first
derivative passes, but the second one doesn't.

This op (PreventGradient) may prove useful in the future, so it seems okay
to add.

Partially addresses tensorflow#5876 in that it will no longer silently do the
wrong thing, though supporting the second derivative for this function
is something we want to do.
Change: 141578951
@vrv

This comment has been minimized.

Show comment
Hide comment
@vrv

vrv Feb 1, 2017

Contributor

Okay, with df52532 we now fail loudly when you try to take second derivatives of ops that can't be done that way. I think we're just going to rely on removing these fused ops via XLA one day as the eventual solution.

Contributor

vrv commented Feb 1, 2017

Okay, with df52532 we now fail loudly when you try to take second derivatives of ops that can't be done that way. I think we're just going to rely on removing these fused ops via XLA one day as the eventual solution.

@vrv vrv closed this Feb 1, 2017

@kohpangwei

This comment has been minimized.

Show comment
Hide comment
@kohpangwei

kohpangwei Feb 2, 2017

Cool. Thanks!

kohpangwei commented Feb 2, 2017

Cool. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment