Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

[Tuning] Results are GPU-number and batch-size dependent #444

Open
vince62s opened this issue Nov 27, 2017 · 81 comments
Open

[Tuning] Results are GPU-number and batch-size dependent #444

vince62s opened this issue Nov 27, 2017 · 81 comments
Labels

Comments

@vince62s
Copy link
Contributor

@lukaszkaiser
This is to illustrate what I have discussed on gitter.

Working with WMT EN-FR, I have observed the following.
You can replicate the paper results with "transformer -base" with 4 GPU.
The BLEUapprox looks like this: (batch-size 4096, warmup step 6000)
image

If I do the same on 3 GPU (batch size 4096 warmup step 8000), taking into account that I need to compare step 120K of 4GPU run vs 160K of the 3GPU run, I get this with a clear offset of 1 BLEU point.
The gap is never closed if we wait.
image

If I do the same on 2GPU, it's even lower, 1GPU same.

Also, I observed that it is very dependent on the batch size.
For instance if you lower to 3072 you don't get the same as with 4096
With 2048 even lower.

This makes impossible to replicate the Transformer BIG results since you can only fit a batch size of 2048 even on a GTX1080ti.

Hope this helps for better tuning.

@martinpopel
Copy link
Contributor

My experience is exactly the same. Lower batch size (either nominally lower with the same number of GPUs or effectively lower because of smaller number of GPUs) results in worse results, even if I train long enough to compensate for the lower batch.
I remember people working with Nematus or OpenNMT were surprised by this behavior of Transformer/T2T because their experience was that lower batch size leads to better results in the end (but slower of course, thus they sometimes start training with big batch and then switch to lower batch for fine-tuning).

@mehmedes
Copy link

@vince62s you mentioned that the gap is never closed even if you wait to compensate the batch size difference.

Have you been able to compensate by decreasing learning rate when decreasing batch size ? As suggested here, maybe it would make sense to define the learning rate by multiplying it by x and y if batch size changes nominally (the same number of GPUs ) by x and/or effectively by y (different number of GPUs)?

@vince62s
Copy link
Contributor Author

not sure, because adam / noam is supposed to be adaptive. all the tests I did (changing the lr) were not better.
All what I know is that there are 2 places where the number of replicas has an impact:
here warmup steps are multiplied by the number of gpu

hparams.learning_rate_warmup_steps * num_worker_replicas)

here the lr is divided by the sqrt of nb of gpu
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/model_builder.py#L219

@martinpopel
Copy link
Contributor

martinpopel commented Dec 13, 2017

I have nice TensorBoard BLEU curves illustrating that training with more GPUs converges faster to higher BLEU (which is what everyone expects). When answering the question whether N GPUs are N times faster (or less or more, which is what this issue is about), it depends whether we plot steps or hours on the x axis:

bleu-vs-gpu-steps
Here we can see e.g. that 8 GPUs (the topmost curve) achieved BLEU=24 after 68k steps of training, while 1 GPU (the bottom-most curve) needed something between 800k and 1050k steps to achieve (and never go below) the same BLEU. So 8 GPUs are more than 8 times faster when measuring the BLEU convergence vs training steps. However, one step takes less seconds with 1 GPU than with 8 GPUs, so I think it is more relevant to plot time instead of steps on the x axis (switch to "Relative" in TensorBoard).

bleu-vs-gpu-time
Here we can see that 8 GPUs achieved BLEU=24 after 12 hours of training, while 1 GPU needed something between 84 and 108 hours.
So now the expectation 8*12=96 is within the measured range.
The best BLEU/number_of_GPUs curve seems to be achieved by two GPUs, which crossed the BLEU=24 line after something between 31 and 38 hours, so here the expectation 61-72 hours is clearly below the 1GPU range 84-108.

I am still waiting whether 1GPU can achieve as good results as 8GPU, after long enough time (several weeks). So the main question of this issue still remains unanswered.

All GPUs are of the exactly same type (1080 Ti). Of course, these results should be still taken with a grain of salt - it is just one testset (newstest2013), different testsets may give (slightly) different results.

@martinpopel
Copy link
Contributor

martinpopel commented Dec 13, 2017

And more post (this time with base models instead of big) showing that you should always use the highest possible batch_size (comparing batch size 1500, 3000 and 4500, all three trained just on one GPU):
bleu-vs-batch-time

@fstahlberg
Copy link
Contributor

Some results with T2T 1.3.2:

--problems=translate_ende_wmt32k --model=transformer --hparams_set=transformer_base_v2 --hparams=batch_size=8192,max_length=150 --worker_gpu=4

I'm running this on 4 P100 GPUs - As far as I understand the code --hparams=batch_size=8192 --worker_gpu=4 should be the same as --hparams=batch_size=4096 --worker_gpu=8

On news-test2014 (with averaging):
after 100k steps: 27.0 BLEU
after 200k steps: 27.4 BLEU
after 300k steps: 27.5 BLEU

This is not very close to the 28.2 BLEU reported here.

After reading vince62s post, maybe the t2t authors used worker replicas rather than --worker_gpu? In that case, shouldn't we be doing the same things to wramup_steps and learning rate with worker_gpu as with worker_replicas, ie. replacing num_worker_replicas in the code vince62s pointed to with something like worker_gpu*worker_replicas?

@vince62s
Copy link
Contributor Author

vince62s commented Dec 14, 2017

I think you might be correct. I realize my above comment was misleading since I confused replicas and gpu.
Now that I see exactly what their intention was, your last suggestion make 100% sense.
@lukaszkaiser any insight ?

EDIT: actually not exactly, because if the "default" values are tuned for 8 GPU / 1 replica, then we would need to make a prorata for learning_rate and warmup_steps.

@vince62s
Copy link
Contributor Author

Second thought (actually maybe 1000th one).
When we do learning_rate /= math.sqrt(float(worker_replicas))
If this is calibrated / tuned for 1 replica and 8 gpu, it would mean that when we run on one machine with 4 GPU, we would actually need to INCREASE the learning rate (equivalent of replica = 0.5)
makes sense ?

@fstahlberg
Copy link
Contributor

Well, I think if it was calibrated for 1 replica we would be fine since dividing the learning rate or multiplying warmup_steps by worker_replicas would have no effect. My current thesis is that the default parameters are tuned for 8 replica, each with 1 GPU. In this case we need to decrease the learning rate and increase warmup steps in order to simulate the setting on a single machine with multiple GPUs. I am trying that right now...

@vince62s
Copy link
Contributor Author

Paper says one machine, 8 GPUs.

@fstahlberg
Copy link
Contributor

You are right.. that means I'm still deeply confused. @lukaszkaiser the exact training command for replicating the 28.2 BLEU would be very helpful.

@mehmedes
Copy link

mehmedes commented Dec 15, 2017

If the gpu memory is not sufficient for the ideal batch size of 4096, @martinpopel suggested in #446 to use transfomer_big_single_gpu (orange) and to set the batch size as high as possible to get the best results.
After adjusting learning_rate and learning_rate_warmup_steps as @fstahlberg recommends , it looks like my loss curve for transfomer_big (red) will eventually cut the transfomer_big_single_gpu curve.

image

I adjusted learning_rate and learning_rate_warmup_steps based on transformer_v2 as follows:

image

@martinpopel
Copy link
Contributor

As I have commented elsewhere, I think that in the Attention is all you need paper they use batch_size=3072 (which multiplied by 8 gives the approx. 25000 tokens per batch reported in the paper). However, the number 3072 never appeared in the source code of transformer.py.

@lukaszkaiser
Copy link
Contributor

For the best base model so far (28.2) I used 8 gpus with the default transformer_base. Some recent papers suggested scaling learning rate linearly or square-root-like with batch size, so according to them if we go down from 8 to 2 gpus we should scale learning rate down by 2x or 4x. Martin: could you try these? I'll try to reproduce the above results to make sure we understand it better. If it's indeed the case, then we should probably add automatic LR scaling...

@vince62s
Copy link
Contributor Author

I would have said the contrary, ie the more GPUs the smaller LR since the batch size isx times bigger.

@martinpopel
Copy link
Contributor

First a note: The graphs I posted above are all on the translate_encs_wmt32k task evaluated on the newstest2013 with the real BLEU, but I believe similar observations hold also for EnDe translation and other tasks.

Now I did some experiments with learning rate (and 1 GPU and a fixed batch size):

  • lr=0.01 converges noticeably slower than the default lr=0.2 (4 BLEU worse even after 2 days of training).
  • lr=0.3 (or higher) diverges after few hours, so I had to double the warmup steps. Afterwards, it is about the same as the default lr=0.2.
  • lr=0.25, lr=0.1 and lr=0.05 are about the same as the default lr=0.2.
    By "about the same" I mean that the learning curves cross each other frequently (the variance from checkpoint to checkpoint is higher than the difference). I will try even higher lr.

@vince62s
Copy link
Contributor Author

vince62s commented Dec 18, 2017

My results from the very first post were with 1.2.9 if I recall well.
Re-running the same 4GPU experiment with 1.3.2 gives me worse results ....
Anything changed since ?

@mehmedes
Copy link

@martinpopel Have you had success with high learning rates? Increased learning rates seem to make up for lower batch size. I'm currently training a transformer_big model with a LR of 0,8 (warmup=32k) on 4 GPUs each with a batch size of 2000. So far this provides me with the greatest loss on 4x1080TIs.
I multiplied LR and warmup by the quotient of ideal batch size / available batch size.

@martinpopel
Copy link
Contributor

@mehmedes: No success with higher learning rates. I've tried lr=0.5 (and warmup 32k) and it is still about the same as other learning rates (except for lr=0.01 which is clearly worse). Then I tried lr=1 and it diverged (BLEU=0).

@vince62s
Copy link
Contributor Author

@mehmedes did you through the end fo your training, how did it go ?

@mehmedes
Copy link

My training with LR=0.8 is still running. My loss curve hasn't flattened out yet, compared to previous models (see above) where I divided the LR by the ideal / available batch size ratio instead of multiplying it.

image

@mehmedes
Copy link

@martinpopel did you test LR=0.5 on a single GPU or on multiple?

@martinpopel
Copy link
Contributor

@mehmedes: All my experiments with learning rate so far are on a single GPU.

As I think about it I am afraid there is no easy way (one magical formula) how to exactly compensate for a lower batch size (caused e.g. by less GPUs) with a learning rate scaling:
If we keep the same learning rate schedule then after x steps with 8 GPUs we have the same lr as with 1 GPU, although we have seen 8 times more examples. So it is tempting to set the lr 8 times bigger on 1 GPU to compensate, but this does not work. If the lr is too high in any moment, the learning diverges. We can (and should) fight this with more warmup steps on 1 GPU, but this is not really equivalent to the multi-GPU setting with less warmup steps (and smaller lr).
I think it also depends on how long do you plan (can afford) to train: there may be two hyperparameter setups, one outperforming the other only after N days or hours of training (this is the case of base vs. big). And this is related also to the training data size (training too long on small data leads to overtraining).

@mehmedes
Copy link

@martinpopel yes, that's true.

What I find curious about T2T is that the LR impact behaves inverserly proportional.
In Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour Facebook multiplies LR by 32 because they increase batch size by 32.
image

So far, I've made the experice in T2T that if I decrease the batch size by x, I need to multiply LR and warmup steps by x
image

Any ideas, why?

@vince62s
Copy link
Contributor Author

@fstahlberg I did a quick experiment.
I confirm your initial thought.
1GPU+BS4096 = 2GPU+BS2048 = 4GPU+BS1024
I am getting very close loss values and approx bleu values for these 3 scenarios.
maybe you can try with another seed and a bit longer they went up to 350K steps I think.

@mehmedes
Copy link

mehmedes commented Mar 1, 2018

@vince62s you may also need to increase --worker_gpu_memory_fraction=0.978. Otherwise it may first train but then crash after a few hours. I had to stop my x server and run ubuntu from command line in order to allow T2T use the entire GPU memory -- and you shouldn't be running tensorboard in parallel on the same machine, which may also cause a crash.

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.111                Driver Version: 384.111                   |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce GTX 108...  Off  | 00000000:02:00.0 Off |                  N/A |
| 46%   57C    P2   235W / 250W |  11163MiB / 11164MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:03:00.0 Off |                  N/A |
| 45%   55C    P2   253W / 250W |  11171MiB / 11172MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce GTX 108...  Off  | 00000000:81:00.0 Off |                  N/A |
| 42%   53C    P2   245W / 250W |  11171MiB / 11172MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+
|   3  GeForce GTX 108...  Off  | 00000000:82:00.0 Off |                  N/A |
| 33%   46C    P2   252W / 250W |  11171MiB / 11172MiB |    100%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0     16764      C   /usr/bin/python                            11153MiB |
|    1     16764      C   /usr/bin/python                            11161MiB |
|    2     16764      C   /usr/bin/python                            11161MiB |
|    3     16764      C   /usr/bin/python                            11161MiB |
+-----------------------------------------------------------------------------+

The params I use for batch_size=6000 are:

@registry.register_hparams
def transformer_big_adafactor():
  hparams = transformer_big()
  hparams.optimizer = "Adafactor"
  hparams.learning_rate_schedule = "rsqrt_decay"
  hparams.learning_rate_warmup_steps = 10000
  hparams.attention_dropout_broadcast_dims = "0,1"  # batch, heads
  hparams.relu_dropout_broadcast_dims = "1"  # length
  hparams.layer_prepostprocess_dropout_broadcast_dims = "1"  # length
  return hparams

And during data generation you also need to use _packed problems with max_length=100:

@registry.register_problem
class TranslateEndeWmt32kPacked(TranslateEndeWmt32k):

  @property
  def packed_length(self):
    return 100

@vince62s
Copy link
Contributor Author

vince62s commented Mar 6, 2018

Ok here are my last results on WMT ENDE32k
Base, Adafactor, batchsize 8192, 4 GPU [total 32768 which is more than the 25k of the paper)
Newstest 2014: after 100k steps, on avg10ckp = 26.70
after 500k steps on avg10ckp = 27.2, on avg20ckp = 27.35

Big, Adafactor, batchsize 5000, 4 GPU
Newstest 2014: after 200k steps, on avg20ckp = 27.6
after 300k steps on avg20ckp = 27.66

These results are still under the paper Base 100k steps 27.3, Big 300k steps 28.4

I would really love to know if you guys at GG Brain @lukaszkaiser @rsepassi still replicate
the paper results in the same conditions (100k steps for base, 300k steps for big) with the current code.

@liesun1994
Copy link

@vince62s 😆

@liesun1994
Copy link

@vince62s #317 may helps u .

@duyvuleo
Copy link
Contributor

duyvuleo commented Mar 29, 2018

Hi all, can i double check the scores you guys produced in your experiments? Are they with t2t-bleu or sacreBLEU (with or without --tok intl)? Thanks!

@vince62s
Copy link
Contributor Author

Talking for myself, I always report BLEU from mteval13a.pl without intl tok, and this is the same as multi-bleu-detok.perl
As far as I know, this is what is usually reported in papers.

@martinpopel
Copy link
Contributor

I report BLEU with t2t-bleu, which should be the same as mteval-v14.pl --international-tokenization and as sacrebleu -lc -tok intl (note the lc for lowercase, i.e. uncased). The international tokenization has higher correlation with human ranking (unless the sentences are ASCII only, in which case it has no effect, but hey we live in the 21st century and not all languages are ASCII only), but I agree it is not used in all papers. Unfortunately, there is no single most popular BLEU variant used in papers - some use cased, some uncased, some use multi-bleu.perl with detokenization, some with detokenization and tokenize.perl...
My advice is to use sacreBLEU and always report its signature. Note that sacreBLEU can report also the chrF metric.

@duyvuleo
Copy link
Contributor

Thanks @vince62s and @martinpopel for your replies.
In my case, scores with t2t-bleu and sacrebleu -tok intl are (a bit) higher than the ones without "-tok intl" (in English -> French, 2-3 BLEU scores are higher). This makes the comparisons with the existing SOTAs unfair, and I will not know whether I trained the system in a proper way. That makes me confusing.

@vince62s
Copy link
Contributor Author

Just use sacreBleu without -tok intl and without -lc and you will be comparable I think.

@DC-Swind
Copy link

DC-Swind commented Apr 12, 2018

@martinpopel : I saw you draw this picture (BLEU_uncased) in tensorboard: https://user-images.githubusercontent.com/724617/33940325-d217ba74-e00e-11e7-9996-5132b62d51dc.png
But I just have "approx_bleu_score" curve. Is "approx_bleu_score" the curve of trainning set? How to get the "BLEU_uncased" curve of test set?

@martinpopel
Copy link
Contributor

@DC-Swind: approx_bleu is computed on the dev set, but using the internal subword tokenization, so it is not replicable (and it is not reliable because of using gold reference last word). I use t2t-bleu and t2t-translate-all for plotting the (real) BLEU curves.

@DC-Swind
Copy link

@martinpopel : There is no event file was generated by t2t-blue and t2t-translate-all , how to get the curve in tensorboard? Could you provide the detailed command? I just want to plot a single BLEU curve of test set for a specified model (which is stored in the $TRAIN_DIR)

@martinpopel
Copy link
Contributor

@DC-Swind: t2t-bleu creates the event file if called with proper parameters, see https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/bin/t2t_bleu.py#L27-L53

@AranKomat
Copy link

