Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple CPU usage ineffective #583

Closed
alantus opened this issue Dec 21, 2015 · 29 comments
Closed

Multiple CPU usage ineffective #583

alantus opened this issue Dec 21, 2015 · 29 comments
Assignees

Comments

@alantus
Copy link

alantus commented Dec 21, 2015

I'm running tensorflow on a machine with 64 CPUs and no GPUs. I noticed that although tf detects the number of CPU's correctly (
I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 64
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 64
), the total CPU usage is less than 16 most of the time. It increases to 20-30 occasionally but for a very short period of time. And even that is much lower than 64.

@yaroslavvb
Copy link
Contributor

Can you try with a plan that is naively paralelizable to make sure it's not communication bottleneck somewhere?

For instance, rename the attached file into "parallel.py" and then run it as "python parallel.py ". On my macbook pro I get following numbers which shows linear scaling up to 4 cores

python ~/g/src/parallel.py 1
done in 0.99, 10.13 ops/sec

python ~/g/src/parallel.py 2
done in 0.97, 20.58 ops/sec

python ~/g/src/parallel.py 3
done in 1.02, 29.55 ops/sec

python ~/g/src/parallel.py 4
done in 1.04, 38.29 ops/sec

python ~/g/src/parallel.py 5
done in 1.33, 37.45 ops/sec

parallel.txt

@alantus
Copy link
Author

alantus commented Dec 23, 2015

Two things:

  1. Here's what it writes to stderr:
    I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 32
    I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 32
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    I tensorflow/core/kernels/logging_ops.cc:79] [heavy op #0]
    W tensorflow/core/common_runtime/executor.cc:1076] 0x7f77dc0125c0 Compute status: Out of range: FIFOQueue '_0_fifo_queue' is closed and has insufficient elements (requested 1, current size 0)
    [[Node: fifo_queue_Dequeue = QueueDequeuecomponent_types=[DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"]]
    W tensorflow/core/common_runtime/executor.cc:1076] 0x7f77dc012de0 Compute status: Aborted: Queue '_0_fifo_queue' is already closed.

[[Node: fifo_queue_Close_1 = QueueClosecancel_pending_enqueues=false, _device="/job:localhost/replica:0/task:0/cpu:0"]]

Is that ok?
2. I run it with the following line:
for i in seq 1 64; do echo $i ./parallel.py $i 2>/dev/null; done
and here's the output:
1 done in 1.38, 7.26 ops/sec
2 done in 1.45, 13.79 ops/sec
3 done in 1.65, 18.21 ops/sec
4 done in 1.70, 23.50 ops/sec
5 done in 1.73, 28.98 ops/sec
6 done in 1.82, 33.05 ops/sec
7 done in 1.86, 37.70 ops/sec
8 done in 1.88, 42.65 ops/sec
9 done in 1.73, 52.05 ops/sec
10 done in 1.92, 52.03 ops/sec
11 done in 2.07, 53.13 ops/sec
12 done in 2.17, 55.29 ops/sec
13 done in 1.83, 70.92 ops/sec
14 done in 2.05, 68.13 ops/sec
15 done in 2.36, 63.45 ops/sec
16 done in 2.48, 64.64 ops/sec
17 done in 2.03, 83.89 ops/sec
18 done in 1.94, 92.69 ops/sec
19 done in 2.51, 75.74 ops/sec
20 done in 2.31, 86.41 ops/sec
21 done in 2.04, 102.80 ops/sec
22 done in 2.29, 96.04 ops/sec
23 done in 2.19, 105.18 ops/sec
24 done in 2.08, 115.29 ops/sec
25 done in 2.39, 104.58 ops/sec
26 done in 2.18, 119.52 ops/sec
27 done in 2.36, 114.33 ops/sec
28 done in 2.43, 115.42 ops/sec
29 done in 2.77, 104.66 ops/sec
30 done in 2.27, 132.39 ops/sec
31 done in 2.57, 120.63 ops/sec
32 done in 2.54, 126.02 ops/sec
33 done in 2.59, 127.47 ops/sec
34 done in 2.57, 132.32 ops/sec
35 done in 2.50, 139.96 ops/sec
36 done in 2.34, 153.63 ops/sec
37 done in 2.71, 136.65 ops/sec
38 done in 2.70, 140.72 ops/sec
39 done in 2.71, 144.05 ops/sec
40 done in 3.07, 130.32 ops/sec
41 done in 2.73, 150.23 ops/sec
42 done in 3.22, 130.30 ops/sec
43 done in 2.82, 152.70 ops/sec
44 done in 3.09, 142.43 ops/sec
45 done in 2.68, 167.86 ops/sec
46 done in 2.69, 170.91 ops/sec
47 done in 2.73, 172.44 ops/sec
48 done in 2.77, 173.29 ops/sec
49 done in 3.28, 149.27 ops/sec
50 done in 2.92, 171.16 ops/sec
51 done in 3.25, 157.13 ops/sec
52 done in 3.11, 167.12 ops/sec
53 done in 3.32, 159.57 ops/sec
54 done in 2.93, 184.45 ops/sec
55 done in 3.46, 158.80 ops/sec
56 done in 3.44, 162.72 ops/sec
57 done in 3.26, 174.85 ops/sec
58 done in 3.35, 173.29 ops/sec
59 done in 3.72, 158.64 ops/sec
60 done in 3.38, 177.43 ops/sec
61 done in 3.50, 174.10 ops/sec
62 done in 3.47, 178.43 ops/sec
63 done in 3.97, 158.81 ops/sec

64 done in 3.27, 195.81 ops/sec

So it does scale up to 64 threads...

@tridemax
Copy link

I've tested on 40 threads (2x CPU) and got the attached results.
tensorflowscalability

Slow increase in 10-20 and 30-40 threads range probably connected to hyper-threading limitations.
But models/mnist/convolutional.py rarely occupies more than 25% of total CPU.

@benoitsteiner
Copy link
Contributor

This could be related to #551: the code ends up spending a lot of time locking and unlocking a mutex in the threadpool code.

@jeremybarnes
Copy link

This is very likely the case. With 8 vCPUs, the mutex locking is nearly undetectable. With 16, it is about 10% of the runtime. With 32, it dominates. I can only imagine that with 64 it's completely unusable.

If you replace the mutex in the thread pool with a spinlock (I copied the source from the one here: https://github.com/mldbai/mldb/blob/master/arch/spinlock.h), does it change the shape of the graph?

I have started work on a lockless threadpool implementation; it's complicated by the fact that thread pools are created and destroyed dynamically, so it needs to be able to deal with threads coming and going.

jeremybarnes added a commit to mldbai/tensorflow that referenced this issue Jan 13, 2016
Addresses tensorflow#581, tensorflow#583

For a benchmark on running the December Inception model 100 times on a 32 vcore CPU,
results are:

Current implementation (baseline)

211.17user 154.92system 0:15.42elapsed 2373%CPU (0avgtext+0avgdata 950024maxresident)k
0inputs+0outputs (0major+186279minor)pagefaults 0swaps

Lock free implementation

204.62user 27.12system 0:09.08elapsed 2551%CPU (0avgtext+0avgdata 737888maxresident)k
0inputs+0outputs (0major+122158minor)pagefaults 0swaps

So throughput is improved by ~35% and system overhead is greatly reduced.
@jeremybarnes
Copy link

@tridemax PR #763 may address the issue. If it is easy to re-run the tests on that branch, it would be good to know what the effect is.

@jeremybarnes
Copy link

Confirmed: part of this is caused by the thread pool locking. The PR in #763 nearly halves the elapsed wall time to train the mnist convolutional.py example:

from:

2464.97user 767.13system 7:48.46elapsed 689%CPU (0avgtext+0avgdata 655848maxresident)k
0inputs+0outputs (0major+84760819minor)pagefaults 0swap

to:

2827.11user 724.31system 4:19.43elapsed 1368%CPU (0avgtext+0avgdata 1441204maxresident)k
0inputs+0outputs (0major+61934308minor)pagefaults 0swaps

We still only manage to 1/2 utilize the CPU, however. Further improvements would probably come from splitting the minibatch into multiple shards that execute the whole graph independently, so that the holes during the reduction operations can be filled with convolutions from another batch.

@tridemax
Copy link

@jeremybarnes I've tested this on our experimental Windows build (as this server is Windows-based), so there is no direct way for me to test your branch. =(

jeremybarnes added a commit to mldbai/tensorflow that referenced this issue Jan 26, 2016
Addresses tensorflow#581, tensorflow#583

For a benchmark on running the December Inception model 100 times on a 32 vcore CPU,
results are:

Current implementation (baseline)

211.17user 154.92system 0:15.42elapsed 2373%CPU (0avgtext+0avgdata 950024maxresident)k
0inputs+0outputs (0major+186279minor)pagefaults 0swaps

Lock free implementation

204.62user 27.12system 0:09.08elapsed 2551%CPU (0avgtext+0avgdata 737888maxresident)k
0inputs+0outputs (0major+122158minor)pagefaults 0swaps

So throughput is improved by ~35% and system overhead is greatly reduced.
jeremybarnes added a commit to mldbai/tensorflow that referenced this issue Jan 26, 2016
This adds an experimental (mostly) lock-free thread pool implementation which can
improve the throughput on 8 or more core machines when running CPU-only (or when
the CPU's ability to feed the GPUs is the bottleneck).

It is disabled by default, and should be tested by setting the environment
variable TF_THREAD_POOL=lock_free when launching the C++ or python executable
running the Tensorflow session.

Addresses tensorflow#551, tensorflow#583

For a benchmark on running the December Inception model 100 times on a 32 vcore CPU,
results are:

Default implementation (baseline)

211.17user 154.92system 0:15.42elapsed 2373%CPU (0avgtext+0avgdata 950024maxresident)k
0inputs+0outputs (0major+186279minor)pagefaults 0swaps

Lock free implementation

204.62user 27.12system 0:09.08elapsed 2551%CPU (0avgtext+0avgdata 737888maxresident)k
0inputs+0outputs (0major+122158minor)pagefaults 0swaps

So throughput is improved by ~35% and system overhead is greatly reduced.
@dvyukov
Copy link

dvyukov commented Jan 28, 2016

How can I expose this contention on a realistic program? Preferably from C++?

I've tried to modify tensorflow/examples/label_image/main.cc to execute session->Run in parallel:

--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -32,6 +32,9 @@ limitations under the License.
 // The googlenet_graph.pb file included by default is created from Inception.

 #include <fstream>
+#include <thread>
+#include <memory>
+#include <memory>

 #include "tensorflow/cc/ops/const_op.h"
 #include "tensorflow/cc/ops/image_ops.h"
@@ -290,15 +293,22 @@ int main(int argc, char* argv[]) {
   }
   const Tensor& resized_tensor = resized_tensors[0];

-  // Actually run the image through the model.
-  std::vector<Tensor> outputs;
-  Status run_status = session->Run({{input_layer, resized_tensor}},
+  const int N = 20;
+  std::vector<std::unique_ptr<std::thread>> threads(N);
+  for (int i = 0; i < N; i++)
+    threads[i].reset(new std::thread([&]() {
+      // Actually run the image through the model.
+      std::vector<Tensor> outputs;
+      Status run_status = session->Run({{input_layer, resized_tensor}},
                                    {output_layer}, {}, &outputs);
-  if (!run_status.ok()) {
-    LOG(ERROR) << "Running model failed: " << run_status;
-    return -1;
-  }
+      if (!run_status.ok()) {
+        LOG(ERROR) << "Running model failed: " << run_status;
+      }
+    }));
+  for (int i = 0; i < N; i++)
+    threads[i]->join();

+  /*
   // This is for automated testing to make sure we get the expected result with
   // the default settings. We know that label 866 (military uniform) should be
   // the top label for the Admiral Hopper image.
@@ -321,6 +331,6 @@ int main(int argc, char* argv[]) {
     LOG(ERROR) << "Running print failed: " << print_status;
     return -1;
   }
-
+  */
   return 0;
 }

and to create multiple sessions as well:

--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -32,6 +32,9 @@ limitations under the License.
 // The googlenet_graph.pb file included by default is created from Inception.

 #include <fstream>
+#include <thread>
+#include <memory>
+#include <memory>

 #include "tensorflow/cc/ops/const_op.h"
 #include "tensorflow/cc/ops/image_ops.h"
@@ -268,13 +271,18 @@ int main(int argc, char* argv[]) {
     return -1;
   }

+  const int N = 20;
+  std::vector<std::unique_ptr<std::thread>> threads(N);
+  for (int i = 0; i < N; i++)
+    threads[i].reset(new std::thread([&]() {
+
   // First we load and initialize the model.
   std::unique_ptr<tensorflow::Session> session;
   string graph_path = tensorflow::io::JoinPath(root_dir, graph);
   Status load_graph_status = LoadGraph(graph_path, &session);
   if (!load_graph_status.ok()) {
     LOG(ERROR) << load_graph_status;
-    return -1;
+    //return -1;
   }

   // Get the image from disk as a float array of numbers, resized and normalized
@@ -286,19 +294,22 @@ int main(int argc, char* argv[]) {
                               input_std, &resized_tensors);
   if (!read_tensor_status.ok()) {
     LOG(ERROR) << read_tensor_status;
-    return -1;
+    //return -1;
   }
   const Tensor& resized_tensor = resized_tensors[0];

-  // Actually run the image through the model.
-  std::vector<Tensor> outputs;
-  Status run_status = session->Run({{input_layer, resized_tensor}},
+      // Actually run the image through the model.
+      std::vector<Tensor> outputs;
+      Status run_status = session->Run({{input_layer, resized_tensor}},
                                    {output_layer}, {}, &outputs);
-  if (!run_status.ok()) {
-    LOG(ERROR) << "Running model failed: " << run_status;
-    return -1;
-  }
+      if (!run_status.ok()) {
+        LOG(ERROR) << "Running model failed: " << run_status;
+      }
+    }));
+  for (int i = 0; i < N; i++)
+    threads[i]->join();

+  /*
   // This is for automated testing to make sure we get the expected result with
   // the default settings. We know that label 866 (military uniform) should be
   // the top label for the Admiral Hopper image.
@@ -321,6 +332,6 @@ int main(int argc, char* argv[]) {
     LOG(ERROR) << "Running print failed: " << print_status;
     return -1;
   }
-
+  */
   return 0;
 }

But in both cases it nicely consumes all my cores, no significant time spent in system and top functions in profile are all doing useful work:

+  20.62%  label_image  label_image           [.] float __vector(4) Eigen::internal::pmul<float __vector(4)>(float __vector(4) const&, float __vector(4) const&)
+  16.66%  label_image  label_image           [.] void Eigen::internal::gebp_traits<float, float, false, false>::madd<float __vector(4), float __vector(4), float __vector(4)>(float __vector(4) const&, floa
+  15.80%  label_image  label_image           [.] float __vector(4) Eigen::internal::padd<float __vector(4)>(float __vector(4) const&, float __vector(4) const&)
+   6.55%  label_image  label_image           [.] float __vector(4) Eigen::internal::pload<float __vector(4)>(Eigen::internal::unpacket_traits<float __vector(4)>::type const*)
+   6.20%  label_image  label_image           [.] void Eigen::internal::pbroadcast4<float __vector(4)>(Eigen::internal::unpacket_traits<float __vector(4)>::type const*, float __vector(4)&, float __vector(4
+   4.57%  label_image  label_image           [.] Eigen::internal::gebp_kernel<float, float, long, Eigen::internal::blas_data_mapper<float, long, 0, 0>, 8, 4, false, false>::operator()(Eigen::internal::bla
+   4.25%  label_image  label_image           [.] Eigen::internal::TensorIntDivisor<long, false>::divide(long) const

Can somebody suggest a modification to existing C++ examples that would expose the contention? thanks.

I am on commit d4422ff (Jan 26).

@dvyukov
Copy link

dvyukov commented Jan 29, 2016

label_image seems to deadlock (#929)
so any good parallel benchmarks?

dvyukov added a commit to dvyukov/tensorflow that referenced this issue Jan 29, 2016
Current thread pool implementation is centralized and non-scalable.
This change adds distributed, non-blocking thread pool implementation.
Both implementations co-exist and can be chosen using
TF_THREAD_POOL env var.

Fixes tensorflow#551
Fixes tensorflow#583
Update tensorflow#932
Update tensorflow#933
dvyukov added a commit to dvyukov/tensorflow that referenced this issue Jan 29, 2016
Current thread pool implementation is centralized and non-scalable.
This change adds distributed, non-blocking thread pool implementation.
Both implementations co-exist and can be chosen using
TF_THREAD_POOL env var.

Fixes tensorflow#551
Fixes tensorflow#583
Update tensorflow#932
Update tensorflow#933
dvyukov added a commit to dvyukov/tensorflow that referenced this issue Jan 29, 2016
Current thread pool implementation is centralized and non-scalable.
This change adds distributed, non-blocking thread pool implementation.
Both implementations co-exist and can be chosen using
TF_THREAD_POOL env var.

Fixes tensorflow#551
Fixes tensorflow#583
Update tensorflow#932
Update tensorflow#933
dvyukov added a commit to dvyukov/tensorflow that referenced this issue Jan 29, 2016
Current thread pool implementation is centralized and non-scalable.
This change adds distributed, non-blocking thread pool implementation.
Both implementations co-exist and can be chosen using
TF_THREAD_POOL env var.

Fixes tensorflow#551
Fixes tensorflow#583
Update tensorflow#932
Update tensorflow#933
dvyukov added a commit to dvyukov/tensorflow that referenced this issue Jan 30, 2016
Current thread pool implementation is centralized and non-scalable.
This change adds distributed, non-blocking thread pool implementation.
Both implementations co-exist and can be chosen using
TF_THREAD_POOL env var.

Fixes tensorflow#551
Fixes tensorflow#583
Update tensorflow#932
Update tensorflow#933
dvyukov added a commit to dvyukov/tensorflow that referenced this issue Feb 2, 2016
Current thread pool implementation is centralized and non-scalable.
This change adds distributed, non-blocking thread pool implementation.
Both implementations co-exist and can be chosen using
TF_THREAD_POOL env var.

Fixes tensorflow#551
Fixes tensorflow#583
Update tensorflow#932
Update tensorflow#933
@jeremybarnes
Copy link

@dvyukov I ran the mnist training example distributed with Tensorflow (it's Python, but it exercises the C++ code). Unfortunately Reviewable swallowed the comment with the exact command line. Note that you need to have a machine with more than 16 vcores and no GPU enabled to see the contention.

@dvyukov
Copy link

dvyukov commented Feb 16, 2016

@jeremybarnes thanks for mnist pointer. FTR the command line is:

$ bazel build --dynamic_mode=off -c opt tensorflow/models/image/mnist:convolutional
$ bazel-bin/tensorflow/models/image/mnist/convolutional

Ha, that also deadlocks for me.
I think the reason why you do not see deadlocks is that your pool implementation actually provides FIFO for externally submitted tasks accidentally (or maybe not!) and that tasks that require FIFO ordering are exactly externally submitted. And it is also important that these tasks are submitted from a single external thread. If anything of the above becomes false, then I think you will also see deadlocks.

I can kind-a work around deadlocks by creating 8x more threads than requested. But that results in 800 threads created (there are 2 thread pools in process (why?)). And I guess it's not guaranteed to work anyway.

@dvyukov
Copy link

dvyukov commented Feb 16, 2016

Doh. I've added a global FIFO queue for externally submitted work and deadlocks go away. The pool is still faster than the default one. But what I see is that all work is on the global queue. So what provides speedup in our cases is merely thread spinning before blocking. Ouch. If we remove the worker work queue, I bet it will become just faster...

Such parallel algorithms are not going to scale whatever pool is used for scheduling. The parallelization part needs to be rewritten.

@tridemax
Copy link

Could we just end up using kind of automatic multi-towering for large CPUs, like with multi-GPU case? TensoFlow is supposed to have clustering technique built-in, so this should be an easy pick?

@dvyukov
Copy link

dvyukov commented Feb 16, 2016

@tridemax Do you mean some kind of higher-level partitioning?

Even when I run tensorflow/models/image/mnist/convolutional with just 4 threads, it is still unable to utilize 4 cores (only about 3.5). The partitioning will also have some overheads for communicate between partitions (I don't know how large it is, though).
Overall it does not look like good strategy long-term. Partitioning is used to mitigate inherent communication overheads. In this case we try to mitigate artificial overheads. Real cost of shared memory synchronization is pretty low.

@jeremybarnes
Copy link

@dvyukov When I looked at the low CPU occupancy originally, I saw two main reasons:

  1. Massive contention on the lock for the global work queue. In order to solve point 2 below, we had structured our work into multiple overlapping and independent jobs, with a reduction at the end, so there were definitely lots of threads simultaneously submitting work. The lock contention was overwhelming. Replacing the mutex with a spinlock helped quite a bit.
  2. Once that was fixed, it appears that there is also quite simply a lack of exploitable parallelism during execution of typical operation graphs. I didn't look exactly at which ops were running when, but it appears that there are choke points in the computation graph that have either no or little parallelism (eg, a tree based reduction will only be able to occupy ~ 1/2 of the available cores on average). We attempted to work around by running more than one independent graph, eg by running multiple shards of a minibatch independently in parallel, so that the high occupancy operations like convolutions from one could fill in the holes in the low-occupancy operations in the other. This gives up some stochasticity for better occupancy, which is not always a tradeoff worth making.

In any event, the workloads we have for Tensorflow do cause many threads to submit jobs, which is partially what drove the design for the PR that I have open; I'd be open to a simpler global queue but I'm pretty sure it wouldn't solve our problem.

@dvyukov
Copy link

dvyukov commented Feb 16, 2016

Ok, then I guess we need to start with a representative set of benchmarks that we want to optimize. The only one that I was pointed to (mnist) submits work from a single thread.

Different requirements lead to different designs.

@jeremybarnes
Copy link

Agreed. I can submit a benchmark that is more representative of our use-case. It could be simulated by running label_image in 8-16 threads in parallel within a single process.

It would help if someone working on Tensorflow from Google could provide a position on what kinds of external use-cases are interesting vs the internal ones to help us agree on a starting point.

@dvyukov
Copy link

dvyukov commented Feb 19, 2016

@benoitsteiner Can you please answer @jeremybarnes question? Good representative benchmarks are a good thing regardless of outcome of everything else here. I've struggled to find any ready benchmarks.

@linkerlin
Copy link

I think maybe OpenMP will help a lot.

@tridemax
Copy link

tridemax commented Apr 2, 2016

Actually, ANY kind of network we ran (MNIST, AlexNet, few hybrid CNN+LSTM setups) on 20 core (40 threads) machine, yielded 20-30% load, regardless of the task.

@songgc
Copy link

songgc commented May 18, 2016

With ab02c5, I do see a significant speedup on CPUs.

@benoitsteiner
Copy link
Contributor

@songgc ab02c5a combined with a few other changes to make sure all the TensorFlow operations are multithreaded resulted in a 1.5x to 3x improvement on most benchmarks. Is that what you're seeing?

@songgc
Copy link

songgc commented May 18, 2016

@benoitsteiner Yes. I trained relatively small networks on intel CPUs and did inference on an ARM CPU. On both systems (learning and inference), I did see 1.5x to 1.8x gains. Thanks a lot!

@BogusLogin
Copy link

Hi, my problem is similar to @alantus post in the beginning:
I installed Keras with TF as backend on a machine with multiple cpus (each cpu can run only 1 thread). At runtime Keras/TF only uses 1 cpu at all.

How can I check how many devices TF was able to recognize?

@dvyukov
Copy link

dvyukov commented May 19, 2016

@tridemax @jeremybarnes Can you please test your benchmarks after commit ab02c5?
Scaling on CPU should be significantly improved.

@vincentvanhoucke
Copy link
Contributor

As of change ab02c5, we consider the CPU scaling issue generally resolved, so I'm going to close this bug. Please file new issues with benchmarks and steps to reproduce if you still see any performance issues.

andre-geldenhuis added a commit to andre-geldenhuis/kokako that referenced this issue Mar 9, 2017
as  mutex locking begins to dominate the cpu time available
see: tensorflow/tensorflow#583
we have vary meny Tensor flow jobs to run (100s of 1000s) so
internal parallelism is irrelevant anyway.
@BogdanRuzh
Copy link

@vincentvanhoucke I still have a problem with multicore CPU utilization for MKL build
#15320

@yunjiangster
Copy link

I have 48 cpu cores and 4 gpu cores. I am using tfrecord reader. Even with a trivial network, my cpu utilization is capped at 1800%. I tried sharding my input data and using parallel_interleave but there is no usage improvement.

darkbuck pushed a commit to darkbuck/tensorflow that referenced this issue Jan 23, 2020
…pstream-deven-misc-190726

fix for a couple subtest failures in the unit-test //tensorflow/python:nn_fused_batchnorm_test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

13 participants