New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-deterministic mean and sum reduction #3103

Closed
eamartin opened this Issue Jun 29, 2016 · 10 comments

Comments

Projects
None yet
7 participants
@eamartin
Copy link

eamartin commented Jun 29, 2016

I'm running Tensorflow 0.9.0 installed from wheel on Python 2.7 on a K40 with CUDA 7.0.

The following test case attempts to minimize the mean of a vector through gradient descent. The script finds that the vectors are equal at all steps, but the means are not. I believe the vectors being equal at all steps is pure numerical luck, since non-deterministic loss likely means non-deterministic gradient which means non-deterministic/reproducible iterative optimization. I've observed cases where training results in different final losses where the only source of non-determinism is from reduce_mean.

import numpy as np
import tensorflow as tf

n_dims = 1000
n_steps = 50

np.random.seed(2016)

vec = tf.Variable(np.random.randn(n_dims).astype(np.float32))
mean = tf.reduce_mean(vec)

optimizer = tf.train.GradientDescentOptimizer(0.01)
train_step = optimizer.minimize(mean)

def generate():
    data = []
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())

        for _ in xrange(n_steps):
            _vec, _mean, _ = sess.run([vec, mean, train_step])
            data.append((_vec, _mean))

    return [np.array([f[i] for f in data]) for i in xrange(2)]

first_vec, first_mean = generate()
second_vec, second_mean = generate()
print 'vecs equal:', np.all(first_vec == second_vec)

print 'mean equal:', np.all(first_mean == second_mean)
print 'means not equal at idxs:', np.nonzero(first_mean != second_mean)[0]

Example output:

vecs equal: True
mean equal: False
means not equal at idxs: [ 4  5 11 18 34 38 44 49]

From looking through the code, it appears the GPU mean reduction is implemented with GPU sum reduction: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/reduction_ops_gpu.cu.cc#L43
I've confirmed my test case still triggers when I replace "reduce_mean" with "reduce_sum".

The GPU sum reduction appears to be implemented using CUDA's atomicAdd: https://bitbucket.org/eigen/eigen/src/241472d2a52142e23b0b2ba5c301c6c146298fa9/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h?at=default&fileviewer=file-view-default#TensorReductionCuda.h-97

Atomic floating point adds on GPU are the problem. Having floating point adds to the same address in an undefined order is inherently non-deterministic due to non-associativity of floating point arithmetic.

This issue could be solved (and reduction performance improved) by using some sort of reduction tree to reduce within blocks, and then launching a second kernel (or doing some manual block synchronization tricks) to reduce across blocks.

@eamartin

This comment has been minimized.

Copy link

eamartin commented Jul 1, 2016

I created a smaller test case.

import numpy as np
import tensorflow as tf

np.random.seed(2016)
data = np.random.randn(100000).astype(np.float32)

vec = tf.placeholder(tf.float32, data.shape)
avg = tf.reduce_mean(vec)

avgs = []
with tf.Session() as sess:
    for _ in xrange(100):
        avgs.append(sess.run(avg, feed_dict={vec: data}))

print min(avgs) == max(avgs)
print max(avgs) - min(avgs)

with output

False
6.98492e-10
@rmlarsen

This comment has been minimized.

Copy link
Member

rmlarsen commented Jul 1, 2016

@lightcatcher as you point out, our current implementation of sum reduction (and the various ops that depend on it) is not deterministic on either GPU or multi-threaded CPUs. It is primarily a speed/accuracy trade-off and if we could get comparable speed from one of the approaches you mention, we would be happy to switch the implementation. In short, this is working as intended for now, but contributions for a more accurate or even deterministic sum reduction (possibly as a separate op) would certainly be welcomed.

@rmlarsen rmlarsen changed the title Non-deterministic mean (and sum) on GPU Non-deterministic mean and sum reduction Jul 1, 2016

@rmlarsen

This comment has been minimized.

Copy link
Member

rmlarsen commented Jul 13, 2016

I'm closing this as "working as intended".

@TimZaman

This comment has been minimized.

Copy link
Contributor

TimZaman commented Sep 23, 2016

Which other ops are non-deterministic that are not non-deterministic by nature?

@rasbt

This comment has been minimized.

Copy link
Contributor

rasbt commented Mar 14, 2017

@TimZaman Unless something has changed in the recent months, some of the cuDNN code is non-deterministic, for instance cudnn.SpatialConvolution. So, I guess that some of the CNN-related stuff in tensorflow may be non-reproducible (if run on GPU). Would probably be a bit of work, but it would be nice to have a flag or note in the TF docstrings of the affected functions.

@yaroslavvb

This comment has been minimized.

Copy link
Contributor

yaroslavvb commented Mar 14, 2017

As @zheng-xq mentioned earlier, anything using cuda atomics is non-deterministic, so a way to narrow it down is to see which CuDNN algorithms use CUDA atomics. For CPU ops, the way to check might be to track down parallel ops (see which ops use tensorflow/core/util/work_sharder.cc) and check that result is independent of the order in which individual work shards complete. Note that there are more tricky cases of non-determinism, for instance same sequence of SSE instructions can give different results on rerun, so to get a stronger guarantee of determinism you need to disable multi-threading and special instruction sets: http://blog.nag.com/2011/02/wandering-precision.html

@ahmedhosny

This comment has been minimized.

Copy link

ahmedhosny commented Dec 12, 2017

@eamartin I ran you "smaller test case" from above and I am getting a fully deterministic behavior. I am just wondering what has changed since you ran it. I am using py 3.5 and tf 1.4.1.

@eamartin

This comment has been minimized.

Copy link

eamartin commented Sep 17, 2018

Are reductions still non-determinstic by default on GPU? If so, can this issue be re-opened? Determinstic computation is critical for reproducibility, and reductions are a critical part of neural nets.

Finally:
Several comments on this issue and other linked issues mention that "reductions are non-determinstic for performance". This is not the case. A reduction tree is both determinstic and generally faster than using atomic adds (which cause non-determinism). https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ describes reduction trees, shows that reduction tree + very limited use of atomic add is the fastest option, but that reduction trees (determinstic) is only marginally slower.
Last I checked (quite a while ago), the TF reduction implementation exclusively used atomics and no reduction tree, so a switch to a reduction tree only implementation would provide a performance boost. My guess is that atomics were used for the TF implementation not for performance but because the implementation with atomics is somewhat simpler to write.

@yaroslavvb

This comment has been minimized.

Copy link
Contributor

yaroslavvb commented Sep 17, 2018

I believe there's been some work to make reductions deterministic

@eamartin

This comment has been minimized.

Copy link

eamartin commented Sep 17, 2018

@yaroslavvb
I just re-ran my initial examples on TF 1.5.0 and found that that reduce_mean call produced consistent results across 10K trials with the mean kernel running on a K80 GPU. This observation agrees with @ahmedhosny 's observation.

I dug through the source a little bit and I think the reduction logic happens in https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h . This (and the TensorReductionGpu.h file) are new since I opened this issue in 2016 and contain reduction tree logic. However, they also still contain atomic floating point adds. I think these are to reduce the values between the different GPU blocks.

This file has been recently modified by @rmlarsen . @rmlarsen , do you have any insight on how deterministic these reductions are? When is atomic floating point add used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment