diff --git a/official/mnist/dataset.py b/official/mnist/dataset.py index 0ba7c1a8bda..6b12d7eda41 100644 --- a/official/mnist/dataset.py +++ b/official/mnist/dataset.py @@ -97,7 +97,7 @@ def decode_image(image): def decode_label(label): label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] label = tf.reshape(label, []) # label is a scalar - return tf.to_int32(label) + return tf.cast(label, tf.int32) images = tf.data.FixedLengthRecordDataset( images_file, 28 * 28, header_bytes=16).map(decode_image)