New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multiple CPU usage ineffective #583
Comments
Can you try with a plan that is naively paralelizable to make sure it's not communication bottleneck somewhere? For instance, rename the attached file into "parallel.py" and then run it as "python parallel.py ". On my macbook pro I get following numbers which shows linear scaling up to 4 cores python ~/g/src/parallel.py 1 python ~/g/src/parallel.py 2 python ~/g/src/parallel.py 3 python ~/g/src/parallel.py 4 python ~/g/src/parallel.py 5 |
Two things:
[[Node: fifo_queue_Close_1 = QueueClosecancel_pending_enqueues=false, _device="/job:localhost/replica:0/task:0/cpu:0"]]Is that ok? 64 done in 3.27, 195.81 ops/secSo it does scale up to 64 threads... |
This could be related to #551: the code ends up spending a lot of time locking and unlocking a mutex in the threadpool code. |
This is very likely the case. With 8 vCPUs, the mutex locking is nearly undetectable. With 16, it is about 10% of the runtime. With 32, it dominates. I can only imagine that with 64 it's completely unusable. If you replace the mutex in the thread pool with a spinlock (I copied the source from the one here: https://github.com/mldbai/mldb/blob/master/arch/spinlock.h), does it change the shape of the graph? I have started work on a lockless threadpool implementation; it's complicated by the fact that thread pools are created and destroyed dynamically, so it needs to be able to deal with threads coming and going. |
Addresses tensorflow#581, tensorflow#583 For a benchmark on running the December Inception model 100 times on a 32 vcore CPU, results are: Current implementation (baseline) 211.17user 154.92system 0:15.42elapsed 2373%CPU (0avgtext+0avgdata 950024maxresident)k 0inputs+0outputs (0major+186279minor)pagefaults 0swaps Lock free implementation 204.62user 27.12system 0:09.08elapsed 2551%CPU (0avgtext+0avgdata 737888maxresident)k 0inputs+0outputs (0major+122158minor)pagefaults 0swaps So throughput is improved by ~35% and system overhead is greatly reduced.
Confirmed: part of this is caused by the thread pool locking. The PR in #763 nearly halves the elapsed wall time to train the mnist convolutional.py example: from:
to:
We still only manage to 1/2 utilize the CPU, however. Further improvements would probably come from splitting the minibatch into multiple shards that execute the whole graph independently, so that the holes during the reduction operations can be filled with convolutions from another batch. |
@jeremybarnes I've tested this on our experimental Windows build (as this server is Windows-based), so there is no direct way for me to test your branch. =( |
Addresses tensorflow#581, tensorflow#583 For a benchmark on running the December Inception model 100 times on a 32 vcore CPU, results are: Current implementation (baseline) 211.17user 154.92system 0:15.42elapsed 2373%CPU (0avgtext+0avgdata 950024maxresident)k 0inputs+0outputs (0major+186279minor)pagefaults 0swaps Lock free implementation 204.62user 27.12system 0:09.08elapsed 2551%CPU (0avgtext+0avgdata 737888maxresident)k 0inputs+0outputs (0major+122158minor)pagefaults 0swaps So throughput is improved by ~35% and system overhead is greatly reduced.
This adds an experimental (mostly) lock-free thread pool implementation which can improve the throughput on 8 or more core machines when running CPU-only (or when the CPU's ability to feed the GPUs is the bottleneck). It is disabled by default, and should be tested by setting the environment variable TF_THREAD_POOL=lock_free when launching the C++ or python executable running the Tensorflow session. Addresses tensorflow#551, tensorflow#583 For a benchmark on running the December Inception model 100 times on a 32 vcore CPU, results are: Default implementation (baseline) 211.17user 154.92system 0:15.42elapsed 2373%CPU (0avgtext+0avgdata 950024maxresident)k 0inputs+0outputs (0major+186279minor)pagefaults 0swaps Lock free implementation 204.62user 27.12system 0:09.08elapsed 2551%CPU (0avgtext+0avgdata 737888maxresident)k 0inputs+0outputs (0major+122158minor)pagefaults 0swaps So throughput is improved by ~35% and system overhead is greatly reduced.
How can I expose this contention on a realistic program? Preferably from C++? I've tried to modify --- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -32,6 +32,9 @@ limitations under the License.
// The googlenet_graph.pb file included by default is created from Inception.
#include <fstream>
+#include <thread>
+#include <memory>
+#include <memory>
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
@@ -290,15 +293,22 @@ int main(int argc, char* argv[]) {
}
const Tensor& resized_tensor = resized_tensors[0];
- // Actually run the image through the model.
- std::vector<Tensor> outputs;
- Status run_status = session->Run({{input_layer, resized_tensor}},
+ const int N = 20;
+ std::vector<std::unique_ptr<std::thread>> threads(N);
+ for (int i = 0; i < N; i++)
+ threads[i].reset(new std::thread([&]() {
+ // Actually run the image through the model.
+ std::vector<Tensor> outputs;
+ Status run_status = session->Run({{input_layer, resized_tensor}},
{output_layer}, {}, &outputs);
- if (!run_status.ok()) {
- LOG(ERROR) << "Running model failed: " << run_status;
- return -1;
- }
+ if (!run_status.ok()) {
+ LOG(ERROR) << "Running model failed: " << run_status;
+ }
+ }));
+ for (int i = 0; i < N; i++)
+ threads[i]->join();
+ /*
// This is for automated testing to make sure we get the expected result with
// the default settings. We know that label 866 (military uniform) should be
// the top label for the Admiral Hopper image.
@@ -321,6 +331,6 @@ int main(int argc, char* argv[]) {
LOG(ERROR) << "Running print failed: " << print_status;
return -1;
}
-
+ */
return 0;
} and to create multiple sessions as well: --- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -32,6 +32,9 @@ limitations under the License.
// The googlenet_graph.pb file included by default is created from Inception.
#include <fstream>
+#include <thread>
+#include <memory>
+#include <memory>
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/image_ops.h"
@@ -268,13 +271,18 @@ int main(int argc, char* argv[]) {
return -1;
}
+ const int N = 20;
+ std::vector<std::unique_ptr<std::thread>> threads(N);
+ for (int i = 0; i < N; i++)
+ threads[i].reset(new std::thread([&]() {
+
// First we load and initialize the model.
std::unique_ptr<tensorflow::Session> session;
string graph_path = tensorflow::io::JoinPath(root_dir, graph);
Status load_graph_status = LoadGraph(graph_path, &session);
if (!load_graph_status.ok()) {
LOG(ERROR) << load_graph_status;
- return -1;
+ //return -1;
}
// Get the image from disk as a float array of numbers, resized and normalized
@@ -286,19 +294,22 @@ int main(int argc, char* argv[]) {
input_std, &resized_tensors);
if (!read_tensor_status.ok()) {
LOG(ERROR) << read_tensor_status;
- return -1;
+ //return -1;
}
const Tensor& resized_tensor = resized_tensors[0];
- // Actually run the image through the model.
- std::vector<Tensor> outputs;
- Status run_status = session->Run({{input_layer, resized_tensor}},
+ // Actually run the image through the model.
+ std::vector<Tensor> outputs;
+ Status run_status = session->Run({{input_layer, resized_tensor}},
{output_layer}, {}, &outputs);
- if (!run_status.ok()) {
- LOG(ERROR) << "Running model failed: " << run_status;
- return -1;
- }
+ if (!run_status.ok()) {
+ LOG(ERROR) << "Running model failed: " << run_status;
+ }
+ }));
+ for (int i = 0; i < N; i++)
+ threads[i]->join();
+ /*
// This is for automated testing to make sure we get the expected result with
// the default settings. We know that label 866 (military uniform) should be
// the top label for the Admiral Hopper image.
@@ -321,6 +332,6 @@ int main(int argc, char* argv[]) {
LOG(ERROR) << "Running print failed: " << print_status;
return -1;
}
-
+ */
return 0;
} But in both cases it nicely consumes all my cores, no significant time spent in system and top functions in profile are all doing useful work:
Can somebody suggest a modification to existing C++ examples that would expose the contention? thanks. I am on commit d4422ff (Jan 26). |
label_image seems to deadlock (#929) |
Current thread pool implementation is centralized and non-scalable. This change adds distributed, non-blocking thread pool implementation. Both implementations co-exist and can be chosen using TF_THREAD_POOL env var. Fixes tensorflow#551 Fixes tensorflow#583 Update tensorflow#932 Update tensorflow#933
Current thread pool implementation is centralized and non-scalable. This change adds distributed, non-blocking thread pool implementation. Both implementations co-exist and can be chosen using TF_THREAD_POOL env var. Fixes tensorflow#551 Fixes tensorflow#583 Update tensorflow#932 Update tensorflow#933
Current thread pool implementation is centralized and non-scalable. This change adds distributed, non-blocking thread pool implementation. Both implementations co-exist and can be chosen using TF_THREAD_POOL env var. Fixes tensorflow#551 Fixes tensorflow#583 Update tensorflow#932 Update tensorflow#933
Current thread pool implementation is centralized and non-scalable. This change adds distributed, non-blocking thread pool implementation. Both implementations co-exist and can be chosen using TF_THREAD_POOL env var. Fixes tensorflow#551 Fixes tensorflow#583 Update tensorflow#932 Update tensorflow#933
Current thread pool implementation is centralized and non-scalable. This change adds distributed, non-blocking thread pool implementation. Both implementations co-exist and can be chosen using TF_THREAD_POOL env var. Fixes tensorflow#551 Fixes tensorflow#583 Update tensorflow#932 Update tensorflow#933
Current thread pool implementation is centralized and non-scalable. This change adds distributed, non-blocking thread pool implementation. Both implementations co-exist and can be chosen using TF_THREAD_POOL env var. Fixes tensorflow#551 Fixes tensorflow#583 Update tensorflow#932 Update tensorflow#933
@dvyukov I ran the mnist training example distributed with Tensorflow (it's Python, but it exercises the C++ code). Unfortunately Reviewable swallowed the comment with the exact command line. Note that you need to have a machine with more than 16 vcores and no GPU enabled to see the contention. |
@jeremybarnes thanks for mnist pointer. FTR the command line is:
Ha, that also deadlocks for me. I can kind-a work around deadlocks by creating 8x more threads than requested. But that results in 800 threads created (there are 2 thread pools in process (why?)). And I guess it's not guaranteed to work anyway. |
Doh. I've added a global FIFO queue for externally submitted work and deadlocks go away. The pool is still faster than the default one. But what I see is that all work is on the global queue. So what provides speedup in our cases is merely thread spinning before blocking. Ouch. If we remove the worker work queue, I bet it will become just faster... Such parallel algorithms are not going to scale whatever pool is used for scheduling. The parallelization part needs to be rewritten. |
Could we just end up using kind of automatic multi-towering for large CPUs, like with multi-GPU case? TensoFlow is supposed to have clustering technique built-in, so this should be an easy pick? |
@tridemax Do you mean some kind of higher-level partitioning? Even when I run tensorflow/models/image/mnist/convolutional with just 4 threads, it is still unable to utilize 4 cores (only about 3.5). The partitioning will also have some overheads for communicate between partitions (I don't know how large it is, though). |
@dvyukov When I looked at the low CPU occupancy originally, I saw two main reasons:
In any event, the workloads we have for Tensorflow do cause many threads to submit jobs, which is partially what drove the design for the PR that I have open; I'd be open to a simpler global queue but I'm pretty sure it wouldn't solve our problem. |
Ok, then I guess we need to start with a representative set of benchmarks that we want to optimize. The only one that I was pointed to (mnist) submits work from a single thread. Different requirements lead to different designs. |
Agreed. I can submit a benchmark that is more representative of our use-case. It could be simulated by running label_image in 8-16 threads in parallel within a single process. It would help if someone working on Tensorflow from Google could provide a position on what kinds of external use-cases are interesting vs the internal ones to help us agree on a starting point. |
@benoitsteiner Can you please answer @jeremybarnes question? Good representative benchmarks are a good thing regardless of outcome of everything else here. I've struggled to find any ready benchmarks. |
I think maybe OpenMP will help a lot. |
Actually, ANY kind of network we ran (MNIST, AlexNet, few hybrid CNN+LSTM setups) on 20 core (40 threads) machine, yielded 20-30% load, regardless of the task. |
With ab02c5, I do see a significant speedup on CPUs. |
@benoitsteiner Yes. I trained relatively small networks on intel CPUs and did inference on an ARM CPU. On both systems (learning and inference), I did see 1.5x to 1.8x gains. Thanks a lot! |
Hi, my problem is similar to @alantus post in the beginning: How can I check how many devices TF was able to recognize? |
@tridemax @jeremybarnes Can you please test your benchmarks after commit ab02c5? |
As of change |
as mutex locking begins to dominate the cpu time available see: tensorflow/tensorflow#583 we have vary meny Tensor flow jobs to run (100s of 1000s) so internal parallelism is irrelevant anyway.
@vincentvanhoucke I still have a problem with multicore CPU utilization for MKL build |
I have 48 cpu cores and 4 gpu cores. I am using tfrecord reader. Even with a trivial network, my cpu utilization is capped at 1800%. I tried sharding my input data and using parallel_interleave but there is no usage improvement. |
…pstream-deven-misc-190726 fix for a couple subtest failures in the unit-test //tensorflow/python:nn_fused_batchnorm_test
I'm running tensorflow on a machine with 64 CPUs and no GPUs. I noticed that although tf detects the number of CPU's correctly (
I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 64
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 64
), the total CPU usage is less than 16 most of the time. It increases to 20-30 occasionally but for a very short period of time. And even that is much lower than 64.
The text was updated successfully, but these errors were encountered: