Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistency in CPU results and GPU results in model training #67137

Closed
Liqi1003 opened this issue May 8, 2024 · 4 comments
Closed

Inconsistency in CPU results and GPU results in model training #67137

Liqi1003 opened this issue May 8, 2024 · 4 comments
Assignees
Labels
comp:gpu GPU related issues stat:awaiting response Status - Awaiting response from author TF 2.13 For issues related to Tensorflow 2.13 type:bug Bug

Comments

@Liqi1003
Copy link

Liqi1003 commented May 8, 2024

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

No

Source

binary

TensorFlow version

tf 2.13.0

Custom code

Yes

OS platform and distribution

Linux Ubuntu 20.04.5

Mobile device

No response

Python version

3.8.10

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

11.8/8.7

GPU model and memory

No response

Current behavior?

I am reporting an issue encountered during the distributed training of a model with different types of devices using tensorflow. I initially encountered the bug with multiple GPUs involved, but reproduced the bug in a single GPU case. The version of tensorflow used is 2.13.0.

It is very likely an edge case. This inconsistency only occurs with the specific initial weights and inputs we provided. To make the difference more apparent given a limited amount of training data, we deliberately chose a relatively high learning rate (lr=10.0).

Before executing the code, put the model inside the same directory of the reproduce code, so that the model weights can be loaded. It is important to load the model weights, as random initial weights cannot reproduce this bug.

Standalone code to reproduce the issue

import sys
import keras
import random
import numpy as np
import tensorflow as tf
from keras import layers
from keras import optimizers

print("python version:", sys.version)
print("tensorflow version:", tf.__version__)
print("numpy version:", np.__version__)

def train_using_strategy(strategy, train_input, train_label, test_input):
    # Load model
    with strategy.scope():
        model = keras.models.load_model("./model.keras")
        optimizer = optimizers.SGD(learning_rate=10.0)
        loss = tf.keras.losses.CategoricalCrossentropy(
            from_logits=True,
            reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)

        model.compile(optimizer=optimizer, loss=loss)

    # Train for 1 step
    model.fit(
        train_input, train_label, verbose=0, shuffle=False, batch_size=2400)

    pred = model(test_input)
    pred = tf.nn.softmax(pred)
    return pred


# Set random seeds
seed = 54078027
random.seed(seed)
tf.random.set_seed(seed)

# Training data, batch_size=2400
train_input = tf.random.uniform(shape=(2400, 32, 32, 3))
train_label = tf.one_hot(tf.random.uniform(
    shape=(2400, ), minval=0, maxval=10, dtype=tf.int32), 10)
test_input = tf.random.uniform(shape=(1, 32, 32, 3))
test_label = tf.one_hot(tf.random.uniform(
    shape=(1, ), minval=0, maxval=10, dtype=tf.int32), 10)
        
# Original model
layer_0 = layers.Input(shape=(32, 32, 3,))
layer_1 = layers.Conv2D(
    filters=5,
    kernel_size=(13, 13),
    strides=(1, 1),
    padding="valid",
    data_format="channels_last",
    dilation_rate=(1, 1),
    activation="tanh",
    use_bias=True,
    kernel_initializer="random_uniform",
    bias_initializer="random_uniform")(layer_0)
layer_2 = layers.ReLU(max_value=0.08354582293069757)(layer_1)
layer_3 = layers.Flatten()(layer_2)
layer_4 = layers.Dense(
    units=10,
    activation="linear",
    use_bias=False,
    kernel_initializer="random_uniform",
    bias_initializer="random_uniform")(layer_3)
layer_5 = layers.Reshape((10,))(layer_4)

model = keras.Model(layer_0, layer_5)
model.summary()

model.load_weights("./tensorflow.h5")

# Alternatively, load the model directly
# model = keras.models.load_model("./tensorflow.h5")

keras.models.save_model(model, "./model.keras")

res_cpu = train_using_strategy(
    strategy=tf.distribute.MirroredStrategy(devices=["/CPU:0"]), 
    train_input=train_input, 
    train_label=train_label, 
    test_input=test_input)

res_gpu = train_using_strategy(
    strategy=tf.distribute.MirroredStrategy(devices=["/GPU:0"]), 
    train_input=train_input, 
    train_label=train_label, 
    test_input=test_input)

print("max diff:", np.max(np.abs(res_cpu - res_gpu)))
print("result on CPU device:", res_cpu)
print("result on GPU device:", res_gpu)

Relevant log output

2024-05-08 01:19:26.978000: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-08 01:19:28.059507: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
python version: 3.8.10 (default, Nov 22 2023, 10:22:35) 
[GCC 9.4.0]
tensorflow version: 2.13.0
numpy version: 1.23.5
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU')]
2024-05-08 01:19:29.719405: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10696 MB memory:  -> device: 0, name: NVIDIA TITAN V, pci bus id: 0000:af:00.0, compute capability: 7.0
2024-05-08 01:19:29.720298: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 11539 MB memory:  -> device: 1, name: NVIDIA TITAN Xp, pci bus id: 0000:3b:00.0, compute capability: 6.1
2024-05-08 01:19:29.721009: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 11539 MB memory:  -> device: 2, name: NVIDIA TITAN Xp, pci bus id: 0000:5e:00.0, compute capability: 6.1
2024-05-08 01:19:29.721728: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 11539 MB memory:  -> device: 3, name: NVIDIA TITAN Xp, pci bus id: 0000:86:00.0, compute capability: 6.1
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 conv2d (Conv2D)             (None, 20, 20, 5)         2540      
                                                                 
 re_lu (ReLU)                (None, 20, 20, 5)         0         
                                                                 
 flatten (Flatten)           (None, 2000)              0         
                                                                 
 dense (Dense)               (None, 10)                20000     
                                                                 
 reshape (Reshape)           (None, 10)                0         
                                                                 
=================================================================
Total params: 22540 (88.05 KB)
Trainable params: 22540 (88.05 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
2024-05-08 01:19:30.445908: W tensorflow/core/framework/dataset.cc:956] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
2024-05-08 01:19:31.126108: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f57a8015840 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2024-05-08 01:19:31.126233: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2024-05-08 01:19:31.154914: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2024-05-08 01:19:31.408683: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:432] Loaded cuDNN version 8700
2024-05-08 01:19:32.675226: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f5c85a4c040 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-05-08 01:19:32.675279: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA TITAN V, Compute Capability 7.0
2024-05-08 01:19:32.675292: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (1): NVIDIA TITAN Xp, Compute Capability 6.1
2024-05-08 01:19:32.675304: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (2): NVIDIA TITAN Xp, Compute Capability 6.1
2024-05-08 01:19:32.675316: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (3): NVIDIA TITAN Xp, Compute Capability 6.1

max diff: 0.01396697
result on CPU device: tf.Tensor(
[[1.7262544e-03 1.6841354e-02 1.6154546e-02 4.0267220e-01 1.6546119e-03
  7.9213362e-14 4.0498021e-01 1.1821460e-01 1.8025419e-02 1.9730769e-02]], 
  shape=(1, 10), dtype=float32)
result on GPU device: tf.Tensor(
[[1.73812674e-03 1.88294742e-02 1.49312345e-02 4.12326276e-01 1.74528488e-03
  7.97685756e-14 4.08124179e-01 1.04247630e-01 1.85339060e-02 1.95238516e-02]], 
  shape=(1, 10), dtype=float32)
@google-ml-butler google-ml-butler bot added the type:bug Bug label May 8, 2024
@SuryanarayanaY SuryanarayanaY added comp:gpu GPU related issues TF 2.13 For issues related to Tensorflow 2.13 labels May 8, 2024
@SuryanarayanaY
Copy link
Collaborator

Hi @Liqi1003 ,

The difference seems precision related because of XLA fusion. Please find the developer comment on same for more details.

@SuryanarayanaY SuryanarayanaY added the stat:awaiting response Status - Awaiting response from author label May 8, 2024
@Liqi1003
Copy link
Author

Liqi1003 commented May 9, 2024

Hi @SuryanarayanaY,

Thanks for the pointer!

To test whether it is a problem caused by XLA, I tried to disable XLA by adding TF_XLA_FLAGS=--tf_xla_auto_jit=-1 before the command, as mentioned here. However, I still see the log indicates XLA compilation is enabled, and the outputs are the same as the one I posted above. Am I using the wrong way to disable XLA?

Also, I was only able to reproduce this problem in tensorflow 2.13. Using google colab with tensorflow 2.16.1, running the same code does not result in such inconsistency. Here is the colab.

I wonder if it is a bug that was silently fixed in later versions? Thanks!

@google-ml-butler google-ml-butler bot removed the stat:awaiting response Status - Awaiting response from author label May 9, 2024
@SuryanarayanaY
Copy link
Collaborator

Hi @Liqi1003 ,

It seems 2.16v has better precision than earlier versions. May be there seems some internal amendments which I am not aware.
As it resolved in latest version can we mark this as closed. Thanks!

@SuryanarayanaY SuryanarayanaY added the stat:awaiting response Status - Awaiting response from author label May 9, 2024
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues stat:awaiting response Status - Awaiting response from author TF 2.13 For issues related to Tensorflow 2.13 type:bug Bug
Projects
None yet
Development

No branches or pull requests

2 participants