Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grpc RecvTensor is slow #6116

Closed
llhe opened this issue Dec 6, 2016 · 42 comments
Closed

grpc RecvTensor is slow #6116

llhe opened this issue Dec 6, 2016 · 42 comments
Assignees
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug

Comments

@llhe
Copy link
Contributor

llhe commented Dec 6, 2016

I made benchmark tests for distributed setup with loopback network, profiling it and found there is excessive memory copying in the client side of RecvTensor call, which is actually one of the bottleneck.

Here is the code, which mainly stolen from @yaroslavvb here,

  with tf.device(device1):                                                      
    params = tf.get_variable("params", shape=[params_size], dtype=dtype,        
                             initializer=tf.zeros_initializer)                  
  with tf.device(device2):                                                      
    # constant node gets placed on device1 because of simple_placer             
    #    update = tf.constant(1, shape=[params_size], dtype=dtype)              
    update = tf.get_variable("update", shape=[params_size], dtype=dtype,        
                             initializer=tf.ones_initializer())                 
    add_op = params.assign(update)

Here is the profiling result (google perftools) with tensor size 100MB (one fact is, the throughput will degrade with the increasing of tensor size):

From the result, the sending side (device2) look fine, but the receiving side (device1, the grpc client) consumes too many CPU cycles for the data transfer.

By the way, I made rough stats for this memmove call. For one round of 100MB tensor assignment, there are roughly 2GB data moved (actually, including the copy inside memmove, it should be 4GB copied with a naive memmove), which is 20+ times RAM bandwidth amplification (the result is an average of 100 round run, which may not precise but the scale should be ok).

@llhe
Copy link
Contributor Author

llhe commented Dec 6, 2016

To make things more clear, I collected more detailed data for memmove:

  // int move_size = (sb->count - 1) * sizeof(gpr_slice);
  memmove(&sb->slices[0], &sb->slices[1], (sb->count - 1) * sizeof(gpr_slice));
  // int data_size = GPR_SLICE_LENGTH(slice);

A typical move_size, slice_size sequence,

  • move_size: 6096, slice_size: 608
  • move_size: 6072, slice_size: 8192
  • move_size: 6048, slice_size: 7583
  • move_size: 6024, slice_size: 600
  • move_size: 6000, slice_size: 8192
  • move_size: 5976, slice_size: 7591
  • move_size: 5952, slice_size: 592
  • move_size: 5928, slice_size: 8192
  • move_size: 5904, slice_size: 7599
  • move_size: 5880, slice_size: 584
  • move_size: 5856, slice_size: 8192
  • move_size: 5832, slice_size: 7607
  • move_size: 5808, slice_size: 576
  • move_size: 5784, slice_size: 8192
  • move_size: 5760, slice_size: 7615
  • move_size: 5736, slice_size: 568
  • move_size: 5712, slice_size: 8192
  • move_size: 5688, slice_size: 7623
  • move_size: 5664, slice_size: 560
  • move_size: 5640, slice_size: 8192
  • move_size: 5616, slice_size: 7631
  • move_size: 5592, slice_size: 552
  • move_size: 5568, slice_size: 8192
  • move_size: 5544, slice_size: 7639
  • move_size: 5520, slice_size: 544
  • move_size: 5496, slice_size: 8192
  • move_size: 5472, slice_size: 7647
  • move_size: 5448, slice_size: 536
  • move_size: 5424, slice_size: 8192
  • move_size: 5400, slice_size: 7655
  • move_size: 5376, slice_size: 528
  • move_size: 5352, slice_size: 8192
  • move_size: 5328, slice_size: 7663
  • move_size: 5304, slice_size: 520

So the problem is obvious (the slice_size will sum to 100MB per run). The root cause should be the grpc buffer management does not work well for large message. This also explains why the throughput will decrease with the increase of the tensor size.

Not quite familiar with the grpc code, adding an grpc option to change 'gpr_slice_buffer_take_first' to 'gpr_slice_buffer_take_all' can remove the unnecessary memory copy? Tuning the slice size can also help reducing the overhead but can't eliminate it.

@ctiller
Copy link

ctiller commented Dec 6, 2016

Interesting!

The correct fix is probably to have grpc_chttp2_incoming_byte_stream become a ring-like buffer, so instead of doing a move down the slice array, we just increment an index. When we reach the end of slices, we can reset the counter to zero.

I'll make sure someone takes a look soon.

@michaelisard michaelisard added the type:bug Bug label Dec 6, 2016
@yaroslavvb
Copy link
Contributor

@llhe amazing investigation work!

@yaroslavvb
Copy link
Contributor

I wonder if there's extra ineffiency in that benchmark in that repeated fields are used (Tensor::AsProtoField) rather than tensor_data (Tensor::AsProtoTensorContent), I see 11% of the time being spent in RepeatedField::Reserve

@llhe
Copy link
Contributor Author

llhe commented Dec 8, 2016

@yaroslavvb Do you mean the serialization in sending side? I haven't identified that issue. I use float32 with size 100MB, and looks like it goes in to this branch as expected. And in the receiving side, it also goes to this branch and in protobuf.

However, the 18.2% time consumption looks unexpected high to me for the bare memcpy, compared to AssignOp's memcpy. Maybe just caused by poor alignment?

@yaroslavvb
Copy link
Contributor

yaroslavvb commented Dec 9, 2016

As discussed off-channel, we were looking at slightly different benchmarks. My original local_distributed_benchmark.py does sess.run(add_op) which fetches the buffer back into Python runtime, while timings in this issue is for sess.run(add_op.op) which doesn't.

The "fetching into Python" version is ridiculously slow for gRPC runtime (0.05 GB/sec grpc vs 3.4 GB/sec in-process), whereas the non-fetching is just slow (0.9 GB/sec grpc vs 20.2 GB/sec in-process) for Xeon(R) CPU E5-2630 v3 @ 2.40GHz

@yaroslavvb
Copy link
Contributor

BTW, you can partially mitigate this problem by running multiple ps processes and sharding your variables over them.
On a 32 core Xeon, I can get transfer rate to go from 0.9 GB/s to 2.6 GB/s
sharded_ps_benchmark.py

this makes the logic more complicated unfortunately

./sharded_ps_benchmark.py --ps=8
...
worker 0 done: 2555.34 MB per second

@llhe
Copy link
Contributor Author

llhe commented Dec 9, 2016

@yaroslavvb Good suggestion! Make more shardings should help.

@llhe
Copy link
Contributor Author

llhe commented Dec 15, 2016

This is being fixed here: grpc/grpc#8975.
I made an integration with the temp fix: https://github.com/llhe/tensorflow/tree/grpc-fix.

The issue with grpc polling buffer is resolved:
device1-prof-grpc-fix.pdf

The current remaining unnecessary memory copy is a known issue marked as TODO by Jeff and Sanjay: tensor raw content decoding which actually buffer memcpy.

@yaroslavvb
Copy link
Contributor

With that patch, I'm seeing 10x improvement for transfer speed of 1GB buffer

./sharded_ps_benchmark.py --ps=1 --iters=1 --data_mb=1024
worker 0 done: 104.52 MB per second

after patch
worker 0 done: 1088.05 MB per second

@aselle
Copy link
Contributor

aselle commented Dec 20, 2016

@jhseu, thought you might want to take a look at this issue. Thanks.

@mbz
Copy link
Contributor

mbz commented Dec 30, 2016

Is there any news on this being ported over to TensorFlow? Seems to be a very straightforward patch.

@jhseu
Copy link
Contributor

jhseu commented Dec 30, 2016

Looks like it hasn't been merged into gRPC yet. I'll make sure it gets into the TensorFlow 1.0 release.

@llhe
Copy link
Contributor Author

llhe commented Feb 10, 2017

Anyone working on this? The grpc patch is merged. I can have a try.

@jhseu
Copy link
Contributor

jhseu commented Feb 10, 2017

Yeah, that'd be great! Note that it didn't make it into the gRPC 1.1 release, so you'd have to update gRPC to a more recent git version.

We also use a custom BUILD file that's in third_party/grpc.BUILD. You can diff it from the version it's derived from to see its changes. Ideally, we'd just use gRPC's build file and not have our own, but I'm not sure how feasible that is. I'm not sure how many of those changes are still required, or whether any more might be needed. We haven't updated gRPC for a few months.

@jhseu
Copy link
Contributor

jhseu commented Feb 10, 2017

Also, if you update, it'd be useful to run this benchmark to compare before and after:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpcbench_test.cc

Run it with: bazel run -c opt rpcbench_test -- --benchmarks=all

@ctiller
Copy link

ctiller commented Feb 10, 2017 via email

@aselle aselle added the stat:awaiting response Status - Awaiting response from author label Feb 10, 2017
@llhe
Copy link
Contributor Author

llhe commented Feb 12, 2017

@ctiller Is bazel still supported in gRPC trunk? Looks like there are somethings broken, like unreferenced header, incorrect header path, redefined symbols etc. I'm trying to solve them to make it build with tensorflow.

