Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Unable to reproduce WMT En2De results #317

Open
edunov opened this issue Sep 21, 2017 · 52 comments
Open

Unable to reproduce WMT En2De results #317

edunov opened this issue Sep 21, 2017 · 52 comments
Labels

Comments

@edunov
Copy link

edunov commented Sep 21, 2017

I tried to reproduce results from the paper on WMT En2De, base model. In my experiments I tried both BPE and word piece model. Here are the steps I made to train models:

# For BPE model I used this setup
PROBLEM=translate_ende_wmt_bpe32k
# For word piece model I used this setup
PROBLEM=translate_ende_wmt32k

MODEL=transformer
HPARAMS=transformer_base

DATA_DIR=$HOME/t2t_data
TMP_DIR=$HOME/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS

mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR

-datagen \
  --data_dir=$DATA_DIR \
  --tmp_dir=$TMP_DIR \
  --problem=$PROBLEM

-trainer \
  --data_dir=$DATA_DIR \
  --problems=$PROBLEM \
  --model=$MODEL \
  --hparams_set=$HPARAMS \
  --output_dir=$TRAIN_DIR`

I trained both models till the trainer finished (~1 day). The last update for BPE model was:

INFO:tensorflow:Validation (step 250000): loss = 1.73099, metrics-translate_ende_wmt_bpe32k/accuracy = 0.633053, metrics-translate_ende_wmt_bpe32k/accuracy_per_sequence = 0.0, metrics-│
translate_ende_wmt_bpe32k/accuracy_top5 = 0.819939, metrics-translate_ende_wmt_bpe32k/approx_bleu_score = 0.306306, metrics-translate_ende_wmt_bpe32k/neg_log_perplexity = -1.98039, met│
rics-translate_ende_wmt_bpe32k/rouge_2_fscore = 0.38708, metrics-translate_ende_wmt_bpe32k/rouge_L_fscore = 0.589309, global_step = 249006

INFO:tensorflow:Saving dict for global step 250000: global_step = 250000, loss = 1.7457, metrics-translate_ende_wmt_bpe32k/accuracy = 0.638563, metrics-translate_ende_wmt_bpe32k/accura│
cy_per_sequence = 0.0, metrics-translate_ende_wmt_bpe32k/accuracy_top5 = 0.823388, metrics-translate_ende_wmt_bpe32k/approx_bleu_score = 0.290224, metrics-translate_ende_wmt_bpe32k/neg│
_log_perplexity = -1.93242, metrics-translate_ende_wmt_bpe32k/rouge_2_fscore = 0.373072, metrics-translate_ende_wmt_bpe32k/rouge_L_fscore = 0.574759


For word piece model the last update was:

INFO:tensorflow:Validation (step 250000): loss = 1.56711, metrics-translate_ende_wmt32k/accuracy = 0.655595, metrics-translate_ende_wmt32k/accuracy_per_sequence = 0.0360065, metrics-tr│
anslate_ende_wmt32k/accuracy_top5 = 0.836071, metrics-translate_ende_wmt32k/approx_bleu_score = 0.358524, metrics-translate_ende_wmt32k/neg_log_perplexity = -1.84754, metrics-translate│
_ende_wmt32k/rouge_2_fscore = 0.440053, metrics-translate_ende_wmt32k/rouge_L_fscore = 0.628949, global_step = 248578

INFO:tensorflow:Saving dict for global step 250000: global_step = 250000, loss = 1.57279, metrics-translate_ende_wmt32k/accuracy = 0.65992, metrics-translate_ende_wmt32k/accuracy_per_s│
equence = 0.00284091, metrics-translate_ende_wmt32k/accuracy_top5 = 0.841923, metrics-translate_ende_wmt32k/approx_bleu_score = 0.368791, metrics-translate_ende_wmt32k/neg_log_perplexi│
ty = -1.80413, metrics-translate_ende_wmt32k/rouge_2_fscore = 0.445689, metrics-translate_ende_wmt32k/rouge_L_fscore = 0.636854

Then I tried to evaluate both models on newstest2013, newstest2014, newstest2015. Here are the commands that I used (I'm mostly following steps from here https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/get_ende_bleu.sh)

For BPE model:

YEAR=2013
#YEAR=2014
#YEAR=2015
BEAM_SIZE=5
ALPHA=0.6
t2t-decoder   --data_dir=$DATA_DIR   \
    --problems=$PROBLEM   --model=$MODEL   \
    --hparams_set=$HPARAMS   --output_dir=$TRAIN_DIR   \
    --decode_beam_size=$BEAM_SIZE   --decode_alpha=$ALPHA   \
    --decode_from_file=/tmp/t2t_datagen/newstest${YEAR}.tok.bpe.32000.en

#Tokenize reference
perl ~/mosesdecoder/scripts/tokenizer/tokenizer.perl -l de < /tmp/t2t_datagen/newstest${YEAR}.de > /tmp/t2t_datagen/newstest${YEAR}.de.tok
#Do compound splitting on the reference
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < /tmp/t2t_datagen/newstest${YEAR}.de.tok > /tmp/t2t_datagen/newstest${YEAR}.de.atat

#Remove BPE tokenization
cat /tmp/t2t_datagen/newstest${YEAR}.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.decodes | sed 's/@@ //g' > /tmp/t2t_datagen/newstest${YEAR}.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.words
#Do compound splitting on the translation
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < /tmp/t2t_datagen/newstest${YEAR}.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.words > /tmp/t2t_datagen/newstest${YEAR}.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.atat
#Score
perl ~/mosesdecoder/scripts/generic/multi-bleu.perl /tmp/t2t_datagen/newstest${YEAR}.de.atat < /tmp/t2t_datagen/newstest${YEAR}.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.atat

For word piece model:

YEAR=2013
#YEAR=2014
#YEAR=2015 
BEAM_SIZE=5
ALPHA=0.6
t2t-decoder   --data_dir=$DATA_DIR   --problems=$PROBLEM   \
    --model=$MODEL   --hparams_set=$HPARAMS   \
    --output_dir=$TRAIN_DIR   --decode_beam_size=$BEAM_SIZE   \
    --decode_alpha=$ALPHA   --decode_from_file=/tmp/t2t_datagen/newstest${YEAR}.en

#Tokenize the reference
perl ~/mosesdecoder/scripts/tokenizer/tokenizer.perl -l de < /tmp/t2t_datagen/newstest${YEAR}.de > /tmp/t2t_datagen/newstest${YEAR}.de.tok
#Do compound splitting on the reference
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < /tmp/t2t_datagen/newstest${YEAR}.de.tok > /tmp/t2t_datagen/newstest${YEAR}.de.atat

#Tokenize the translation
perl ~/mosesdecoder/scripts/tokenizer/tokenizer.perl -l de < /tmp/t2t_datagen/newstest${YEAR}.en.transformer.transformer_base.beam5.alpha0.6.decodes > /tmp/t2t_datagen/newstest${YEAR}.en.transformer.transformer_base.beam5.alpha0.6.tok
#Do compount splitting on the translation
perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < /tmp/t2t_datagen/newstest${YEAR}.en.transformer.transformer_base.beam5.alpha0.6.tok > /tmp/t2t_datagen/newstest${YEAR}.en.transformer.transformer_base.beam5.alpha0.6.atat
#Score the translation
perl ~/mosesdecoder/scripts/generic/multi-bleu.perl /tmp/t2t_datagen/newstest${YEAR}.de.atat < /tmp/t2t_datagen/newstest${YEAR}.en.transformer.transformer_base.beam5.alpha0.6.atat

Here are the BLEU scores I've got:

  newstest2013 newstest2014 newstest2015
BPE 10.81 11.31 12.75
wordpiece 22.41 22.75 25.46

There is a big mismatch with the results reported in the paper, so there must be something wrong with the way I ran these experiments. Could you please provide me some guidance on how to run this properly to reproduce the results from the paper?

@martinpopel
Copy link
Contributor

  • Have you inspected the BPE output for obvious problems (proper deBPEzation etc)?
  • The results in the paper were on 8 GPUs, I think. It seems you are using just one GPU. In that case use transforer_base_single_gpu or even better transforer_big_single_gpu, set the batch size as high as possible (with some reserve to prevent OOM) and train for at least 8 times more steps. I trained ende transforer_base_single_gpu for 500k steps (almost three days) on a single GPU with batch_size=3072, T2T v1.1, and got BLEU(newstest2014) = 25.61. @vince62s replicated this with T2T v1.1, but got worse results with T2T v1.2.
  • You can get up to 1 BLEU point (usually less, sometimes more) improvement by averaging few last checkpoints. The paper suggests last 20 (saved in the default 10-minutes interval). I suggest to use 60-minutes interval instead.

@edunov
Copy link
Author

edunov commented Sep 21, 2017

Hi Martin,

Thank you for quick reply.

  1. Yes, I checked BPE output, it looks fine to me, here are first lines of the model output: https://gist.github.com/edunov/f5b0ead45a10deb5469bc58c790d05ca
    And this is how it looks after all processing: https://gist.github.com/edunov/a46303ff05505d33eddfad0d78900d76
    Is there anything that seems obviously wrong to you?

  2. I actually ran it on 8 GPU machine, and I believe it used all 8 GPUs. I see a bunch of things like this in the log, so it clearly sees all GPUs:

    2017-09-20 10:56:29.551421: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:0) -> (device: 0, name: Tesla P100-SXM2-16GB, pci bus id: 0000:06
    :00.0)
    2017-09-20 10:56:29.551452: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:1) -> (device: 1, name: Tesla P100-SXM2-16GB, pci bus id: 0000:07
    :00.0)
    2017-09-20 10:56:29.551477: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:2) -> (device: 2, name: Tesla P100-SXM2-16GB, pci bus id: 0000:0a
    :00.0)
    2017-09-20 10:56:29.551484: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:3) -> (device: 3, name: Tesla P100-SXM2-16GB, pci bus id: 0000:0b
    :00.0)
    2017-09-20 10:56:29.551509: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:4) -> (device: 4, name: Tesla P100-SXM2-16GB, pci bus id: 0000:85
    :00.0)
    2017-09-20 10:56:29.551516: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:5) -> (device: 5, name: Tesla P100-SXM2-16GB, pci bus id: 0000:86
    :00.0)
    2017-09-20 10:56:29.551540: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:6) -> (device: 6, name: Tesla P100-SXM2-16GB, pci bus id: 0000:89
    :00.0)
    2017-09-20 10:56:29.551547: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1045] Creating TensorFlow device (/gpu:7) -> (device: 7, name: Tesla P100-SXM2-16GB, pci bus id: 0000:8a
    :00.0)

Is there anything I need to enable to utilize all 8 GPUs properly?

  1. Thank you, I'll try averaging, is it correct to use this script to average over multiple checkpoints? https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py
    But there is still quite a gap, even if I get +1 BLEU using averaging.

@martinpopel
Copy link
Contributor

Is there anything that seems obviously wrong to you?

E.g. diskutieren split into two lines, which shifts all lines by one, unless that is a copy-paste error when copying to gist

Is there anything I need to enable to utilize all 8 GPUs properly?

--worker_gpu=8

@edunov
Copy link
Author

edunov commented Sep 21, 2017

Sorry, this was a copy&paste issue, just to be sure I checked line counts in all files, and it is the same:

$ wc -l /tmp/t2t_datagen/newstest2015.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.atat
2169 /tmp/t2t_datagen/newstest2015.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.atat
$ wc -l /tmp/t2t_datagen/newstest2015.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.decodes
2169 /tmp/t2t_datagen/newstest2015.tok.bpe.32000.en.transformer.transformer_base.beam5.alpha0.6.decodes
$ wc -l /tmp/t2t_datagen/newstest2015.de
2169 /tmp/t2t_datagen/newstest2015.de

I'll try running with --worker_gpu=8 and report results here.

@edunov
Copy link
Author

edunov commented Sep 22, 2017

Unfortunately, it fails to run if I add --worker_gpu=8. It fails with the following error:
https://gist.github.com/edunov/70f24b50a1c113b3505b019ca0652754

Seems to be the same issue as here: #266
I'll try running with --local_eval_frequency=0

@mehmedes
Copy link

@edunov Have you checked whether the tests have already been tokenized? Standford's tests sets for example are already tokenized. Maybe you've tokenized the reference text twice? Moreover, in the excerpt you posted the sentences seem to miss full stops. Maybe this impacts your BPE score?

@edunov
Copy link
Author

edunov commented Sep 25, 2017

@mehmedes I'm using the test set that tensor2tensor provides, I believe it comes from here:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/wmt.py#L296

It has both tokenized (newstest2014.tok.de) and untokenized version (newstest2014.de), I'm using untokenized one.

Btw, I re-run all the experiments and here the numbers I've got (everything before model averaging):

  newstest2013 newstest2014 newstest2015
BPE 13 14.6 15.73
word piece 26.03 27.76 29.99

Seems like word piece model works right. But for BPE model I'm still very far from the expected numbers. From the output of BPE model it seems that periods are completely missing, so maybe the sentences are truncated or something? Is there any option or parameter in the generator that limits the sentence length?

@lukaszkaiser
Copy link
Contributor

For the BPE model, you need to use the tokenized test set and do not re-run MOSES tokenizer on the decoded output. Your results look like some bug of this kind, it's hard to believe the model would be so different. There is a decode parameter limiting the length, but I doubt it's that problem.

@vince62s
Copy link
Contributor

@edunov what did you use to reproduce 27.76 on newstest2014 (word piece) ?
8 GPU / workers ?
big / base // single / not single ?
Batch size ?

@lukaszkaiser what is the rationale behind the fact that with one GPU only it seems to converge faster to a lower point (in BLEU terms). Even if we train longer it never reaches the same level.

@edunov
Copy link
Author

edunov commented Oct 30, 2017

@vince62s I used 8 GPUs, base model and everything else is set to default. I trained the model until the trainer stopped (250k steps). I believe this is different from what they did in the paper (100k steps) and it takes longer.

@frankdiehl
Copy link

Hi All,
also I'm struggling to reproduce the ende-results presented in the 'Transformer is all you need' paper. I am using:
tensor2tensor version 1.2.7
tensorflow version 1.4

And run on a single gpu the setup:

--problem=translate_ende_wmt32k
--hparams_set=transformer_big_single_gpu (that corresponds to a batch_size=4096)
--model=transformer

I monitor the eval-bleu score with --eval_run_autoregressive set.

After 110861 steps I get a approx_bleu_score = 0.108027
However, it took ~50k steps to reach 0.1 and the next ~60k steps gave me only ~0.008 improvement. Running the newstest2014 test resulted in: BLEU = 13.61 (I took the already tokenized data from Stanford). Taking the very slow eval-blue increase into account I doubt to come any close to this 27.76 number reported above - even when keep on training for another 400k steps. Anyone has an idea wrt this? Thanks a lot!

@liesun1994
Copy link

@edunov Have you solved the bpe lower results problem ? I am puzzling now. The bpe results on newstest2014 just get 12.16 bleu . Really puzzling . I am using the configuration just the same as you mentioned . em ....

@liesun1994
Copy link

@lukaszkaiser Hello , When using translate_ende_bpe_32k translation , I met the same question as @edunov , The bleu score is too low . When I print the test log , it always truncate the source sentence . The examples are follows :
newstest2014.tok.bpe.en
image
the output log
image
symbol " . " is missed . And the translation results rarely have end of sentence symbol such as " . "
Just as @edunov mentioned . https://gist.github.com/edunov/a46303ff05505d33eddfad0d78900d76
So is there a proper way to solve this problem ?

@martinpopel
Copy link
Contributor

Can someone try t2t 1.3.0 and confirm if the error is fixed (by 7909c69)?

@liesun1994
Copy link

@martinpopel I will try it .

@liesun1994
Copy link

@martinpopel I think the problem has solved . I trained 180K steps with 2 GPUs , transformer_base_single_gpu and all others are set to default . It got 22.72 in newstest2014 .

@jiangbojian
Copy link

jiangbojian commented Jan 8, 2018

Hi!@edunov I have a question,when you compute BLEU on newtest2014 using word-piece model,the reference file is newstest2014.de ? Have you try another way to compute BLEU ? Firstly, token your result file. Secondly , compute BLEU using newstest2014.tok.de . How will the BLUE score change?

I run the experiment(word-piece model) has the same setting except trainsteps=10w,BLEU on newtest2014(reference file is newstest2014.de) is 22.57.Will two different ways produce significant different BLEU score? thanks!

@edunov
Copy link
Author

edunov commented Jan 8, 2018

@jiangbojian I actually tokenize and then apply compound splitting to the reference before computing BLEU score:

#Tokenize the reference perl ~/mosesdecoder/scripts/tokenizer/tokenizer.perl -l de < /tmp/t2t_datagen/newstest${YEAR}.de > /tmp/t2t_datagen/newstest${YEAR}.de.tok #Do compound splitting on the reference perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' < /tmp/t2t_datagen/newstest${YEAR}.de.tok > /tmp/t2t_datagen/newstest${YEAR}.de.atat

/tmp/t2t_datagen/newstest${YEAR}.de.atat - is what I use to compute BLEU

It is important to have exactly the same processing steps for both reference and system output. So, whatever you do to reference, you'll have to do to the system output as well.

@jiangbojian
Copy link

@edunov Thank you for your detailed answer.
When I token the result file and reference file, Bleu score has a effective improvement (newtest2014,ensemble the last modes ,bleu score= 26,56),Could you tell me the function the the second script or how to get newstest${YEAR}.de.atat ?

@jiangbojian
Copy link

jiangbojian commented Jan 9, 2018

@edunov I only token the result file and the reference file.
I train wordpiece model(setting is equals to 'Attention Is All You Need', 250k trainsteps, 8 gpus, batch_size=4096).
Ensemble the last 20models ,BLEU score computed by multi-bleu.perl is 26.96.
Is 'apply compound splitting' so important to Bleu?
Thanks~

@edunov
Copy link
Author

edunov commented Jan 10, 2018

Yes, what it does is it replaces hyphen with 3 words, e.g. "Science-Fiction" becomes "Science ##AT##-##AT## Fiction" so if your translation is correct you'll have 3 consecutive tokens match instead of 1. It seems like it gives extra 0.7 - 1 BLEP point

To get newstest${YEAR}.de.atat I just applied the same compound splitting to the tokenized version of newstest${YEAR}.de

@jiangbojian
Copy link

@edunov Yes,BLEU score gives extra 0.7-1 BLEU point on word-piece model. By the way,for bpe model(setting same with the paper) ,through applying compound splitting on result file and reference file,BLEU score on newtest2014 reach 27.11(the BLEU score on paper is 27.30)

So,is the BLEU score(ende,bpe32k,base model) reported by the paper computed in the way? @lukaszkaiser

@SkyAndCloud
Copy link
Contributor

@jiangbojian Hi, I'm now try to reproduce transformer, too. However, I can only achieve ~22 BLEU score on wmt'14-ende dataset. My configs:

Tensorflow: 1.4.0
T2T: 1.5.5
hparams_set: transformer_base
gpu_num: 4
problem: translate_ende_wmt32k
testset: newstest2014-deen-src.en   newstest2014-deen-ref.de  (these two files are untokenized)

My model covergented after ~8k steps. Here are my translating steps:

  1. tokenize newstest2014-deen-src.en
  2. use beam_size=4,alpha=0.6, decode as T2T readme walkthrough and get predict.txt
  3. tokenize newstest2014-deen-ref.de
  4. for predict.txt and newstest2014-deen-ref.de put compounds in ATAT format as this
  5. use t2t-bleu to calculate BLEU score

could you please provide a full experiment detail to help me achieve ~27 BLEU score as you? Thank you!

@haoransh
Copy link

@martinpopel Hi, I noticed that you claim with the transformer_base_single_gpu configuration( but 500~k steps and batch_size 3072) with t2t version1.1. However, I encountered NAN loss when trying it in the current version(pip install). My command is

t2t-trainer   --data_dir=$DATA_DIR   --problems=translate_ende_wmt_bpe32k   --model=transformer   --hparams_set=transformer_base_single_gpu --hparams="batch_size=3072"  --output_dir=$TRAIN_DIR > print.txt 2>log.txt --train_steps=500000

Can you please show your command when you reach 25.61 Bleu score on the same task? Thanks very much!
Or could someone else provide his command when the model reaches highest Bleu score on single gpu?

@martinpopel
Copy link
Contributor

martinpopel commented Mar 27, 2018

@Shrshore: What I wrote was true for T2T version 1.1 and PROBLEM=translate_ende_wmt32k. I have never tried translate_ende_wmt_bpe32k. There are many differences in results between different versions (e.g. #529), usually you can get similarly good results with the newer version, but sometimes you must change some hyperparameters. @lukaszkaiser wrote they "have a regression test for wmt'14 en-de", but this is probably only for 8 GPUs and "--hparams_set=transformer_base and --problems=translate_ende_wmt32k. Train for 500k steps. Get BLEU, should be about 28." There is no guarantee for less GPUs, or if you need to lower the batch size.

@haoransh
Copy link

@martinpopel Thanks for your reply! I'm also wondering whether the evaluation pipeline should contain the 'compound splitting' operation sincerely. Have you tried that before? Since I have obtained an model output, whose Bleu metric is only ~22 without compound splitting. But after applying compound splitting to the reference file and output file, the metric can goes to ~27 with t2t-bleu. So I'm wondering whether the result reported in the paper is calculated in this way, too. @lukaszkaiser

@martinpopel
Copy link
Contributor

I think (based on some post from @lukaszkaiser) the BLEU=28.4 in the Attention Is All You Need was computed with https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/get_ende_bleu.sh, which splits hyphen-compounds for "historical reasons". This splitting improves the BLEU score nominally, but (of course) not the translation quality, so I do not do it (I tried just once when I wanted to compare my scores to the paper and measure the effect of different BLEU evaluation).

@SkyAndCloud
Copy link
Contributor

@jiangbojian @edunov @martinpopel @Shrshore Hi, I used --hparams="batch_size=1024" and after 360k steps, I got 21.7 bleu. have you used transformer_base_single_gpu hparams_set to reproduce ~27 bleu score on translate_ende_wmt32k problem? I don't know if there exists some tricks.
My env:

tensor2tensor:1.5.5
tf:1.4.1

@DC-Swind
Copy link

Hi, @edunov @martinpopel
Can you supply some details about how to average few last checkpoints? I tried the script avg_checkpoints.py, but it demages the performance. From 21.28 BLEU to 14.51.
( I got 21.28 BLEU using transformer_base_single_gpu after 250k steps. )
I use the following command to average checkpoints:
python avg_checkpoints.py --checkpoints=model.ckpt-250000,model.ckpt-249001,model.ckpt-249000,model.ckpt-248001,model.ckpt-248000,model.ckpt-247001,model.ckpt-247000,model.ckpt-246001,model.ckpt-246000,model.ckpt-245001,model.ckpt-245000,model.ckpt-244001,model.ckpt-244000,model.ckpt-243001,model.ckpt-243000,model.ckpt-242001,model.ckpt-242000,model.ckpt-241001,model.ckpt-241000,model.ckpt.ckpt-240001,model.ckpt-240000 --num_last_checkpoints=21 --prefix=${PREFIX}${CONFIG} --output_path=${PROJ_PATH}/avg_models/${CONFIG}last21.ckpt

And evaluate the model using:
t2t-decoder --data_dir=data --problems=$PROBLEM --model=$MODEL --hparams_set=$HPARAMS --output_dir=$TRAIN_DIR --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" --decode_from_file=$DECODE_FILE --decode_to_file=translation.de t2t-bleu --translation=translation.de --reference=data/newstest2014.de

@DC-Swind
Copy link

DC-Swind commented Apr 11, 2018 via email

@martinpopel
Copy link
Contributor

@DC-Swind: No, I haven't met this situation. I always average checkpoints stored in 1-hour intervals (--save_checkpoints_secs=3600 --schedule=train). I have no idea what is the effect of averaging the "extra" checkpoints (e.g. 249001 in addition to 249000) created by --schedule=continuous_train_and_eval, but see e.q #556. The original paper used --schedule=train_and_evaluate with no extra checkpoints (although saved in 10-minutes intervals).

@DC-Swind
Copy link

@martinpopel : I will try --schedule=train . One more question: when I use t2t-decoder --output_dir=$TRAIN_DIR, which parameter will be used to decode? model.ckpt-250000?
My checkpoint file in $TRAIN_DIR:
model_checkpoint_path: "model.ckpt-250000" all_model_checkpoint_paths: "model.ckpt-240001" all_model_checkpoint_paths: "model.ckpt-241000" all_model_checkpoint_paths: "model.ckpt-241001" all_model_checkpoint_paths: "model.ckpt-242000" all_model_checkpoint_paths: "model.ckpt-242001" all_model_checkpoint_paths: "model.ckpt-243000" all_model_checkpoint_paths: "model.ckpt-243001" all_model_checkpoint_paths: "model.ckpt-244000" all_model_checkpoint_paths: "model.ckpt-244001" all_model_checkpoint_paths: "model.ckpt-245000" all_model_checkpoint_paths: "model.ckpt-245001" all_model_checkpoint_paths: "model.ckpt-246000" all_model_checkpoint_paths: "model.ckpt-246001" all_model_checkpoint_paths: "model.ckpt-247000" all_model_checkpoint_paths: "model.ckpt-247001" all_model_checkpoint_paths: "model.ckpt-248000" all_model_checkpoint_paths: "model.ckpt-248001" all_model_checkpoint_paths: "model.ckpt-249000" all_model_checkpoint_paths: "model.ckpt-249001" all_model_checkpoint_paths: "model.ckpt-250000"

How to specify checkpoint by name if I want to decode using "model.ckpt-24500"?

By the way, is there any convenient method of drawing training loss or test BLEU curve?

@martinpopel
Copy link
Contributor

@DC-Swind: t2t-decoder checks the checkpoint file in the output_dir and uses the checkpoint specified in that file (which is usually the last one, unless you edit it). However, you can use t2t-decoder --checkpoint_path=path/to/my/model.ckpt-24500 (which overrides the --output_dir).

By the way, is there any convenient method of drawing training loss or test BLEU curve?

Yes, tensorboard (plus t2t-bleu and t2t-translate-all if you prefer to see the real BLEU instead of approx_bleu or if you use --schedule=train).

@martinpopel
Copy link
Contributor

If anyone is still trying to reproduce the Attention Is All You Need paper en-de wmt14 BLEU scores: note that you must manually tweak the newstest2014.de reference file (in addition to all the hacks in get_ende_bleu.sh): convert all unicode quotes including to ", which will be converted to&quot; by tokenizer.perl (but make sure not to run tokenizer.perl twice to prevent double escaping). There is no such script in mosesdecoder which does this tweak. (replace-unicode-punctuation.perl ignores lower quotes, normalize-punctuation.perl -l de changes the order of comma-quote to quote-comma, which is not what we want). Or download the tweaked reference from gs://tensor2tensor-checkpoints/transformer_ende_test/newstest2014.tok.de. As you can see below, the difference between the official BLEU (sacreBLEU) and Google-tweaked BLEU can be about 1.5 BLEU. I hope this is a good motivation for everyone for using sacreBLEU (including the signature) next time when reporting any BLEU results.

transformer_base model from gs://tensor2tensor-checkpoints/transformer_ende_test/model.ckpt-3817425

wmt13 wmt14
26.25 27.52 sacrebleu
26.59 28.33 sacrebleu -tok intl
27.10 28.85 sacrebleu -tok intl -lc
26.50 29.02 get_ende_bleu.sh (with tweaked reference)

transformer_base model from gs://tensor2tensor-checkpoints/transformer_ende_test/averaged.ckpt-0

wmt13  wmt14
25.80  26.55  sacrebleu
26.18  27.25  sacrebleu -tok intl
26.67  27.78  sacrebleu -tok intl -lc
26.10  27.98  get_ende_bleu.sh (with tweaked reference)

@szhengac
Copy link

Hi @martinpopel, I wonder what is the difference between model.ckpt-3817425 and averaged.ckpt-0? The former one produces much better results. Thanks.

@martinpopel
Copy link
Contributor

@szhengac: These checkpoints were uploaded by the T2T authors, not by me, so I am not sure. I guess the former model was trained for 3817425 steps (that is 3.8M steps, while in the Attention Is All You Need only 0.1M steps were used for the base models, but it also depends on the number of GPUs and batch size) and with a newer T2T version.

@martinpopel
Copy link
Contributor

As some users still ask how to replicate the BLEU scores (after downloading the trained checkpoint from gs://tensor2tensor-checkpoints/transformer_ende_test/model.ckpt-3817425), I provide what I have

wget http://ufallab.ms.mff.cuni.cz/~popel/replicate-attention-is-all-you-need-bleu.tar.gz
tar -xf replicate-attention-is-all-you-need-bleu.tar.gz
./multi-bleu.perl newstest2013.tok.de.atat < wmt13_deen-translated-bs4a0.6.de.tok.atat
./multi-bleu.perl newstest2014.tok.de.atat < wmt14_deen-translated-bs4a0.6.de.tok.atat

This gives BLEU 26.50 and 29.02 for wmt13 and wmt14, respectively. And (after pip3 install sacrebleu)

sacrebleu -l en-de -t wmt13 -tok intl < wmt13_deen-translated-bs4a0.6.de
sacrebleu -l en-de -t wmt14/full -tok intl < wmt14_deen-translated-bs4a0.6.de

this gives BLEU 26.59 and 28.33.

@szhengac
Copy link

Hi @martinpopel. Thanks for your reply. I wonder what beam_size and alpha were used to obtain these translation text?

@martinpopel
Copy link
Contributor

The default beam_size=4 and alpha=0.6 (as indicated in the filenames by bs4a0.6).

@libeineu
Copy link

libeineu commented Jul 2, 2018

Hi @martinpopel I have confusion about the score 29.3 in Scaling Neural Machine Translation , do you know the way they computed the bleu score ?

@martinpopel
Copy link
Contributor

I have just read the paper, but it seems quite clear:

We measure case-sensitive tokenized BLEU
with multi-bleu (Hoang et al., 2006) and detokenized
BLEU with SacreBLEU (Post, 2018).
BLEU+case.mixed+lang.en-{de,fr}+numrefs.1+smooth.exp+test.wmt14/full+tok.13a+version.1.2.9

In most tables they report only the multi-bleu version, just in Table 3 they report both and we can see that multi-bleu 29.3 corresponds to sacreBLEU 28.6.

@libeineu
Copy link

libeineu commented Jul 3, 2018

@martinpopel Yeah , the multibleu score they got should be handled by the get_ende_bleu.sh ?

@martinpopel
Copy link
Contributor

They report their multibleu score 29.3 in Table 2 as if it is comparable to Vaswani et al. (2017)'s 28.4, but I doubt they followed all the tweaks (unless there was a personal communication between the authors of the two papers). Luckily, we can compare the sacreBLEU scores (case.mixed-tok.13a verison): 28.6 (Ott et al) vs. 27.52 (T2T as reported above, about half a point better than Vaswani et al).

It should be noted that while it is good to use comparable and replicable BLEU (i.e. sacreBLEU), it is not everything as most MT researchers know. It's not only BLEU, but any automatic metric based on similarity to human reference I am aware of (especially to a single reference, as is the case in WMT) is potentially flawed. There are systems today (for some language pairs and the "WMT domain") surpassing the quality of human references (or at least they are near). This of course does not mean that the systems are better in all aspects than human references, just in some aspects. But it means that single-reference BLEU (or any other automatic metric) is not reliable for such high quality systems. I'm curious what correlation scores will we see in WMT18 metrics task results.

@libeineu
Copy link

libeineu commented Jul 4, 2018

Ok ,I just want to build a baseline that the BLEU score is comparable with other systems above . I think it's a good beginning for me to do further research . Thanks for your answer !

@nxphi47
Copy link

nxphi47 commented Jan 24, 2019

Hi @martinpopel ,
It's been a while since this thread is active. I have been reading the thread but I am still unclear what to do the get the real BLEU score on the papers. In the instruction in README, it says I just need to use t2t-bleu, but here there is mentions of sacredbleu and some perl commands and get_ende_bleu.sh and tweaking of reference.

Can you please help to explain completely the steps to reproduce the correct results?

I downloaded the dataset from google_drive_link for for WMT16

I trained the model with problem translate_ende_wmt_bpe32k and get a bunch of model checkpoints model.ckpt-xxxx

Now what should I do specifically to get the averaged model and get the correct bleu score?

Thank you.

@martinpopel
Copy link
Contributor

Hi @nxphi47,
the BLEU reported in the original paper was computed using get_ende_bleu.sh plus a few other tweaks. I don't recommend to waste time by trying to reproduce it exactly. Note that if you trained the checkpoints yourself, most likely you had a different GPUs (and effective batch size) than in the original paper, so you will get different results anyway.

For future research, it is recommended to use sacreBLEU, which when used with option -tok intl should give exactly the same results as t2t-bleu (comparing case-sensitive with case-sensitive or case-insensitive with case-insensitive, of course).

For averaging, use avg_checkpoints.py or t2t-avg-all.

@nxphi47
Copy link

nxphi47 commented Jan 24, 2019

Hi @martinpopel
Thank you for the comment.
So if I use t2t-bleu, I will:
1/ generate with t2t-decoder
2/ remove BPE
3/ use t2t-bleu against newstest2014.tok.de
Report to paper the uncased BLEU score from t2t-bleu
Is that the correct procedure?

Thank you

@martinpopel
Copy link
Contributor

When using t2t-bleu or sacrebleu, always use non-tokenized (de-tokenized) version of translation and reference files, i.e. the version which would be presented to the users.

@nxphi47
Copy link

nxphi47 commented Jan 24, 2019

@martinpopel Thank you for you reply
So from the BPE-removed tokenized version of the output. How can I obtain the de-tokenize them?
Thank you so much.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests