Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model not deterministic, even though os.environ['TF_DETERMINISTIC_OPS'] = '1' is set #38197

Closed
Zethson opened this issue Apr 3, 2020 · 20 comments
Assignees
Labels
comp:gpu GPU related issues comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 type:bug Bug

Comments

@Zethson
Copy link

Zethson commented Apr 3, 2020

System information

  • Have I written custom code (as opposed to using a stock
    example script provided in TensorFlow): Pretty much the MirroredStrategy fmnist example
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): tensorflow/tensorflow:2.2.0rc2-gpu-py3
  • TensorFlow installed from (source or
    binary): tensorflow/tensorflow:2.2.0rc2-gpu-py3
  • TensorFlow version (use command below): tensorflow/tensorflow:2.2.0rc2-gpu-py3
  • Python version: tensorflow/tensorflow:2.2.0rc2-gpu-py3
  • CUDA/cuDNN version: tensorflow/tensorflow:2.2.0rc2-gpu-py3
  • GPU model and memory: 1050M

Describe the current behavior
Model is not deterministic/reproducible.
Two runs:

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 2s 0us/step
Epoch 1, Loss: 0.17844311892986298, Accuracy: 0.9466999769210815,Test Loss: 0.057941436767578125, Test Accuracy: 0.9815000295639038
Epoch 2, Loss: 0.05286668613553047, Accuracy: 0.9836500287055969,Test Loss: 0.044471099972724915, Test Accuracy: 0.9853000044822693
Epoch 3, Loss: 0.03694676235318184, Accuracy: 0.9883000254631042,Test Loss: 0.034947194159030914, Test Accuracy: 0.9897000193595886
Epoch 4, Loss: 0.028592929244041443, Accuracy: 0.9910500049591064,Test Loss: 0.027234185487031937, Test Accuracy: 0.9907000064849854
Epoch 5, Loss: 0.022629836574196815, Accuracy: 0.9927666783332825,Test Loss: 0.029115190729498863, Test Accuracy: 0.9904000163078308
Epoch 6, Loss: 0.0172086451202631, Accuracy: 0.9944999814033508,Test Loss: 0.027797872200608253, Test Accuracy: 0.9902999997138977
Epoch 7, Loss: 0.013981950469315052, Accuracy: 0.9956499934196472,Test Loss: 0.02764272689819336, Test Accuracy: 0.9909999966621399
Epoch 8, Loss: 0.01210874691605568, Accuracy: 0.9961333274841309,Test Loss: 0.035009630024433136, Test Accuracy: 0.9896000027656555
Epoch 9, Loss: 0.008961305022239685, Accuracy: 0.9971666932106018,Test Loss: 0.034057389944791794, Test Accuracy: 0.9905999898910522
Epoch 10, Loss: 0.00800476036965847, Accuracy: 0.9972166419029236,Test Loss: 0.033878158777952194, Test Accuracy: 0.9900000095367432
GPU Run Time: 70.80781483650208 seconds
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 2s 0us/step
Epoch 1, Loss: 0.1761329025030136, Accuracy: 0.9478499889373779,Test Loss: 0.05224931612610817, Test Accuracy: 0.9835000038146973
Epoch 2, Loss: 0.05251472815871239, Accuracy: 0.9836666584014893,Test Loss: 0.04059470072388649, Test Accuracy: 0.9860000014305115
Epoch 3, Loss: 0.03771379590034485, Accuracy: 0.98785001039505,Test Loss: 0.03189479187130928, Test Accuracy: 0.9894000291824341
Epoch 4, Loss: 0.027971116825938225, Accuracy: 0.9912333488464355,Test Loss: 0.03176414594054222, Test Accuracy: 0.9890000224113464
Epoch 5, Loss: 0.022653400897979736, Accuracy: 0.9925000071525574,Test Loss: 0.03643624112010002, Test Accuracy: 0.9876999855041504
Epoch 6, Loss: 0.01727919466793537, Accuracy: 0.9942166805267334,Test Loss: 0.02887595444917679, Test Accuracy: 0.9901000261306763
Epoch 7, Loss: 0.01397143118083477, Accuracy: 0.9957500100135803,Test Loss: 0.03118096850812435, Test Accuracy: 0.9905999898910522
Epoch 8, Loss: 0.01202292088419199, Accuracy: 0.9961333274841309,Test Loss: 0.03164077177643776, Test Accuracy: 0.9909999966621399
Epoch 9, Loss: 0.008715414442121983, Accuracy: 0.9971333146095276,Test Loss: 0.04146642982959747, Test Accuracy: 0.9896000027656555
Epoch 10, Loss: 0.008586470037698746, Accuracy: 0.9969000220298767,Test Loss: 0.033046264201402664, Test Accuracy: 0.9902999997138977
GPU Run Time: 72.08828902244568 seconds