@ctiller
Copy link

ctiller commented Feb 12, 2017 via email

llhe added a commit to llhe/tensorflow that referenced this issue Feb 13, 2017
llhe added a commit to llhe/tensorflow that referenced this issue Feb 13, 2017
@yaroslavvb
Copy link
Contributor

@tfboyd

llhe added a commit to llhe/tensorflow that referenced this issue Feb 17, 2017
jhseu pushed a commit to llhe/tensorflow that referenced this issue Feb 18, 2017
@aselle aselle added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Mar 1, 2017
vjpai pushed a commit to vjpai/tensorflow that referenced this issue May 3, 2017
@tfboyd
Copy link
Member

tfboyd commented Jun 16, 2017

This PR is still being worked on. The PR was updated recently.

@byronyi
Copy link
Contributor

byronyi commented Aug 9, 2017

Shall we close this issue as gRPC upgrade has been completed by @jhseu? I re-ran the test benchmark_grpc_recv.py and found it has been improved from 100-200 MB/s to 800 MB/s. It now fully utilizes a 10Gbps link, and it is actually much better than we'd expected.

@yaroslavvb
Copy link
Contributor

@jhseu @byronyi woohoo! progress!

@byronyi
Copy link
Contributor

byronyi commented Aug 9, 2017

Wow @yaroslavvb I suppose you are a committer to TF now? That is awesome.

@suiyuan2009
Copy link
Contributor

@byronyi , I re-ran the test using tf built from the latest master branch, and I still get

Local rate:       1766.32 MB per second
Distributed rate: 287.89 MB per second

not much improvement.
I ran the script using 3 different machines, and grpc+gdr is about 3400MB/s, grpc is about 500MB/s.

@byronyi
Copy link
Contributor

byronyi commented Aug 14, 2017

For your reference:

$ md5sum benchmark_grpc_recv.py
cebe408e4063bb9db817a2f75d5cd792  benchmark_grpc_recv.py
$ python benchmark_grpc_recv.py
2017-08-14 14:00:54.045878: I tensorflow/core/common_runtime/gpu/gpu_device.cc:962] Found device 0 with properties:
name: Tesla K40m major: 3 minor: 5 memoryClockRate(GHz): 0.745
pciBusID: 0000:02:00.0
totalMemory: 11.17GiB freeMemory: 456.25MiB
2017-08-14 14:00:54.248079: I tensorflow/core/common_runtime/gpu/gpu_device.cc:962] Found device 1 with properties:
name: Tesla K40m major: 3 minor: 5 memoryClockRate(GHz): 0.745
pciBusID: 0000:03:00.0
totalMemory: 11.17GiB freeMemory: 456.25MiB
2017-08-14 14:00:54.463848: I tensorflow/core/common_runtime/gpu/gpu_device.cc:962] Found device 2 with properties:
name: Tesla K40m major: 3 minor: 5 memoryClockRate(GHz): 0.745
pciBusID: 0000:82:00.0
totalMemory: 11.17GiB freeMemory: 456.25MiB
2017-08-14 14:00:54.690266: I tensorflow/core/common_runtime/gpu/gpu_device.cc:962] Found device 3 with properties:
name: Tesla K40m major: 3 minor: 5 memoryClockRate(GHz): 0.745
pciBusID: 0000:83:00.0
totalMemory: 11.17GiB freeMemory: 456.25MiB
2017-08-14 14:00:54.691945: I tensorflow/core/common_runtime/gpu/gpu_device.cc:977] Device peer to peer matrix
2017-08-14 14:00:54.692043: I tensorflow/core/common_runtime/gpu/gpu_device.cc:983] DMA: 0 1 2 3
2017-08-14 14:00:54.692057: I tensorflow/core/common_runtime/gpu/gpu_device.cc:993] 0:   Y Y N N
2017-08-14 14:00:54.692067: I tensorflow/core/common_runtime/gpu/gpu_device.cc:993] 1:   Y Y N N
2017-08-14 14:00:54.692075: I tensorflow/core/common_runtime/gpu/gpu_device.cc:993] 2:   N N Y Y
2017-08-14 14:00:54.692083: I tensorflow/core/common_runtime/gpu/gpu_device.cc:993] 3:   N N Y Y
2017-08-14 14:00:54.692099: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1052] Creating TensorFlow device (/gpu:0) -> (device: 0, name: Tesla K40m, pci bus id: 0000:02:00.0, compute capability: 3.5)
2017-08-14 14:00:54.692111: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1052] Creating TensorFlow device (/gpu:1) -> (device: 1, name: Tesla K40m, pci bus id: 0000:03:00.0, compute capability: 3.5)
2017-08-14 14:00:54.692121: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1052] Creating TensorFlow device (/gpu:2) -> (device: 2, name: Tesla K40m, pci bus id: 0000:82:00.0, compute capability: 3.5)
2017-08-14 14:00:54.692130: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1052] Creating TensorFlow device (/gpu:3) -> (device: 3, name: Tesla K40m, pci bus id: 0000:83:00.0, compute capability: 3.5)
E0814 14:00:55.341150222    5197 ev_epoll1_linux.c:1051]     grpc epoll fd: 49
2017-08-14 14:00:56.331284: E tensorflow/stream_executor/cuda/cuda_driver.cc:406] failed call to cuInit: CUDA_ERROR_NO_DEVICE
E0814 14:00:56.331737994    5285 ev_epoll1_linux.c:1051]     grpc epoll fd: 3
2017-08-14 14:00:56.353470: E tensorflow/stream_executor/cuda/cuda_driver.cc:406] failed call to cuInit: CUDA_ERROR_NO_DEVICE
E0814 14:00:56.353890934    5283 ev_epoll1_linux.c:1051]     grpc epoll fd: 3
Local rate:       10118.01 MB/s
Distributed rate: 726.68 MB/s

