Skip to content

Commit

Permalink
Multilingual training example (#527)
Browse files Browse the repository at this point in the history
Summary:
* Add example for multilingual translation on IWSLT'17
* Match dataset ordering for multilingual_translation and translation
* Fix bug with LegacyDistributedDataParallel when calling forward of sub-modules
Pull Request resolved: #527

Differential Revision: D14218372

Pulled By: myleott

fbshipit-source-id: 2e3fe24aa39476bcc5c9af68ef9a40192db34a3b
  • Loading branch information
myleott authored and facebook-github-bot committed Feb 26, 2019
1 parent 44d27e6 commit 0049349
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 23 deletions.
69 changes: 68 additions & 1 deletion examples/translation/README.md
Expand Up @@ -92,7 +92,6 @@ $ fairseq-generate data-bin/iwslt14.tokenized.de-en \
```


### prepare-wmt14en2de.sh

The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script.
Expand Down Expand Up @@ -163,3 +162,71 @@ $ fairseq-generate data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
```

## Multilingual Translation

We also support training multilingual translation models. In this example we'll
train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets.

Note that we use slightly different preprocessing here than for the IWSLT'14
En-De data above. In particular we learn a joint BPE code for all three
languages and use interactive.py and sacrebleu for scoring the test set.

```
# First install sacrebleu and sentencepiece
$ pip install sacrebleu sentencepiece
# Then download and preprocess the data
$ cd examples/translation/
$ bash prepare-iwslt17-multilingual.sh
$ cd ../..
# Binarize the de-en dataset
$ TEXT=examples/translation/iwslt17.de_fr.en.bpe16k
$ fairseq-preprocess --source-lang de --target-lang en \
--trainpref $TEXT/train.bpe.de-en --validpref $TEXT/valid.bpe.de-en \
--joined-dictionary \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Binarize the fr-en dataset
# NOTE: it's important to reuse the en dictionary from the previous step
$ fairseq-preprocess --source-lang fr --target-lang en \
--trainpref $TEXT/train.bpe.fr-en --validpref $TEXT/valid.bpe.fr-en \
--joined-dictionary --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \
--destdir data-bin/iwslt17.de_fr.en.bpe16k \
--workers 10
# Train a multilingual transformer model
# NOTE: the command below assumes 1 GPU, but accumulates gradients from
# 8 fwd/bwd passes to simulate training on 8 GPUs
$ mkdir -p checkpoints/multilingual_transformer
$ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \
--max-epoch 50 \
--ddp-backend=no_c10d \
--task multilingual_translation --lang-pairs de-en,fr-en \
--arch multilingual_transformer_iwslt_de_en \
--share-decoders --share-decoder-input-output-embed \
--optimizer adam --adam-betas '(0.9, 0.98)'
--lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \
--warmup-updates 4000 --warmup-init-lr '1e-07' \
--label-smoothing 0.1 --criterion label_smoothed_cross_entropy
--dropout 0.3 --weight-decay 0.0001 \
--save-dir checkpoints/multilingual_transformer \
--max-tokens 4000 \
--update-freq 8
# Generate and score the test set with sacrebleu
$ SRC=de
$ sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \
| python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \
> iwslt17.test.${SRC}-en.${SRC}.bpe
$ cat iwslt17.test.${SRC}-en.${SRC}.bpe | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \
--task multilingual_translation --source-lang ${SRC} --target-lang en \
--path checkpoints/multilingual_transformer/checkpoint_best.pt \
--buffer 2000 --batch-size 128 \
--beam 5 --remove-bpe=sentencepiece \
> iwslt17.test.${SRC}-en.en.sys
$ grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \
| sacrebleu --test-set iwslt17 --language-pair ${SRC}-en
```
126 changes: 126 additions & 0 deletions examples/translation/prepare-iwslt17-multilingual.sh
@@ -0,0 +1,126 @@
#!/bin/bash
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

SRCS=(
"de"
"fr"
)
TGT=en

ROOT=$(dirname "$0")
SCRIPTS=$ROOT/../../scripts
SPM_TRAIN=$SCRIPTS/spm_train.py
SPM_ENCODE=$SCRIPTS/spm_encode.py

BPESIZE=16384
ORIG=$ROOT/iwslt17_orig
DATA=$ROOT/iwslt17.de_fr.en.bpe16k
mkdir -p "$ORIG" "$DATA"

TRAIN_MINLEN=1 # remove sentences with <1 BPE token
TRAIN_MAXLEN=250 # remove sentences with >250 BPE tokens

URLS=(
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz"
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz"
)
ARCHIVES=(
"de-en.tgz"
"fr-en.tgz"
)
VALID_SETS=(
"IWSLT17.TED.dev2010.de-en IWSLT17.TED.tst2010.de-en IWSLT17.TED.tst2011.de-en IWSLT17.TED.tst2012.de-en IWSLT17.TED.tst2013.de-en IWSLT17.TED.tst2014.de-en IWSLT17.TED.tst2015.de-en"
"IWSLT17.TED.dev2010.fr-en IWSLT17.TED.tst2010.fr-en IWSLT17.TED.tst2011.fr-en IWSLT17.TED.tst2012.fr-en IWSLT17.TED.tst2013.fr-en IWSLT17.TED.tst2014.fr-en IWSLT17.TED.tst2015.fr-en"
)

# download and extract data
for ((i=0;i<${#URLS[@]};++i)); do
ARCHIVE=$ORIG/${ARCHIVES[i]}
if [ -f "$ARCHIVE" ]; then
echo "$ARCHIVE already exists, skipping download"
else
URL=${URLS[i]}
wget -P "$ORIG" "$URL"
if [ -f "$ARCHIVE" ]; then
echo "$URL successfully downloaded."
else
echo "$URL not successfully downloaded."
exit 1
fi
fi
FILE=${ARCHIVE: -4}
if [ -e "$FILE" ]; then
echo "$FILE already exists, skipping extraction"
else
tar -C "$ORIG" -xzvf "$ARCHIVE"
fi
done

echo "pre-processing train data..."
for SRC in "${SRCS[@]}"; do
for LANG in "${SRC}" "${TGT}"; do
cat "$ORIG/${SRC}-${TGT}/train.tags.${SRC}-${TGT}.${LANG}" \
| grep -v '<url>' \
| grep -v '<talkid>' \
| grep -v '<keywords>' \
| grep -v '<speaker>' \
| grep -v '<reviewer' \
| grep -v '<translator' \
| grep -v '<doc' \
| grep -v '</doc>' \
| sed -e 's/<title>//g' \
| sed -e 's/<\/title>//g' \
| sed -e 's/<description>//g' \
| sed -e 's/<\/description>//g' \
| sed 's/^\s*//g' \
| sed 's/\s*$//g' \
> "$DATA/train.${SRC}-${TGT}.${LANG}"
done
done

echo "pre-processing valid data..."
for ((i=0;i<${#SRCS[@]};++i)); do
SRC=${SRCS[i]}
VALID_SET=${VALID_SETS[i]}
for FILE in ${VALID_SET[@]}; do
for LANG in "$SRC" "$TGT"; do
grep '<seg id' "$ORIG/${SRC}-${TGT}/${FILE}.${LANG}.xml" \
| sed -e 's/<seg id="[0-9]*">\s*//g' \
| sed -e 's/\s*<\/seg>\s*//g' \
| sed -e "s/\’/\'/g" \
> "$DATA/valid.${SRC}-${TGT}.${LANG}"
done
done
done

# learn BPE with sentencepiece
TRAIN_FILES=$(for SRC in "${SRCS[@]}"; do echo $DATA/train.${SRC}-${TGT}.${SRC}; echo $DATA/train.${SRC}-${TGT}.${TGT}; done | tr "\n" ",")
echo "learning joint BPE over ${TRAIN_FILES}..."
python "$SPM_TRAIN" \
--input=$TRAIN_FILES \
--model_prefix=$DATA/sentencepiece.bpe \
--vocab_size=$BPESIZE \
--character_coverage=1.0 \
--model_type=bpe

# encode train/valid/test
echo "encoding train/valid with learned BPE..."
for SRC in "${SRCS[@]}"; do
for LANG in "$SRC" "$TGT"; do
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs "$DATA/train.${SRC}-${TGT}.${SRC} $DATA/train.${SRC}-${TGT}.${TGT}" \
--outputs "$DATA/train.bpe.${SRC}-${TGT}.${SRC} $DATA/train.bpe.${SRC}-${TGT}.${TGT}" \
--min-len $TRAIN_MINLEN --max-len $TRAIN_MAXLEN
python "$SPM_ENCODE" \
--model "$DATA/sentencepiece.bpe.model" \
--output_format=piece \
--inputs "$DATA/valid.${SRC}-${TGT}.${SRC} $DATA/valid.${SRC}-${TGT}.${TGT}" \
--outputs "$DATA/valid.bpe.${SRC}-${TGT}.${SRC} $DATA/valid.bpe.${SRC}-${TGT}.${TGT}"
done
done
15 changes: 11 additions & 4 deletions fairseq/data/round_robin_zip_datasets.py
Expand Up @@ -39,12 +39,11 @@ def __init__(self, datasets, eval_key=None):
self.longest_dataset = dataset
self.longest_dataset_key = key

self._ordered_indices = OrderedDict([
(key, dataset.ordered_indices())
for key, dataset in datasets.items()
])
self._ordered_indices = None

def _map_index(self, key, index):
assert self._ordered_indices is not None, \
'Must call RoundRobinZipDatasets.ordered_indices() first'
return self._ordered_indices[key][index % len(self.datasets[key])]

def __getitem__(self, index):
Expand Down Expand Up @@ -102,6 +101,14 @@ def size(self, index):

def ordered_indices(self):
"""Ordered indices for batching."""
if self._ordered_indices is None:
# Call the underlying dataset's ordered_indices() here, so that we
# get the same random ordering as we would have from using the
# underlying dataset directly.
self._ordered_indices = OrderedDict([
(key, dataset.ordered_indices())
for key, dataset in self.datasets.items()
])
return np.arange(len(self))

@property
Expand Down
2 changes: 1 addition & 1 deletion fairseq/legacy_distributed_data_parallel.py
Expand Up @@ -75,7 +75,6 @@ def __setstate__(self, state):
self._register_grad_hook()

def forward(self, *inputs, **kwargs):
self.need_reduction = True
return self.module(*inputs, **kwargs)

def _register_grad_hook(self):
Expand Down Expand Up @@ -166,6 +165,7 @@ def reduction_fn():
for p in self.module.parameters():

def allreduce_hook(*unused):
self.need_reduction = True
Variable._execution_engine.queue_callback(reduction_fn)

if p.requires_grad:
Expand Down
1 change: 0 additions & 1 deletion fairseq/tasks/fairseq_task.py
Expand Up @@ -226,7 +226,6 @@ def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
- logging outputs to display while training
"""
model.train()

loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
Expand Down
11 changes: 8 additions & 3 deletions fairseq/trainer.py
Expand Up @@ -50,6 +50,7 @@ def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._prev_grad_norm = None
self._wrapped_model = None

self.init_meters(args)
Expand Down Expand Up @@ -215,12 +216,15 @@ def train_step(self, samples, dummy_batch=False):

# gather logging outputs from all replicas
if self.args.distributed_world_size > 1:
logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms],
))
logging_outputs, sample_sizes, ooms, prev_norms = \
zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
))
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
assert all(norm == prev_norms[0] for norm in prev_norms), \
'Fatal error: gradients are inconsistent between workers'

self.meters['oom'].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples):
Expand All @@ -246,6 +250,7 @@ def train_step(self, samples, dummy_batch=False):

# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
self._prev_grad_norm = grad_norm

# take an optimization step
self.optimizer.step()
Expand Down
26 changes: 13 additions & 13 deletions preprocess.py
Expand Up @@ -56,37 +56,37 @@ def build_dictionary(filenames, src=False, tgt=False):
padding_factor=args.padding_factor,
)

if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
raise FileExistsError(dict_path(args.source_lang))
if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)):
raise FileExistsError(dict_path(args.target_lang))

if args.joined_dictionary:
assert (
not args.srcdict or not args.tgtdict
), "cannot use both --srcdict and --tgtdict with --joined-dictionary"
assert not args.srcdict or not args.tgtdict, \
"cannot use both --srcdict and --tgtdict with --joined-dictionary"

if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
elif args.tgtdict:
src_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary({train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True)
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary(
{train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True
)
tgt_dict = src_dict
else:
if args.srcdict:
src_dict = task.load_dictionary(args.srcdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --srcdict is not specified"
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = build_dictionary([train_path(args.source_lang)], src=True)

if target:
if args.tgtdict:
tgt_dict = task.load_dictionary(args.tgtdict)
else:
assert (
args.trainpref
), "--trainpref must be set if --tgtdict is not specified"
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True)
else:
tgt_dict = None
Expand Down
45 changes: 45 additions & 0 deletions scripts/spm_decode.py
@@ -0,0 +1,45 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import absolute_import, division, print_function, unicode_literals

import argparse

import sentencepiece as spm


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="sentencepiece model to use for decoding")
parser.add_argument("--input", required=True, help="input file to decode")
parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
args = parser.parse_args()

sp = spm.SentencePieceProcessor()
sp.Load(args.model)

if args.input_format == "piece":
def decode(l):
return "".join(sp.DecodePieces(l))
elif args.input_format == "id":
def decode(l):
return "".join(sp.DecodeIds(l))
else:
raise NotImplementedError

def tok2int(tok):
# remap reference-side <unk> (represented as <<unk>>) to 0
return int(tok) if tok != "<<unk>>" else 0

with open(args.input, "r", encoding="utf-8") as h:
for line in h:
print(decode(list(map(tok2int, line.rstrip().split()))))


if __name__ == "__main__":
main()

0 comments on commit 0049349

Please sign in to comment.