Describe the expected behavior
I expect the model to be reproducible with the same loss, accuracy etc
Standalone code to reproduce the issue

#!/usr/bin/env python 
import tensorflow as tf
import numpy as np
import argparse
import time
import random
import os

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model


def random_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed) # Python general
    np.random.seed(seed)
    random.seed(seed) # Python random
    tf.random.set_seed(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

# Not yet using click due to Docker issues
parser = argparse.ArgumentParser(description='Tensorflow entry point')
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

# Detect GPUs
print(f'Num GPUs Available: {len(tf.config.experimental.list_physical_devices("GPU"))}')

# Load MNIST
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1), since the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Normalizing the images to [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

# Use MirroredStrategy for multi GPU support
# If the list of devices is not specified in the`tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

# Batch and distribute data
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

# Fix seeds
random_seed(0)

# Define model
def create_model():
    model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Conv2D(64, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(10)
    ])

    return model

# Define Loss and accuracyc metrics
with strategy.scope():
    # Set reduction to `none` so reduction can be done afterwards and divide by global batch size.
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True,
        reduction=tf.keras.losses.Reduction.NONE)
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)

        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    test_loss = tf.keras.metrics.Mean(name='test_loss')

    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')


# Define model, optimizer, training- and test step
with strategy.scope():
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()

  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = compute_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_accuracy.update_state(labels, predictions)

    return loss 

  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)
    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)


with strategy.scope():
  # `run` replicates the provided computation and runs it with the distributed input.
  @tf.function
  def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
 
  @tf.function
  def distributed_test_step(dataset_inputs):
    return strategy.run(test_step, args=(dataset_inputs,))

  gpu_runtime = time.time()
  for epoch in range(args.epochs):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for dist_dataset in train_dist_dataset:
      total_loss += distributed_train_step(dist_dataset)
      num_batches += 1
    train_loss = total_loss / num_batches

    # TEST LOOP
    for dist_dataset in test_dist_dataset:
      distributed_test_step(dist_dataset)

    print(f'Epoch {epoch + 1}, Loss: {train_loss}, Accuracy: {train_accuracy.result()},'
          f'Test Loss: {test_loss.result()}, Test Accuracy: {test_accuracy.result()}')

    # Reset states
    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()

  print(f'GPU Run Time: {str(time.time() - gpu_runtime)} seconds')

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