@suiyuan2009
Copy link
Contributor

I got

Local rate:       10841.73 MB/s
Distributed rate: 971.94 MB/s

, but have you run workers on two different machines, how is the speed?

@suiyuan2009
Copy link
Contributor

suiyuan2009 commented Aug 14, 2017

I use this script to test tensor transmission on 3 different machines, note that task 2 is a client which is responsible for submitting job. grpc is about 500MB/s, grpc+gdr is about 3400MB/s.

python3 tensor_transmission.py --host=xxxx1 --port1=xx1 --host_2=xxxx2 --port2=xx2 --task=0
python3 tensor_transmission.py --host=xxxx1 --port1=xx1 --host_2=xxxx2 --port2=xx2 --task=1
python3 tensor_transmission.py --host=xxxx1 --port1=xx1 --host_2=xxxx2 --port2=xx2 --task=2

@yaroslavvb
Copy link
Contributor

So I just tested this again on latest version, and the speed is roughly the same. However, this 971 MB/second is too slow. AWS now has 25 Gbps ethernet cards, so using this capacity requires being able to send 100MB tensor in <=32ms. Currently RecvTensor of that size takes >100ms locally

wget -N https://gist.githubusercontent.com/yaroslavvb/e196107b5e0afc834652bd3153030c42/raw/5ff6df416933232fc2d3f09416e9bee50b221367/benchmark_grpc_recv.py
python benchmark_grpc_recv.py --data_mb=100 --iters=100

Local rate:       15320.74 MB/s
Distributed rate: 872.92 MB/s

screenshot 2018-03-05 20 08 06

@byronyi
Copy link
Contributor

byronyi commented Mar 6, 2018

@yaroslavvb I guess you could give the latest accelerated networking feature a shot, which provides Mellanox ConnectX-3/ConnectX-3 Pro Virtual Function (and supposedly supports RDMA).

@yaroslavvb
Copy link
Contributor

@byronyi well, this slowness within a single machine, so the bottleneck must software related. My suspicion is that RecvTensor involves a single-threaded memcpy somewhere, this would explain 75% of the slowness -- I see 105ms, while 100 MB at 1.25 GBps memcpy speed = 80ms

@byronyi
Copy link
Contributor

byronyi commented Mar 8, 2018

@yaroslavvb Well I do agree more optimisations could be done with gRPC to eliminate memcpy even with current TCP/IP stack. The loopback performance in my DigitalOcean box is around 20 Gbps:

$ sudo perf record netperf -t TCP_STREAM -H 127.0.0.1
MIGRATED TCP STREAM TEST from 0.0.0.0 (0.0.0.0) port 0 AF_INET to 127.0.0.1 () port 0 AF_INET : demo
Recv   Send    Send
Socket Socket  Message  Elapsed
Size   Size    Size     Time     Throughput
bytes  bytes   bytes    secs.    10^6bits/sec

 87380  16384  16384    10.00    18905.90
[ perf record: Woken up 6 times to write data ]
[ perf record: Captured and wrote 1.338 MB perf.data (30386 samples) ]

But I don't think it is really about memcpy, as the kernel also memcpy and it seems doing it pretty fast:

$ sudo perf report --header
# ========
# captured on: Thu Mar  8 05:09:45 2018
# hostname : localhost
# os release : 4.15.0-1-amd64
# perf version : 4.15.4
# arch : x86_64
# nrcpus online : 4
# nrcpus avail : 4
# cpudesc : Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz
# cpuid : GenuineIntel,6,79,1
# total memory : 8172116 kB
# cmdline : /usr/bin/perf_4.15 record netperf -t TCP_STREAM -H 127.0.0.1
# event : name = cycles, , size = 112, { sample_period, sample_freq } = 1500, sample_type = IP|TID|TIME|PERIOD, disabled = 1, inherit = 1, mmap = 1, comm = 1,
# CPU_TOPOLOGY info available, use -I to display
# NUMA_TOPOLOGY info available, use -I to display
# pmu mappings: breakpoint = 5, cpu = 4, software = 1, tracepoint = 2, msr = 6
# CACHE info available, use -I to display
# missing features: TRACING_DATA BRANCH_STACK GROUP_DESC AUXTRACE STAT
# ========
#
#
# Total Lost Samples: 0
#
# Samples: 13K of event 'cycles'
# Event count (approx.): 19061735130
#
# Overhead  Command  Shared Object      Symbol
# ........  .......  .................  ..............................................
#
    34.79%  netperf  [kernel.kallsyms]  [k] copy_user_enhanced_fast_string
     4.40%  netperf  [kernel.kallsyms]  [k] tcp_sendmsg_locked
     3.08%  netperf  [kernel.kallsyms]  [k] __tcp_ack_snd_check
     2.55%  netperf  [kernel.kallsyms]  [k] __pv_queued_spin_lock_slowpath
     2.29%  netperf  [kernel.kallsyms]  [k] syscall_return_via_sysret
     1.77%  netperf  [kernel.kallsyms]  [k] pvclock_clocksource_read
     1.49%  netperf  [unknown]          [k] 0xfffffe000003201e
     1.38%  netperf  [kernel.kallsyms]  [k] __raw_callee_save___pv_queued_spin_unlock
     1.23%  netperf  [kernel.kallsyms]  [k] get_page_from_freelist

@yaroslavvb
Copy link
Contributor

It sounds like your loopback performance is also bottlenecked by single-threaded memcpy, 20 Gbps loopback seems low given that AWS can do 25 Gbps over ethernet (ps, my earlier numbers are GBps rather than Gbps)

@byronyi
Copy link
Contributor

byronyi commented Mar 8, 2018

@yaroslavvb If you are using in kernel TCP/IP stack then it’s rather unlikely to avoid memory copy at receiving side even if you use the new feature MSG_ZEROCOPY available since Linux 4.9 :)

@byronyi
Copy link
Contributor

byronyi commented Mar 8, 2018

Just a quote from the article:

Readers might be wondering why the patch does not support zero-copy reception; while the patch set itself does not address this question, it is possible to make an educated guess. Reading is inherently harder because it is not generally known where a packet is headed when the network interface receives it. In particular, the interface itself, which must place the packet somewhere, is probably not in a position to know that a specific buffer should be used. So incoming packets end up in a pile and the kernel sorts them out afterward. Fancier interfaces have a fair amount of programmability, to the point that zero-copy reception is not entirely infeasible, but it remains a more complex problem. For many common use cases (web servers, for example), transmission is the more important problem anyway.

@yaroslavvb
Copy link
Contributor

So I think the original problem with RecvTensor is the single-threaded memcpy. You can do memory copy fast if you use multiple threads. For instance putting 100MB object into Ray storage takes 17ms, that involves a memory copy and translates to about 50 Gbps. I can add a constant to 100MB worth of 1s and put result into new memory in 3.5ms, that's about 250 Gbps and is probably the upper limit of how fast you can copy memory on XeonV4

@byronyi
Copy link
Contributor

byronyi commented Mar 8, 2018

Would a possible workaround be splitting your tensor into partitions and load balance your RecvTensor calls? I think that’s what TF currently does to load balance PS tasks.

@byronyi
Copy link
Contributor

byronyi commented Mar 8, 2018

Plus for a truly distributed case, as TCP connection is stream oriented, it would involve a fair amount of locking if you’d perform multithreading with the same socket and effectively make its performance equivalent to the single-thread case at best. If we restrict it to the intranode case, why not just use immutable shared memory? You could avoid bulk memcpy all together.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower type:bug Bug
Projects
None yet
Development

No branches or pull requests

10 participants