Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow session.run() overhead for graphs with few flops #120

Closed
dementrock opened this issue Nov 11, 2015 · 20 comments
Closed

TensorFlow session.run() overhead for graphs with few flops #120

dementrock opened this issue Nov 11, 2015 · 20 comments
Assignees

Comments

@dementrock
Copy link

The following code will take 10 seconds to run:

with tf.Session():
    for _ in range(1000):
        tf.constant(0).eval()

Now, this might not be an issue for large-scale machine learning, but it makes running the graph in real-time very hard. I was trying to port a forward simulation model from Theano to TensorFlow, where 1000 runs of forward simulation took 0.2s in Theano vs. 17s in TensorFlow, and over half of the time was taken by this session.run() overhead.

@mrry
Copy link
Contributor

mrry commented Nov 11, 2015

There's a subtle issue that crops up when performing these microbenchmarks on TensorFlow. I ran a version of your code on my laptop (a late-2014 MacBook Air):

>>> with tf.Graph().as_default():
...   with tf.Session():
...     start = time.time()
...     for _ in xrange(1000):
...       _ = tf.constant(0).eval()
...     end = time.time()
>>> print end - start
4.721920967102051

The perhaps surprising thing to note about this code is that it is actually the tf.constant(0) method that is the expensive part. If I run this slightly different program, it completes much faster:

>>> with tf.Graph().as_default():
...   c = tf.constant(0)
...   with tf.Session():
...     start = time.time()
...     for _ in xrange(1000):
...       _ = c.eval()
...     end = time.time()
>>> print end-start
0.0893728733063

In the first version (like in your snippet), I created 1000 identical constant nodes in the graph, and evaluated each of them: this took several seconds. In the second version, I created the constant node once, and evaluated it 1000 times: this took less than 100 milliseconds.

In summary, it's important to try to reuse the existing graph as much as possible for each step of your model. Adding nodes to the graph isn't free, and we could probably optimize that further. Feel free to share more of your simulation code, and we can look into any major performance issues that it might have.

@dementrock
Copy link
Author

Thanks - I guess this means that the overhead is somewhere else then. I'll post more information once I can narrow down the cause of the slowdown.

@dementrock
Copy link
Author

@mrry

I've put up some simple benchmark here:

https://github.com/dementrock/tensorfuse#running-benchmark

This is a small library that bridges the API between Theano and TensorFlow, and another library called CGT. I just wrote 3 simple tests, and both Theano and CGT beat TensorFlow on these:

Using Theano for TensorFuse
func:'time_sin' 10000 times took: 0.1418 sec
func:'time_matmul' 10000 times took: 0.0786 sec
func:'time_slicing' 10000 times took: 0.1508 sec
Using CGT for TensorFuse
func:'time_sin' 10000 times took: 0.0516 sec
func:'time_matmul' 10000 times took: 0.0619 sec
func:'time_slicing' 10000 times took: 0.0529 sec
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 8
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 8
Using TensorFlow for TensorFuse
func:'time_sin' 10000 times took: 0.8568 sec
func:'time_matmul' 10000 times took: 1.1223 sec
func:'time_slicing' 10000 times took: 1.1388 sec

It might be the way I'm bridging TensorFlow to theano.function, since I'm pretty new to TensorFlow. The implementation is at
https://github.com/dementrock/tensorfuse/blob/master/tensorfuse/__init__.py#L33. I'd really appreciate it if you could take a look. Thanks!

@dementrock dementrock reopened this Nov 11, 2015
@dementrock dementrock changed the title Overhead with session.run() TensorFlow is slow: some initial benchmark Nov 11, 2015
@mrry
Copy link
Contributor

mrry commented Nov 11, 2015

Your benchmarks seem to align with what I can measure on my laptop. For example:

>>> with tf.Graph().as_default():
...   x = tf.constant(np.random.rand(3, 3).astype(np.float32))
...   y = tf.constant(np.random.rand(3, 3).astype(np.float32))
...   z = tf.matmul(x, y)
...   with tf.Session():
...     start = time.time()
...     for _ in xrange(10000):
...       _ = z.eval()
...     end = time.time()
>>> end - start
1.6025779247283936

I should add that the operations in your benchmark are very small - computable in a time that is on the order of microseconds or nanoseconds - compared to the size of computation for which TensorFlow is designed. While it would be nice to reduce any unnecessary overhead from step and op dispatch, it is unlikely that doing so would dramatically reduce the time taken to run an inference or training step in a realistic neural network. With that said, thanks for looking into this, and if you have any suggestions for how to reduce this overhead, we would be glad to hear them!

@dementrock
Copy link
Author