Is there anyone who tried greater batch sizes than recommended to see whether it is possible to gain better performance than in the paper? With accumulation of gradients, one can arbitrarily increase the batch size. Was my question already answered before?

@martinpopel
Copy link
Contributor

The effect of batch size is discussed in this paper. The maximum batch size depends on the GPU memory (and optimizer - with Adafactor we can afford bigger batches, see Table 2 of the paper). The conclusions are: "for the BASE model, a higher batch size gives better results, although with diminishing returns. With the BIG model, we see quite different results."

With accumulation of gradients, one can arbitrarily increase the batch size.

Yes, I had this idea as well: accumulate gradients from N batches and do just one update of weights afterwards, simulating N times bigger batch (or N times more GPUs). I think it is worth trying, but someone would need to implement it first. There will be a question of how to compute the number of steps (which influences the learning rate schedule). However, I am not sure super-big batches will improve the convergence speed. Still, it may be useful for simulating multi-GPU experiments on a single GPU (so that after buying more GPUs, I will already know what are the optimal hyperparams).

@AranKomat
Copy link

Thanks for your response. As you said, I'm simulating such a situation (for language modeling with Transformer) with a single GPU at the expense of time per iteration (for this reason, achieving convergence isn't realistic). I hope someone will figure out how much performance gain it is possible with a huge batch size with multiple GPUs.

@fstahlberg
Copy link
Contributor

We have an upcoming ACL paper where we use this idea for neural machine translation with target side syntax. It turns out that using large batch sizes is even more important when generating long output sequences. I'll post a link to arXiv when ready. There is also a t2t implementation:

https://github.com/fstahlberg/tensor2tensor/blob/master/tensor2tensor/utils/largebatch_optimizer.py
Like the normal AdamOptimizer, but the n argument in the constructor is what you are describing: simulating n times more GPUs at the cost of n times more training iterations.

However, this is still t2t 1.3.1. I haven't had time to polish the code and update t2t to see if it still works. But I can do that and make a PR.

Regarding the original question: From my experience it is a good idea to try to match the number_of_gpus*batch_size setup, and n can compensate for reducing either of these values. I haven't seen gains from even larger batches.

@AranKomat
Copy link

@fstahlberg I really appreciate your feedback, and I'm looking forward to reading your paper. Generating a long sequence (though not for translation) is something I'm currently working on, so that's very beneficial to learn about. Maybe increasing the batch size enhances the generalizability of text transduction model, which alleviates the issue of exposure bias in generating long sequences? I'm eager to hear from you about any of these as well as bit more relevant details about your paper.

@vince62s
Copy link
Contributor Author

does someone have a recent comparison between 4 and 8 GPU for the same set of hparam
(batch size 4096, warmup steps 8000, lr 2 -- ie. default)

@AranKomat
Copy link

This recent paper achieved 5x speedup on translation using Transformer with various techniques, including batch size of 400k and mixed precision: Scaling Neural Machine Translation. Furthermore, it achieved BLEU of 29.3 and 43.2 on En-De and En-Fr, respectively. For those of us who don't have many GPUs, the use of diet variables of utils/diet.py would be helpful to increase the batch size if that thing works. Has anybody tried diet variables? Does it really work as expected?

@mehmedes
Copy link

@AranKomat: Please note that in the aforementioned paper one crucial factor in the speed up is switching from single to half precision and that the hardware is V100, which achieves 14TFLOPs in single precision and 112TFLOPs!! in half precision. The P100, which was used in the T2T paper, would "only" increase from 9TFLOPs to 18TFLOPs when switching to half precision. The hardware should also be considered when evaluating the speed up.
The 1080Ti, for example, is faster on calculating single precision than on half precision!

@AranKomat
Copy link

@mehmedes I didn't notice that there was such as huge difference between V100 and P100 in terms of half precision TFLOPS! But I believe Table 1 accounts for the difference by citing the BLEU and speed with V100. Maybe using diet variables wouldn't benefit much in this case if half precision is already used.

@xerothermic
Copy link

Hi @fstahlberg

Has your paper been published?

We have an upcoming ACL paper where we use this idea for neural machine translation with target side syntax. It turns out that using large batch sizes is even more important when generating long output sequences. I'll post a link to arXiv when ready. There is also a t2t implementation:

@fstahlberg
Copy link
Contributor

Hi @xerothermic yes, we have used it in

https://www.aclweb.org/anthology/P18-2051

for syntax and in

https://www.aclweb.org/anthology/W18-6427

for a WMT18 submission.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests