Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Multilingual training example (#527)
Summary: * Add example for multilingual translation on IWSLT'17 * Match dataset ordering for multilingual_translation and translation * Fix bug with LegacyDistributedDataParallel when calling forward of sub-modules Pull Request resolved: #527 Differential Revision: D14218372 Pulled By: myleott fbshipit-source-id: 2e3fe24aa39476bcc5c9af68ef9a40192db34a3b
- Loading branch information
1 parent
44d27e6
commit 0049349
Showing
10 changed files
with
388 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#!/bin/bash | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
SRCS=( | ||
"de" | ||
"fr" | ||
) | ||
TGT=en | ||
|
||
ROOT=$(dirname "$0") | ||
SCRIPTS=$ROOT/../../scripts | ||
SPM_TRAIN=$SCRIPTS/spm_train.py | ||
SPM_ENCODE=$SCRIPTS/spm_encode.py | ||
|
||
BPESIZE=16384 | ||
ORIG=$ROOT/iwslt17_orig | ||
DATA=$ROOT/iwslt17.de_fr.en.bpe16k | ||
mkdir -p "$ORIG" "$DATA" | ||
|
||
TRAIN_MINLEN=1 # remove sentences with <1 BPE token | ||
TRAIN_MAXLEN=250 # remove sentences with >250 BPE tokens | ||
|
||
URLS=( | ||
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz" | ||
"https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz" | ||
) | ||
ARCHIVES=( | ||
"de-en.tgz" | ||
"fr-en.tgz" | ||
) | ||
VALID_SETS=( | ||
"IWSLT17.TED.dev2010.de-en IWSLT17.TED.tst2010.de-en IWSLT17.TED.tst2011.de-en IWSLT17.TED.tst2012.de-en IWSLT17.TED.tst2013.de-en IWSLT17.TED.tst2014.de-en IWSLT17.TED.tst2015.de-en" | ||
"IWSLT17.TED.dev2010.fr-en IWSLT17.TED.tst2010.fr-en IWSLT17.TED.tst2011.fr-en IWSLT17.TED.tst2012.fr-en IWSLT17.TED.tst2013.fr-en IWSLT17.TED.tst2014.fr-en IWSLT17.TED.tst2015.fr-en" | ||
) | ||
|
||
# download and extract data | ||
for ((i=0;i<${#URLS[@]};++i)); do | ||
ARCHIVE=$ORIG/${ARCHIVES[i]} | ||
if [ -f "$ARCHIVE" ]; then | ||
echo "$ARCHIVE already exists, skipping download" | ||
else | ||
URL=${URLS[i]} | ||
wget -P "$ORIG" "$URL" | ||
if [ -f "$ARCHIVE" ]; then | ||
echo "$URL successfully downloaded." | ||
else | ||
echo "$URL not successfully downloaded." | ||
exit 1 | ||
fi | ||
fi | ||
FILE=${ARCHIVE: -4} | ||
if [ -e "$FILE" ]; then | ||
echo "$FILE already exists, skipping extraction" | ||
else | ||
tar -C "$ORIG" -xzvf "$ARCHIVE" | ||
fi | ||
done | ||
|
||
echo "pre-processing train data..." | ||
for SRC in "${SRCS[@]}"; do | ||
for LANG in "${SRC}" "${TGT}"; do | ||
cat "$ORIG/${SRC}-${TGT}/train.tags.${SRC}-${TGT}.${LANG}" \ | ||
| grep -v '<url>' \ | ||
| grep -v '<talkid>' \ | ||
| grep -v '<keywords>' \ | ||
| grep -v '<speaker>' \ | ||
| grep -v '<reviewer' \ | ||
| grep -v '<translator' \ | ||
| grep -v '<doc' \ | ||
| grep -v '</doc>' \ | ||
| sed -e 's/<title>//g' \ | ||
| sed -e 's/<\/title>//g' \ | ||
| sed -e 's/<description>//g' \ | ||
| sed -e 's/<\/description>//g' \ | ||
| sed 's/^\s*//g' \ | ||
| sed 's/\s*$//g' \ | ||
> "$DATA/train.${SRC}-${TGT}.${LANG}" | ||
done | ||
done | ||
|
||
echo "pre-processing valid data..." | ||
for ((i=0;i<${#SRCS[@]};++i)); do | ||
SRC=${SRCS[i]} | ||
VALID_SET=${VALID_SETS[i]} | ||
for FILE in ${VALID_SET[@]}; do | ||
for LANG in "$SRC" "$TGT"; do | ||
grep '<seg id' "$ORIG/${SRC}-${TGT}/${FILE}.${LANG}.xml" \ | ||
| sed -e 's/<seg id="[0-9]*">\s*//g' \ | ||
| sed -e 's/\s*<\/seg>\s*//g' \ | ||
| sed -e "s/\’/\'/g" \ | ||
> "$DATA/valid.${SRC}-${TGT}.${LANG}" | ||
done | ||
done | ||
done | ||
|
||
# learn BPE with sentencepiece | ||
TRAIN_FILES=$(for SRC in "${SRCS[@]}"; do echo $DATA/train.${SRC}-${TGT}.${SRC}; echo $DATA/train.${SRC}-${TGT}.${TGT}; done | tr "\n" ",") | ||
echo "learning joint BPE over ${TRAIN_FILES}..." | ||
python "$SPM_TRAIN" \ | ||
--input=$TRAIN_FILES \ | ||
--model_prefix=$DATA/sentencepiece.bpe \ | ||
--vocab_size=$BPESIZE \ | ||
--character_coverage=1.0 \ | ||
--model_type=bpe | ||
|
||
# encode train/valid/test | ||
echo "encoding train/valid with learned BPE..." | ||
for SRC in "${SRCS[@]}"; do | ||
for LANG in "$SRC" "$TGT"; do | ||
python "$SPM_ENCODE" \ | ||
--model "$DATA/sentencepiece.bpe.model" \ | ||
--output_format=piece \ | ||
--inputs "$DATA/train.${SRC}-${TGT}.${SRC} $DATA/train.${SRC}-${TGT}.${TGT}" \ | ||
--outputs "$DATA/train.bpe.${SRC}-${TGT}.${SRC} $DATA/train.bpe.${SRC}-${TGT}.${TGT}" \ | ||
--min-len $TRAIN_MINLEN --max-len $TRAIN_MAXLEN | ||
python "$SPM_ENCODE" \ | ||
--model "$DATA/sentencepiece.bpe.model" \ | ||
--output_format=piece \ | ||
--inputs "$DATA/valid.${SRC}-${TGT}.${SRC} $DATA/valid.${SRC}-${TGT}.${TGT}" \ | ||
--outputs "$DATA/valid.bpe.${SRC}-${TGT}.${SRC} $DATA/valid.bpe.${SRC}-${TGT}.${TGT}" | ||
done | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import absolute_import, division, print_function, unicode_literals | ||
|
||
import argparse | ||
|
||
import sentencepiece as spm | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", required=True, | ||
help="sentencepiece model to use for decoding") | ||
parser.add_argument("--input", required=True, help="input file to decode") | ||
parser.add_argument("--input_format", choices=["piece", "id"], default="piece") | ||
args = parser.parse_args() | ||
|
||
sp = spm.SentencePieceProcessor() | ||
sp.Load(args.model) | ||
|
||
if args.input_format == "piece": | ||
def decode(l): | ||
return "".join(sp.DecodePieces(l)) | ||
elif args.input_format == "id": | ||
def decode(l): | ||
return "".join(sp.DecodeIds(l)) | ||
else: | ||
raise NotImplementedError | ||
|
||
def tok2int(tok): | ||
# remap reference-side <unk> (represented as <<unk>>) to 0 | ||
return int(tok) if tok != "<<unk>>" else 0 | ||
|
||
with open(args.input, "r", encoding="utf-8") as h: | ||
for line in h: | ||
print(decode(list(map(tok2int, line.rstrip().split())))) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.