-
Notifications
You must be signed in to change notification settings - Fork 75.3k
sparse_softmax_cross_entropy_with_logits returns NaN instead of raising an error when the class int is not in the range [0, logit_size) when run on GPU #8484
Description
If you provide to tf.nn.tf.nn.sparse_softmax_cross_entropy_with_logits() labels that are not inside the number of classes (that is the length of the logits, or the length of the second dimension of the logits if using the first dimension for batching), you get a loss of NaN and then everything in the model becomes NaN. I spent quite some time debugging this in my model and I believe that this should not be silent, there should be at least a warning at execution time.
More detail I just discovered: if you run it on CPU an InvalidAgrumentError is raised, if you run it on GPU, you get NaN, so you may want to fix only the GPU implementation to behave like the CPU one.
Environment info
Operating System: Ubuntu 16.04 64bit
Installed version of CUDA and cuDNN: cuda 8.0 cudnn 5.1
Tensorflow version: 1.0.0
If possible, provide a minimal reproducible example (We usually don't have time to read hundreds of lines of your code)
Run it with CUDA_VISIBLE_DEVICE="" to see the difference on CPU and GPU.
import numpy as np
import tensorflow as tf
classes = 5
num_datapoints = 100
xs = np.random.rand(num_datapoints,4)
ys = np.random.randint(classes + 1, size=num_datapoints)
graph = tf.Graph()
with graph.as_default():
x = tf.placeholder(tf.float32, shape=[None, 4])
y = tf.placeholder(tf.int32, shape=[None, ])
w = tf.Variable(tf.random_normal([4, classes]))
b = tf.Variable(tf.random_normal([classes]))
logits = tf.matmul(x, w) + b
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y))
optimizer = tf.train.AdamOptimizer(0.01).minimize(loss)
with tf.Session(graph=graph) as session:
tf.global_variables_initializer().run()
for step in range(201):
_, loss_val = session.run(
[optimizer, loss],
feed_dict={x: xs, y: ys})
print("step {step} - loss: {loss_val:.6f}".format(step=step, loss_val=loss_val))