def random_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed) # Python general
    np.random.seed(seed)
    random.seed(seed) # Python random
    tf.random.set_seed(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

I guess this should cover everything?

The code is currently running on a SINGLE GPU, even though I'm planning to run it on several GPUs.

@Zethson Zethson added the type:bug Bug label Apr 3, 2020
@geetachavan1 geetachavan1 added the TF 2.2 Issues related to TF 2.2 label Apr 3, 2020
@Saduf2019
Copy link
Contributor

@Zethson
i ran the code shared by you and face the error in this gist

@Zethson
Copy link
Author

Zethson commented Apr 6, 2020

ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-cafa3e81-09ee-4d77-8548-81b49eb3128e.json

An exception has occurred, use %tb to see the full traceback.

SystemExit: 2

I don't see how this error is related to the code? Seems to be a jupyter notebook kernel issue, no?

@Saduf2019 Saduf2019 added the comp:keras Keras related issues label Apr 16, 2020
@Saduf2019 Saduf2019 assigned gowthamkpr and unassigned Saduf2019 Apr 16, 2020
@gowthamkpr gowthamkpr added the comp:gpu GPU related issues label May 14, 2020
@gowthamkpr gowthamkpr assigned sanjoy and unassigned gowthamkpr May 14, 2020
@gowthamkpr gowthamkpr added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 14, 2020
@sanjoy
Copy link
Contributor

sanjoy commented May 22, 2020

@duncanriach Any ideas what could be going wrong here?

@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label May 24, 2020
@duncanriach
Copy link
Contributor

@duncanriach Any ideas what could be going wrong here?

Will take a look at this, hopefully today. Feel free to assign it to me, @sanjoy.

@duncanriach
Copy link
Contributor

I'm now actively working on this issue ...

@duncanriach
Copy link
Contributor

duncanriach commented May 29, 2020

Hey @Zethson, I repo'd your issue and found a solution. To get determinism, you need to do the following:

In both calls to shuffle, you should:

  1. set seed=123 (any integer)
  2. set reshuffle_each_iteration=False

tf.data.Dataset re-shuffling (which is the shuffle default and causes a re-shuffle before each epoch, including the first epoch) is currently not reproducible when used in conjunction with tf.distribute.MirroredStrategy (or I suspect any tf.distribute strategy). This is a bug that I came across recently in another context, but I have not yet had a chance to dig in and root-cause it. @sanjoy, I'll try to create a simple, direct test that demonstrates that the re-shuffle is not reproducible in the context of tf.distribute.MirroredStrategy (even with a single GPU) and possibly open a new issue for that.

Also, note that the code given in the original comment is almost the same as what's provided for Custom training with tf.distribute.Strategy except that everything from def train_step() onwards is, unnecessarily, in the strategy scope.

@Zethson
Copy link
Author

Zethson commented May 29, 2020

@duncanriach
Thanks!
I will try the approach in the next ~ 10 days and report back.
I also had reproduciblity issues with the CPU, do you expect them to be related to the dataset shuffling?

@duncanriach
Copy link
Contributor

I also had reproducibility issues with the CPU, do you expect them to be related to the dataset shuffling?

You're welcome. Yes, with these changes you should see the CPU training become reproducible as well. (Let me know the outcome of that.) The sources of non-determinism that we are addressing here are not related to the ops and therefore not related to which type of processor the ops are running on. Therefore, TF_DETERMINISTIC_OPS was necessary but not sufficient for determinism.

@Zethson
Copy link
Author

Zethson commented Jun 1, 2020

@duncanriach
Thanks! This does indeed solve the reproducibility issue.
However, it looks to me like it is only reproducible on the same system.
I am running the whole code inside a Docker container and I get reproducible results on two of my systems. However, the results inbetween the two systems is not the same.

Is this to be expected? If yes, what is the reason for this behavior?

@Zethson
Copy link
Author

Zethson commented Jun 1, 2020

To add some numbers:

System 1:
Run 1:

Epoch 1/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9492 - loss: 0.1709
Epoch 2/5
938/938 [==============================] - 11s 12ms/step - accuracy: 0.9727 - loss: 0.0899
Epoch 3/5
938/938 [==============================] - 11s 12ms/step - accuracy: 0.9768 - loss: 0.0816
Epoch 4/5
938/938 [==============================] - 11s 11ms/step - accuracy: 0.9785 - loss: 0.0729
Epoch 5/5
938/938 [==============================] - 10s 11ms/step - accuracy: 0.9797 - loss: 0.0731
157/157 [==============================] - 2s 12ms/step - accuracy: 0.9861 - loss: 0.0562
Test loss: 0.0561639703810215, Test Accuracy: 0.9861000180244446

Run 2:

Epoch 1/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9492 - loss: 0.1709
Epoch 2/5
938/938 [==============================] - 10s 11ms/step - accuracy: 0.9727 - loss: 0.0899
Epoch 3/5
938/938 [==============================] - 11s 11ms/step - accuracy: 0.9768 - loss: 0.0816
Epoch 4/5
938/938 [==============================] - 10s 11ms/step - accuracy: 0.9785 - loss: 0.0729
Epoch 5/5
938/938 [==============================] - 10s 11ms/step - accuracy: 0.9797 - loss: 0.0731
157/157 [==============================] - 2s 12ms/step - accuracy: 0.9861 - loss: 0.0562
Test loss: 0.0561639703810215, Test Accuracy: 0.9861000180244446

System 2:
Run 1:

Epoch 1/5
938/938 [==============================] - 17s 18ms/step - accuracy: 0.9484 - loss: 0.1709
Epoch 2/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9738 - loss: 0.0882
Epoch 3/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9771 - loss: 0.0787
Epoch 4/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9789 - loss: 0.0764
Epoch 5/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9793 - loss: 0.0707
157/157 [==============================] - 3s 20ms/step - accuracy: 0.9848 - loss: 0.0672
Test loss: 0.06723744422197342, Test Accuracy: 0.9847999811172485

Run 2:

Epoch 1/5
938/938 [==============================] - 17s 18ms/step - accuracy: 0.9484 - loss: 0.1709
Epoch 2/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9738 - loss: 0.0882
Epoch 3/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9771 - loss: 0.0787
Epoch 4/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9789 - loss: 0.0764
Epoch 5/5
938/938 [==============================] - 13s 14ms/step - accuracy: 0.9793 - loss: 0.0707
157/157 [==============================] - 4s 22ms/step - accuracy: 0.9848 - loss: 0.0672
Test loss: 0.06723744422197342, Test Accuracy: 0.9847999811172485


@duncanriach
Copy link
Contributor

duncanriach commented Jun 1, 2020

Is this to be expected? If yes, what is the reason for this behavior?

Hey @Zethson, from the GPU standpoint, bit-exact reproducibility between two systems is only guaranteed if the hardware-software stack is the same. Any changes in the stack could lead to differences in the way the computation workload is partitioned for (massively) parallel processing. The change in this partitioning will inevitably lead to differences in the accumulation of floating-point rounding errors in the computations. You can learn more about this by watching my GTC talk on the topic.

While a different version of anything in the hardware-software stack (e.g. different CUDA driver versions) could lead to slightly different results, you're most likely to see a difference if the GPU architecture is different, if the cuDNN version is different, or if the TensorFlow version is different. Since you're using the same container, I can infer that you're using the same versions of both cuDNN and TensorFlow on both machines. That leaves the GPU architecture. Does one of these machines contain a Pascal GPU and the other a Volta GPU perhaps? Please share the output from nvidia-smi on each of the machines.

It's also possible that there are hardware-software differences in the CPU-related stack that are introducing slightly different floating-point rounding errors, but that's less likely due to much less (or no) parallel computation on the CPU.

@Zethson
Copy link
Author

Zethson commented Jun 2, 2020

Thank you very much for your detailed response.
Your talk was already on my shortlist and I will absolutely be watching it. I'm highly interested in reproducible ML and I am sure that your talk will improve my understanding of the broader challenges.

Yes, the GPU architecture is very likely to be different. I am working on my own Laptop (1050M) and a VM (2 K80s).
System 1:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.82       Driver Version: 440.82       CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 1050    Off  | 00000000:01:00.0 Off |                  N/A |
| N/A   45C    P0    N/A /  N/A |      0MiB /  4040MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

System 2:

ubuntu@mlflow ~> nvidia-smi
Tue Jun  2 08:57:48 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 440.64.00    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla K80           On   | 00000000:00:05.0 Off |                    0 |
| N/A   27C    P8    30W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Tesla K80           On   | 00000000:00:06.0 Off |                    0 |
| N/A   24C    P8    26W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

So judging from your answer I determine that for full reproducibility, the same GPU architecture has to be used. Hence, for a ML model (Say, a 'terrorist detection model' or a 'cancer detection model') to be verifiable and reproducible, we would not only need the same code (solved by git), the same environment (solved by containers), but also the same hardware (solved as long as the hardware exists)?

Naive question: would it technically be possible to improve reproducibility (at the cost of precise model training) by decreasing the floating point precision and introducing a more eager rounding procedure?

I am a bioinformatician and currently, the state of reproducibility of data analysis has dramatically improved with the introduction of https://anaconda.org/bioconda/ and workflow languages such as Nextflow, which facilitate their usage. As a result, the results are not only fully reproducible, but also portable. Therefore, researchers can very easily verify the results, which is very important for the peer review process.

Nevertheless, I am aware of Nvidias efforts of speeding up the very computationally expensive bioinformatics analyses (and support that!) and fear that we may lose the portability, since the very same GPU architecture would be required.

If my last two paragraphs are off topic, then please tell me and I will remove them and would be happy to move the discussion elsewhere (if you are interested).

@duncanriach
Copy link
Contributor

duncanriach commented Jun 2, 2020

I'm happy to discuss this here. I suspect that our discussion may be helpful to others. Thanks for all the additional information.

The GeForce GTX 1050 contains a GPU that is based on the Pascal architecture and the Tesla K80s contain GPUs that are based on the Kepler architecture. So bit-exact reproducibility is not guaranteed, and in fact unlikely, between those two machines based solely on the GPU architecture they use.

However, a more significant factor is that you're doing multi-GPU training and on your laptop you have only a single GPU (one Kepler) while on the remote machine you have two GPUs (two Pascals). Because of the different number of GPUs, even if all those GPUs were from same architecture, you would definitely not get bit-exact reproducibility between the laptop and the remote machine.

The reason for this (again) is that the extensive floating-point computations are parallelized by being distributed in different ways on these two systems. This distribution in necessary and inherent in the process of maximally parallelizing (and therefore maximally accelerating) these computations. Computations that involve reducing the partial results from these compute partitions will include slightly different rounding errors depending on the way that the computation was partitioned. In the case of data-parallel multi-GPU (or multi-node) training, there is always going to be a reduction of the partial gradients produced on each of the GPUs (or nodes).


Would it technically be possible to improve reproducibility (at the cost of precise model training) by decreasing the floating point precision and introducing a more eager rounding procedure?

You're on the right track. Theoretically, there are four different possible ways around this that I can think of:

1. Use integers

While floating-point operations are not perfectly associative (rounding errors differ based on the order of operations), integers are perfectly associative. Integers (e.g. INT8) can be used for inference, and they often are used because they result in increased performance and reduced memory footprint. However, integers cannot (currently) easily be used for training because both range and precision are required, especially in the gradients.

2. Use double-precision floating-point (i.e. 64-bit floating point)

This will reduce the amount of floating-point rounding error that accumulates but there would still be a difference between GPU architectures and/or number of GPUs. This will also reduce performance a lot (at least 4x ?) and will at least double memory footprint. I've never trained a model with 64-bit integers through, and I don't know if it's possible in TensorFlow and whether the precision propagates all the way through, including through the back-prop. Based on my experience with TensorFlow's source code, I think it's very unlikely for typical cases.

3. Quantize after training

It's not possible to train all the way through in regular floating-point and then convert to integer or a reduced-precision floating point format at the end to get (probably reduced-accuracy) between-stack reproducible training results (i.e. trainable variables) because it's fundamentally not possible to reproducibly quantize-away the accumulated error differences. This is challenging concept to understand or explain in text form, sorry.

4. Final-train on CPU

Another option to think about is to report results from running on a single thread on a CPU. You would do all your development using the massive amount of acceleration provided by GPUs (or other accelerators) and then run once to get values to report. However, it's going to take a long time for that final run. Also, since the exact implementation of the underlying math can change on different CPUs (especially when using MKL), even when only using a single thread, you should still really include the CPU architecture that you used along with your results. Someone could run the same container and git repo hashes on a different CPU architecture and theoretically get slightly different results (just as with GPUs).

I imagine that none of the above are feasible for your needs.


We now have run-to-run training reproducibility on GPUs in TensorFlow. This is a relatively new achievement. I, and others, are now working on extending this support.

In reality, I think it's going to be totally practical for you to provide bit-exact results for one or more GPU architectures (e.g. Pascal or Volta). In terms of reproducibility, as far as I am aware, this goes way further than the current state-of-the-art in the ML/DL research community. I recommend that you qualify your bit-exact results as being achieved on a given hardware-software version stack, including the type and number of GPUs used. Given the underlying technical constraints, this is a reasonable compromise.

And with all of that said, it's important to remember that in most SGD-DL applications the amount of variance in the final result (e.g. test-set accuracy) is relatively small due to these differences in floating-point rounding error propagation. The advantage of a particular git hash running in a particular container image hash will, and should, attain most of the reproducibility demanded by peer-review. You can step-up the game even further by specifying the GPU architecture that the results were produced on.

@Zethson
Copy link
Author

Zethson commented Jun 6, 2020

@duncanriach
Thank you very much for your amazing in depth answer!
Learning a lot!

However, a more significant factor is that you're doing multi-GPU training and on your laptop you have only a single GPU (one Kepler) while on the remote machine you have two GPUs (two Pascals).

Yes, I was aware. Hence, I restricted my Docker container on my multi-GPU machine to only make a single of those 2 GPUs available. Both do of course show up when running nvidia-smi.

I will now ensure that any of my pipelines will output the CPU and GPU architecture and will advocate this whenever appropriate.

And with all of that said, it's important to remember that in most SGD-DL applications the amount of variance in the final result (e.g. test-set accuracy) is relatively small due to these differences in floating-point rounding error propagation.

Are you aware of any studies related to this? Any 'hard' numbers?
It would be interesting to know the fluctuation between the different network architectures and how close the found minima/maxima actually are.
I could image that the overall performance will always be relatively similar, but it might be possible that (imagine a manifold) we end up in completely different optimal solutions with more or less the same loss, but quite different models.

It would be nice to able to have an expected variance between different GPU architectures.
Say: Between Pascal and Kepler empirical studies suggest that a variance between 0.01% of the loss is to be expected, even when all reproducibility settings are used.

Cheers!

@duncanriach
Copy link
Contributor

duncanriach commented Jun 9, 2020

Hence, I restricted my Docker container on my multi-GPU machine to only make a single of those 2 GPUs available.

Good job.

I will now ensure that any of my pipelines will output the CPU and GPU architecture and will advocate this whenever appropriate.

Great. Thanks.

Are you aware of any studies related to this? Any 'hard' numbers?

No, but it's on my roadmap to do this.

I could image that the overall performance will always be relatively similar, but it might be possible that (imagine a manifold) we end up in completely different optimal solutions with more or less the same loss, but quite different models.

Yes. When there is non-determinism, this can result in training randomly and non-reproducibly failing (or not doing as well). Luckily, mini-batch training has the effect of avoiding local minima and finding the global minimum. If there is non-reproducible gradient explosion or disappearance on one of the effectively infinite paths to that global minimum, however, then that can make debugging almost impossible.

In training regimes in which there is no negative feedback (or where there is actually positive feedback, as with reinforcement learning), non-determinism will lead to completely different results on every run.

Note, and remember, that any system that does not have an end-to-end negative feedback loop can, and often will, amplify small differences in input to produce large differences in output.

These concepts apply, of course, to changing the hardware-software stack versioning and thereby potentially changing bit-accurate results, but it's less critical than run-to-run reproducibility (what we call determinism).

It would be nice to able to have an expected variance between different GPU architectures.

Something we plan to do sooner is to characterize the variance due to non-determinism for different model architectures on a given GPU architecture. This variance will have a similar order of magnitude to the variance between GPU architectures for that model architecture (it should be small).

The goal of deterministic operation of TensorFlow on GPUs is run-to-run determinism. What this then gives us, as a side-effect, is the ability to characterize the effect of changes to hardware-software stack versioning on model accuracy. Right now, however, there is a lot of work to be done to consolidate and broaden run-to-run reproducibility (in all DL frameworks).

@Zethson
Copy link
Author

Zethson commented Jun 9, 2020

@duncanriach
Great. Thank you again! Learned a lot in this thread.

I am looking forward to reading/hearing about your results of the variance of non-deterministic models. Very interested in this matter.

tf.data.Dataset re-shuffling (which is the shuffle default and causes a re-shuffle before each epoch, including the first epoch) is currently not reproducible when used in conjunction with tf.distribute.MirroredStrategy (or I suspect any tf.distribute strategy). This is a bug that I came across recently in another context, but I have not yet had a chance to dig in and root-cause it. @sanjoy, I'll try to create a simple, direct test that demonstrates that the re-shuffle is not reproducible in the context of tf.distribute.MirroredStrategy (even with a single GPU) and possibly open a new issue for that.

I consider this issue solved, but we could also keep it open for your mentioned bug.
Whatever you prefer.

@duncanriach
Copy link
Contributor

duncanriach commented Jun 10, 2020

Great. Thank you again! Learned a lot in this thread.

You're welcome, @Zethson. It's been a pleasure.

I am looking forward to reading/hearing about your results of the variance of non-deterministic models. Very interested in this matter.

You might want to star or watch https://github.com/NVIDIA/tensorflow-determinism because progress will be reported there first.

I consider this issue solved, but we could also keep it open for your mentioned bug.

Let's keep this current issue open for now. Once I've opened a new issue, with minimal repro code for the re-shuffle problem, then I'll inform you. Then you can close this current issue.

@duncanriach
Copy link
Contributor

Update: TensorFlow version 2.3.0 no longer exhibits the nondeterminism associated using tf.data.Dataset.shuffle(reshuffle_each_iteration=True) with tf.distribute.MirroredStrategy. Therefore, another issue does not need to be created. @Zethson, please will you now close this current issue.

@Zethson
Copy link
Author

Zethson commented Aug 26, 2020

@duncanriach
Thank you very much for the update!

@Zethson Zethson closed this as completed Aug 26, 2020
@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues comp:keras Keras related issues TF 2.2 Issues related to TF 2.2 type:bug Bug
Projects
Development

No branches or pull requests

7 participants