Is the size defined in terms of the dimension of the data, or the complexity of the computation?

@mrry
Copy link
Contributor

mrry commented Nov 11, 2015

I'd define it in terms of the number of floating-point operations needed to compute the result of the step. A 3 x 3 matrix multiplication (to take the time_matmul benchmark as an example) uses very few floating-point operations compared to the constant framework overhead. It would be quicker to do the computation than dispatch it to another framework. Similarly, it would almost certainly not be worth offloading that computation to a GPU, because of the overheads in dispatching and fetching the results of a kernel. By contrast, larger matrix multiplications an convolutions have a high flop count, and will tend to benefit from this approach.

It would still be very informative to learn how the overhead changes with the size of the data. Would you consider adding different sizes of input for each of the workloads?

@dementrock
Copy link
Author

The overhead does diminish as the size of the matrix increases. Some results below (all tests including the ones before were on CPU):

Using Theano for TensorFuse
func:'time_matmul_3x3' 10000 times took: 0.0804 sec
func:'time_matmul_64x64' 10000 times took: 0.2886 sec
func:'time_matmul_256x256' 1000 times took: 0.3078 sec
func:'time_matmul_512x512' 1000 times took: 2.0757 sec
func:'time_matmul_1024x1024' 100 times took: 1.6689 sec
func:'time_matmul_2048x2048' 10 times took: 1.3290 sec
func:'time_matmul_4096x4096' 10 times took: 10.4791 sec
func:'time_matmul_8192x8192' 10 times took: 82.8672 sec
Using CGT for TensorFuse
func:'time_matmul_3x3' 10000 times took: 0.0608 sec
func:'time_matmul_64x64' 10000 times took: 0.2964 sec
func:'time_matmul_256x256' 1000 times took: 0.3348 sec
func:'time_matmul_512x512' 1000 times took: 2.4141 sec
func:'time_matmul_1024x1024' 100 times took: 1.9360 sec
func:'time_matmul_2048x2048' 10 times took: 1.4082 sec
func:'time_matmul_4096x4096' 10 times took: 10.7015 sec
func:'time_matmul_8192x8192' 10 times took: 83.5396 sec
Using TensorFlow for TensorFuse
I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 8
I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 8
func:'time_matmul_3x3' 10000 times took: 1.2752 sec
func:'time_matmul_64x64' 10000 times took: 2.5342 sec
func:'time_matmul_256x256' 1000 times took: 0.8030 sec
func:'time_matmul_512x512' 1000 times took: 4.4840 sec
func:'time_matmul_1024x1024' 100 times took: 2.1287 sec
func:'time_matmul_2048x2048' 10 times took: 1.5108 sec
func:'time_matmul_4096x4096' 10 times took: 11.0894 sec
func:'time_matmul_8192x8192' 10 times took: 84.9977 sec

The overhead is unnoticeable when size >= 2048.

However, I suspect there's some overhead with each operation added to the graph, even if they are all executed in a single session.run(), which might be the reason why tensorflow is so much slower for the forward physics simulation stuff I was doing.

@vrv vrv changed the title TensorFlow is slow: some initial benchmark TensorFlow session.run() overhead for graphs with few flops Nov 12, 2015
@vrv
Copy link

vrv commented Feb 19, 2016

Closing due to inactivity -- it would be great to get the overhead lower for compute-light graphs, but it's probably not a major focus.

@vrv vrv closed this as completed Feb 19, 2016
ilblackdragon pushed a commit to ilblackdragon/tensorflow that referenced this issue Mar 9, 2016
ilblackdragon pushed a commit to ilblackdragon/tensorflow that referenced this issue Mar 9, 2016
lukeiwanski pushed a commit to codeplaysoftware/tensorflow that referenced this issue Oct 26, 2017
The changes to the ScatterNd op in 63ac86a placed the SYCl kernel
registration inside 'GOOGLE_CUDA' guards, which means that they weren't
registered except when compiled with both CUDA and SYCL support. Here I
ensure that the SYCL registrations are outside of the 'GOOGLE_CUDA'
scope.
@TimZaman
Copy link
Contributor

TimZaman commented Nov 6, 2018

I ran Derek's experiment again:

with tf.Graph().as_default():
  c = tf.constant(0)
  with tf.Session():
    start = time()
    for _ in xrange(1000):
      _ = c.eval()
    end = time()
print(end-start)

>>> 0.215318918228

That's around 5000 steps per second. His results were 2.5x faster back in 2015. Seems that the graph execution overhead is immense!

@mrry
Copy link
Contributor

mrry commented Nov 6, 2018

