Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Multi-GPU Training Performance #146

Closed
WeichaoRen opened this issue Jul 12, 2017 · 2 comments
Closed

Multi-GPU Training Performance #146

WeichaoRen opened this issue Jul 12, 2017 · 2 comments

Comments

@WeichaoRen
Copy link

Have seen #124 and #17 , I made some further tests of the multi-gpu performance for just one machine.
To make it clear, I put some settings and corresponding metrics in a table:

No. batch_size # of gpus global_step/sec max_gpu_util
1 8192 1 around 2.5 99%
2 8192 8 around 0.6 95%
3 1024 1 around 12.3 78%
4 1024 8 around 1.5 23%

As for No.1 and No.2, I noticed that the effective batch size is 8x larger for No.2, no surprise the global_step/sec is smaller.
The problem lies in the 4th row, I guessed the effective batch size, i.e. the step, is as big as No.1, so I expected the global_step/sec to be several times larger than 2.5 as in No.1, but it's even smaller(around 1.5). And it's even too slow compared with No.3. Am I doing something wrong?

I also noticed that gpu_util are always below 25% for all gpus in setting No.4. Thus I further tested this 8 gpus things with batch size 2048, 512, etc. gpu_util are all pretty low for all gpus. Is this normal? What makes the difference of gpu_util between No.3 and No.4?

@lukaszkaiser
Copy link
Contributor

When using 8 gpus, there is an overhead from copying the parameters and gradients. You'll incur this overhead even if your batch size is small (you're mostly copying weights, which are batch-size independent). So in 4: you're probably spending most of the time copying stuff between GPUs and maybe the CPU. How fast the copies are depends on how your GPUs are connected; it can sometimes imptove with the --daisy_chain_variables flag and --gpu_order, see here: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/trainer_utils.py#L95
It's hard to know the GPU order between TF and NVidia drivers, so sometimes you just need to try out a few options and measure.

But, as I see your numbers, it's not surprising. With large batch and 8 GPUs, you lose about 5% of GPU utilization on waiting for the copies, but with small batch you lose much more, like 50% of the time, as the GPU finishes quickly and has nothing to do. It looks like there is a lot of copying going around, you should try the above options or putting variables on CPU (a lot of that speed depends on which GPU interconnects you have, do you have NVLink?). I hope this helps a little!

@WeichaoRen
Copy link
Author

@lukaszkaiser
Thanks! I definitely will try the options you mentioned, though I don't have NVLink actually :).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants