-
Notifications
You must be signed in to change notification settings - Fork 75.3k
tf.matmul gives wrong result on CPUs with avx512_vnni #61510
Description
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
No
Source
binary
TensorFlow version
2.13.0
Custom code
No
OS platform and distribution
Linux
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
When running on a CPU with avx512_vnni instructions (e.g. Xeon Platinum 8260), tf.matmul in 2.13 gives a completely wrong result that also changes from run to run. Other functions, e.g. tf.einsum, are affected too. A reproducer is included below. 2.12 is working correctly.
I believe this is due to a bug in oneDNN, since running with TF_ENABLE_ONEDNN_OPTS=0 restores the correct behaviour.
Could you please try building TF against the latest oneDNN to see if this bug is already fixed there, and either upgrade oneDNN or revert back to the version used in 2.12? If this bug is still present in latest oneDNN, could you also forward this issue to them so that they can fix it? tf.matmul is such a fundamental part of TensorFlow, so it would be great to have a fix for this as soon as possible (and perhaps add further tests to prevent this kind of bug from reoccurring?)
Standalone code to reproduce the issue
import tensorflow as tf
tf.keras.utils.set_random_seed(1)
length = 5000
x = tf.concat([tf.ones([length, 1]), tf.random.normal([length, 2])], axis=1)
x = tf.tile(x[None, ...], [3, 1, 1])
xx = tf.matmul(x, x, transpose_a=True)
# xx = tf.einsum("ijk,ijm->ikm", x, x) # Also doesn't work
print(f"{xx.numpy()}")Relevant log output
[[[ 2.0936000e+04 4.1334631e+02 8.8164221e+02]
[ 4.1334631e+02 2.0951623e+04 5.5098944e+02]
[ 8.1284619e+02 4.7815466e+02 1.9981070e+04]]
[[ 0.0000000e+00 6.8663625e-44 0.0000000e+00]
[ 2.7628342e-35 0.0000000e+00 2.6752920e-35]
[ 0.0000000e+00 2.3822074e-44 0.0000000e+00]]
[[-7.9164143e+31 8.9978802e+02 1.7936620e-43]
[ 0.0000000e+00 1.1210388e-43 0.0000000e+00]
[ 2.7767702e-35 0.0000000e+00 7.0064923e-45]]]