@TimZaman Just speculating, but I think that a large fraction of the cost might come from the first call to c.eval(), which performs various one-time startup activities (and has generally grown in responsibility since 2015). I'd hope that the subsequent steps are faster than 200us per call. You might be interested in looking here to see some of the ways to reduce the overhead of invoking a graph:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/client/session_benchmark.py

However, the recent engineering focus has been on making single-op execution fast in eager mode, and using TensorFlow functions as a replacement for sess.run().

@TimZaman
Copy link
Contributor

TimZaman commented Nov 6, 2018

Ran some tests.. just a small improvement:

{cold, warm, warm} start:

with tf.Graph().as_default():
  c = tf.constant(0)
  with tf.Session():
    for e in range(3):
      start = time()
      for _ in xrange(1000):
        _ = c.eval()
      end = time()
      print(end-start)
>>> 0.240370988846
>>> 0.179282188416
>>> 0.208700180054

evaluating just the .op

makes it twice as fast (still meh).

with tf.Graph().as_default():
  c = tf.constant(0)
  with tf.Session() as sess:
    for e in range(3):
      start = time()
      for _ in xrange(1000):
        sess.run(c.op)
      end = time()
      print(end-start)
>>> 0.167293071747
>>> 0.106302022934
>>> 0.102854967117

eager

Around 40k per second. The same code in torch is 300k/s, numpy 800k/s, python(int) 17M/s.

tf.enable_eager_execution()
c = tf.constant(0)
c = tf.convert_to_tensor(c)
for e in range(3):
    start = time()
    for _ in xrange(1000):
        c += 1
    end = time()
    print(end-start)
c= tf.Tensor(1000, shape=(), dtype=int32)
0.0253150463104
c= tf.Tensor(2000, shape=(), dtype=int32)
0.0205891132355
c= tf.Tensor(3000, shape=(), dtype=int32)
0.0187749862671

@mrry
Copy link
Contributor

mrry commented Nov 6, 2018

CC @asimshankar for the Eager numbers, to see if there's any low-hanging fruit there.

@vrv
Copy link

vrv commented Nov 6, 2018

In graph mode I think the constant is being placed on GPU (which you probably have given where you work :P) and so maybe you are partially timing GPU->CPU copy time. Things get a little faster (2x on my machine) if you don't use GPU, FWIW. Still not where it needs to be though...

@TimZaman
Copy link
Contributor

TimZaman commented Nov 6, 2018 via email

@asimshankar
Copy link
Contributor

asimshankar commented Nov 6, 2018

RE: eager - There are two annoying things right now - the operator overloading for += and the conversion from the Python 1 to a Tensor. Those need to be improved (CC @akshaym), but in the mean time these trivial changes make a big improvement (number from my machine running 1.12.0):

import tensorflow as tf
from time import time

tf.enable_eager_execution()
c = tf.constant(0)
o = tf.constant(1)
for e in range(3):
    start = time()
    for _ in xrange(1000):
        c = tf.add(c, o)
    end = time()
    print(end-start)

Results in:

0.00489807128906
0.00518703460693
0.00459504127502

And just for thrills, compiling into a graph (using tf.contrib.eager.defun right now, which will morph into tf.function in 2.0):

@tf.contrib.eager.defun
def f(c):
  o = tf.constant(1)
  for _ in xrange(1000):
    c = tf.add(c, 0)
  return c

c = tf.constant(0)
# Discard the first run, since that includes graph building time
_ = f(c)
for e in range(3):
  start = time()
  c = f(c)
  end = time()
  print(end - start)

Results in:

0.00160002708435
0.00185298919678
0.00164079666138

@asimshankar
Copy link
Contributor

I should add that using the graph function, the overheads of the operator overload and the Python->Tensor conversion are paid for only at graph construction time. So this should work too:

@tf.contrib.eager.defun
def f(c):
  for _ in xrange(1000):
    c += 1
  return c

c = tf.constant(0)
# Discard the first run, since that includes graph building time
_ = f(c)
for e in range(3):
  start = time()
  c = f(c)
  end = time()
  print(end - start)

@TimZaman
Copy link
Contributor

TimZaman commented Nov 6, 2018 via email

@twiecki
Copy link

twiecki commented Dec 17, 2018

We also ran into this issue, significant call overhead compared to Theano (~10x slower for small graphs, even with defun). Could we maybe re-open this issue to give it more significance?

@asimshankar
Copy link
Contributor

@twiecki - re:defun, would you mind filling a new issue specifically for that, along with details of what exactly was being measured?

@twiecki
Copy link

twiecki commented Jan 3, 2019

@asimshankar #24684

tensorflow-copybara pushed a commit that referenced this issue Sep 5, 2019
The refactoring of ExecutionEngine dropped the usage of the irTransform function used to pass -O3 and other options to LLVM. As a consequence, the proper optimizations do not kick in in LLMV-land.

This CL makes use of the transform function and allows producing avx512 instructions, on an internal example, when using:
`mlir-cpu-runner -dump-object-file=1 -object-filename=foo.o` combined with `objdump -D foo.o`.

Assembly produced resembles:
```
    2b2e:       62 72 7d 48 18 04 0e    vbroadcastss (%rsi,%rcx,1),%zmm8
    2b35:       62 71 7c 48 28 ce       vmovaps %zmm6,%zmm9
    2b3b:       62 72 3d 48 a8 c9       vfmadd213ps %zmm1,%zmm8,%zmm9
    2b41:       62 f1 7c 48 28 cf       vmovaps %zmm7,%zmm1
    2b47:       62 f2 3d 48 a8 c8       vfmadd213ps %zmm0,%zmm8,%zmm1
    2b4d:       62 f2 7d 48 18 44 0e    vbroadcastss 0x4(%rsi,%rcx,1),%zmm0
    2b54:       01
    2b55:       62 71 7c 48 28 c6       vmovaps %zmm6,%zmm8
    2b5b:       62 72 7d 48 a8 c3       vfmadd213ps %zmm3,%zmm0,%zmm8
    2b61:       62 f1 7c 48 28 df       vmovaps %zmm7,%zmm3
    2b67:       62 f2 7d 48 a8 da       vfmadd213ps %zmm2,%zmm0,%zmm3
    2b6d:       62 f2 7d 48 18 44 0e    vbroadcastss 0x8(%rsi,%rcx,1),%zmm0
    2b74:       02
    2b75:       62 f2 7d 48 a8 f5       vfmadd213ps %zmm5,%zmm0,%zmm6
    2b7b:       62 f2 7d 48 a8 fc       vfmadd213ps %zmm4,%zmm0,%zmm7
```
etc.

Fixes #120

PiperOrigin-RevId: 267281097
xinan-jiang pushed a commit to xinan-jiang/tensorflow that referenced this issue Oct 4, 2019
The refactoring of ExecutionEngine dropped the usage of the irTransform function used to pass -O3 and other options to LLVM. As a consequence, the proper optimizations do not kick in in LLMV-land.

This CL makes use of the transform function and allows producing avx512 instructions, on an internal example, when using:
`mlir-cpu-runner -dump-object-file=1 -object-filename=foo.o` combined with `objdump -D foo.o`.

Assembly produced resembles:
```
    2b2e:       62 72 7d 48 18 04 0e    vbroadcastss (%rsi,%rcx,1),%zmm8
    2b35:       62 71 7c 48 28 ce       vmovaps %zmm6,%zmm9
    2b3b:       62 72 3d 48 a8 c9       vfmadd213ps %zmm1,%zmm8,%zmm9
    2b41:       62 f1 7c 48 28 cf       vmovaps %zmm7,%zmm1
    2b47:       62 f2 3d 48 a8 c8       vfmadd213ps %zmm0,%zmm8,%zmm1
    2b4d:       62 f2 7d 48 18 44 0e    vbroadcastss 0x4(%rsi,%rcx,1),%zmm0
    2b54:       01
    2b55:       62 71 7c 48 28 c6       vmovaps %zmm6,%zmm8
    2b5b:       62 72 7d 48 a8 c3       vfmadd213ps %zmm3,%zmm0,%zmm8
    2b61:       62 f1 7c 48 28 df       vmovaps %zmm7,%zmm3
    2b67:       62 f2 7d 48 a8 da       vfmadd213ps %zmm2,%zmm0,%zmm3
    2b6d:       62 f2 7d 48 18 44 0e    vbroadcastss 0x8(%rsi,%rcx,1),%zmm0
    2b74:       02
    2b75:       62 f2 7d 48 a8 f5       vfmadd213ps %zmm5,%zmm0,%zmm6
    2b7b:       62 f2 7d 48 a8 fc       vfmadd213ps %zmm4,%zmm0,%zmm7
```
etc.

Fixes tensorflow#120

PiperOrigin-RevId: 267281097
lissyx added a commit to lissyx/tensorflow that referenced this issue Aug 1, 2020
lissyx pushed a commit to lissyx/tensorflow that referenced this issue Aug 1, 2020
This reverts commit f2f881a, reversing
changes made to 518c1d0.
lissyx added a commit to lissyx/tensorflow that referenced this issue Aug 1, 2020
Revert "Merge pull request tensorflow#120 from lissyx/fix-linker"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants