Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions tensor2tensor/data_generators/translate_enzh.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,23 @@

# End-of-sentence marker.
EOS = text_encoder.EOS_ID

_ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
# This is far from being the real WMT17 task - only toyset here
# you need to register to get UN data and CWT data
# also by convention this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
"training-parallel-nc-v12.tgz"),
("training/news-commentary-v12.zh-en.zh",
"training/news-commentary-v12.zh-en.en")]]
("training/news-commentary-v12.zh-en.en",
"training/news-commentary-v12.zh-en.zh")]]

_ZHEN_TEST_DATASETS = [[
_ENZH_TEST_DATASETS = [[
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
("dev/newsdev2017-zhen-src.zh.sgm", "dev/newsdev2017-zhen-ref.en.sgm")
("dev/newsdev2017-zhen-src.en.sgm", "dev/newsdev2017-zhen-ref.zh.sgm")
]]


@registry.register_problem
class TranslateEnzhWmt8k(translate.TranslateProblem):
"""Problem spec for WMT Zh-En translation."""
"""Problem spec for WMT En-Zh translation."""

@property
def targeted_vocab_size(self):
Expand All @@ -61,16 +63,16 @@ def num_shards(self):

@property
def source_vocab_name(self):
return "vocab.zhen-zh.%d" % self.targeted_vocab_size
return "vocab.enzh-en.%d" % self.targeted_vocab_size

@property
def target_vocab_name(self):
return "vocab.zhen-en.%d" % self.targeted_vocab_size
return "vocab.enzh-zh.%d" % self.targeted_vocab_size

def generator(self, data_dir, tmp_dir, train):
datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS]
target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS]
datasets = _ENZH_TRAIN_DATASETS if train else _ENZH_TEST_DATASETS
source_datasets = [[item[0], [item[1][0]]] for item in _ENZH_TRAIN_DATASETS]
target_datasets = [[item[0], [item[1][1]]] for item in _ENZH_TRAIN_DATASETS]
source_vocab = generator_utils.get_or_generate_vocab(
data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size,
source_datasets)
Expand All @@ -79,21 +81,18 @@ def generator(self, data_dir, tmp_dir, train):
target_datasets)
tag = "train" if train else "dev"
data_path = translate.compile_data(tmp_dir, datasets,
"wmt_zhen_tok_%s" % tag)
# We generate English->X data by convention, to train reverse translation
# just add the "_rev" suffix to the problem name, e.g., like this.
# --problems=translate_enzh_wmt8k_rev
return translate.bi_vocabs_token_generator(data_path + ".lang2",
data_path + ".lang1",
"wmt_enzh_tok_%s" % tag)
return translate.bi_vocabs_token_generator(data_path + ".lang1",
data_path + ".lang2",
source_vocab, target_vocab, EOS)

@property
def input_space_id(self):
return problem.SpaceID.ZH_TOK
return problem.SpaceID.EN_TOK

@property
def target_space_id(self):
return problem.SpaceID.EN_TOK
return problem.SpaceID.ZH_TOK

def feature_encoders(self, data_dir):
source_vocab_filename = os.path.join(data_dir, self.source_vocab_name)
Expand Down