From 6b1267e717f0d3ef51b93120edcd42519bb862b5 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Mon, 30 Oct 2017 10:52:57 +0100 Subject: [PATCH 1/2] Fix the EnZh task --- .../data_generators/translate_enzh.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/tensor2tensor/data_generators/translate_enzh.py b/tensor2tensor/data_generators/translate_enzh.py index 7c77a05fc..5bb5b01b1 100644 --- a/tensor2tensor/data_generators/translate_enzh.py +++ b/tensor2tensor/data_generators/translate_enzh.py @@ -35,21 +35,23 @@ # End-of-sentence marker. EOS = text_encoder.EOS_ID - -_ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/" +# This is far from being the real WMT17 task - only toyset here +# you need to register to get UN data and CWT data +# also by convention this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task +_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/" "training-parallel-nc-v12.tgz"), - ("training/news-commentary-v12.zh-en.zh", - "training/news-commentary-v12.zh-en.en")]] + ("training/news-commentary-v12.zh-en.en", + "training/news-commentary-v12.zh-en.zh")]] -_ZHEN_TEST_DATASETS = [[ +_ENZH_TEST_DATASETS = [[ "http://data.statmt.org/wmt17/translation-task/dev.tgz", - ("dev/newsdev2017-zhen-src.zh.sgm", "dev/newsdev2017-zhen-ref.en.sgm") + ("dev/newsdev2017-zhen-src.en.sgm", "dev/newsdev2017-zhen-ref.zh.sgm") ]] @registry.register_problem class TranslateEnzhWmt8k(translate.TranslateProblem): - """Problem spec for WMT Zh-En translation.""" + """Problem spec for WMT En-Zh translation.""" @property def targeted_vocab_size(self): @@ -61,16 +63,16 @@ def num_shards(self): @property def source_vocab_name(self): - return "vocab.zhen-zh.%d" % self.targeted_vocab_size + return "vocab.en-zh-en.%d" % self.targeted_vocab_size @property def target_vocab_name(self): - return "vocab.zhen-en.%d" % self.targeted_vocab_size + return "vocab.enzh-zh.%d" % self.targeted_vocab_size def generator(self, data_dir, tmp_dir, train): - datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS - source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS] - target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS] + datasets = _ENZH_TRAIN_DATASETS if train else _ENZH_TEST_DATASETS + source_datasets = [[item[0], [item[1][0]]] for item in _ENZH_TRAIN_DATASETS] + target_datasets = [[item[0], [item[1][1]]] for item in _ENZH_TRAIN_DATASETS] source_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size, source_datasets) @@ -79,21 +81,18 @@ def generator(self, data_dir, tmp_dir, train): target_datasets) tag = "train" if train else "dev" data_path = translate.compile_data(tmp_dir, datasets, - "wmt_zhen_tok_%s" % tag) - # We generate English->X data by convention, to train reverse translation - # just add the "_rev" suffix to the problem name, e.g., like this. - # --problems=translate_enzh_wmt8k_rev - return translate.bi_vocabs_token_generator(data_path + ".lang2", - data_path + ".lang1", + "wmt_enzh_tok_%s" % tag) + return translate.bi_vocabs_token_generator(data_path + ".lang1", + data_path + ".lang2", source_vocab, target_vocab, EOS) @property def input_space_id(self): - return problem.SpaceID.ZH_TOK + return problem.SpaceID.EN_TOK @property def target_space_id(self): - return problem.SpaceID.EN_TOK + return problem.SpaceID.ZH_TOK def feature_encoders(self, data_dir): source_vocab_filename = os.path.join(data_dir, self.source_vocab_name) From 733de7b7535849195532540d98e7de031c8368ec Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Mon, 30 Oct 2017 16:49:55 +0100 Subject: [PATCH 2/2] typo fix --- tensor2tensor/data_generators/translate_enzh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensor2tensor/data_generators/translate_enzh.py b/tensor2tensor/data_generators/translate_enzh.py index 5bb5b01b1..6b0f36c23 100644 --- a/tensor2tensor/data_generators/translate_enzh.py +++ b/tensor2tensor/data_generators/translate_enzh.py @@ -63,7 +63,7 @@ def num_shards(self): @property def source_vocab_name(self): - return "vocab.en-zh-en.%d" % self.targeted_vocab_size + return "vocab.enzh-en.%d" % self.targeted_vocab_size @property def target_vocab_name(self):