-
Notifications
You must be signed in to change notification settings - Fork 74.9k
Closed
Labels
TF 2.13For issues related to Tensorflow 2.13For issues related to Tensorflow 2.13comp:opsOPs related issuesOPs related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:bugBugBug
Description
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
2.13 / 2.14.0-dev20230706
Custom code
Yes
OS platform and distribution
Linux Ubuntu 20.04
Mobile device
No response
Python version
3.8
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
$ mamba list | grep cu
cuda-nvcc 12.2.128 0 nvidia
cudatoolkit 11.8.0 h4ba93d1_12 conda-forge
nvidia-cublas-cu11 2022.4.8 pypi_0 pypi
nvidia-cublas-cu117 11.10.1.25 pypi_0 pypi
nvidia-cuda-nvrtc-cu11 2022.4.8 pypi_0 pypi
nvidia-cuda-nvrtc-cu117 11.7.50 pypi_0 pypi
nvidia-cudnn-cu11 8.9.4.25 pypi_0 pypi
GPU model and memory
NVIDIA GeForce RTX 3060 - 12 GB
Driver Version: 520.61.05
Current behavior?
I'm trying to do a batch matrix multiply (i.e. I've got a bunch of m x n matrices, all in one tensor, and I want to do a matrix multiply of each of them with some other matrix). However, I'm getting slightly different results than I get from Numpy (and previous TensorFlow versions, this script passed for me in 2.9).
One interesting thing is that the value 25 (the second dimension of x
) is significant. If I reduce this to 16 or below, it passes. Also, if I change the second dimension of c
from 2 to 1, it also passes.
Standalone code to reproduce the issue
import numpy as np
import tensorflow as tf
print(f"file: {tf.__file__}")
print(f"git version: {tf.version.GIT_VERSION}")
print(f"version: {tf.__version__}")
rng = np.random.RandomState(0)
x = rng.uniform(-1, 1, size=(1, 25, 5)).astype(np.float32)
c = rng.uniform(-1, 1, size=(5, 2)).astype(np.float32)
y = x @ c
x2 = np.tile(x, (2, 1, 1))
y2 = x2 @ c
for y2i in y2:
np.testing.assert_allclose(y2i, y.squeeze(0))
tols = dict(atol=1e-7, rtol=1e-5)
z = tf.matmul(x, c).numpy()
# z = tf.einsum("...tq,...qr->...tr", x, c)
np.testing.assert_allclose(z, y, **tols)
z2 = tf.matmul(x2, c).numpy()
# z2 = tf.einsum("...tq,...qr->...tr", x2, c)
np.testing.assert_allclose(z2, y2, **tols)
Relevant log output
file: /home/ehunsber/mambaforge/envs/tf213/lib/python3.8/site-packages/tensorflow/__init__.py
git version: v1.12.1-96406-gfa4d29bfef8
version: 2.14.0-dev20230706
Traceback (most recent call last):
File "test_batch.py", line 26, in <module>
np.testing.assert_allclose(z2, y2)
File "/home/ehunsber/mambaforge/envs/tf213/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 1592, in assert_allclose
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
File "/home/ehunsber/mambaforge/envs/tf213/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/ehunsber/mambaforge/envs/tf213/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 862, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
Mismatched elements: 100 / 100 (100%)
Max absolute difference: 0.00046098
Max relative difference: 0.01047752
x: array([[[-0.187503, 0.005696],
[-0.255552, -0.84404 ],
[ 0.268836, -1.250793],...
y: array([[[-0.187499, 0.005691],
[-0.255404, -0.844322],
[ 0.268935, -1.251033],...
Metadata
Metadata
Assignees
Labels
TF 2.13For issues related to Tensorflow 2.13For issues related to Tensorflow 2.13comp:opsOPs related issuesOPs related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activityThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorStatus - Awaiting response from authortype:bugBugBug