From 7b4b4b17b323d55be744e714fdbe3cb761001168 Mon Sep 17 00:00:00 2001 From: Noe Casas Date: Mon, 17 Jul 2017 11:47:34 +0200 Subject: [PATCH 1/2] Refactor user directory loading functionality and use it also from t2t-datagen --- tensor2tensor/bin/t2t-datagen | 2 ++ tensor2tensor/bin/t2t-trainer | 26 ++------------------ tensor2tensor/utils/usr_dir.py | 45 ++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 24 deletions(-) create mode 100644 tensor2tensor/utils/usr_dir.py diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index 44e4b34d3..dc2e08b22 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -48,6 +48,7 @@ from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing from tensor2tensor.utils import registry +from tensor2tensor.utils import usr_dir import tensorflow as tf @@ -273,6 +274,7 @@ def set_random_seed(): def main(_): tf.logging.set_verbosity(tf.logging.INFO) + usr_dir.import_usr_dir() # Calculate the list of problems to generate. problems = sorted( diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 322957028..9bdaf4be7 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -36,38 +36,16 @@ import sys # Dependency imports from tensor2tensor.utils import trainer_utils as utils - +from tensor2tensor.utils import usr_dir import tensorflow as tf flags = tf.flags FLAGS = flags.FLAGS -flags.DEFINE_string("t2t_usr_dir", "", - "Path to a Python module that will be imported. The " - "__init__.py file should include the necessary imports. " - "The imported files should contain registrations, " - "e.g. @registry.register_model calls, that will then be " - "available to the t2t-trainer.") - - -def import_usr_dir(): - """Import module at FLAGS.t2t_usr_dir, if provided.""" - if not FLAGS.t2t_usr_dir: - return - dir_path = os.path.expanduser(FLAGS.t2t_usr_dir) - if dir_path[-1] == "/": - dir_path = dir_path[:-1] - containing_dir, module_name = os.path.split(dir_path) - tf.logging.info("Importing user module %s from path %s", module_name, - containing_dir) - sys.path.insert(0, containing_dir) - importlib.import_module(module_name) - sys.path.pop(0) - def main(_): tf.logging.set_verbosity(tf.logging.INFO) - import_usr_dir() + usr_dir.import_usr_dir() utils.log_registry() utils.validate_flags() utils.run( diff --git a/tensor2tensor/utils/usr_dir.py b/tensor2tensor/utils/usr_dir.py new file mode 100644 index 000000000..d6ae5699d --- /dev/null +++ b/tensor2tensor/utils/usr_dir.py @@ -0,0 +1,45 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility to load code from an external directory supplied by user.""" + +import os +import sys +import importlib +import tensorflow as tf + +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("t2t_usr_dir", "", + "Path to a Python module that will be imported. The " + "__init__.py file should include the necessary imports. " + "The imported files should contain registrations, " + "e.g. @registry.register_model calls, that will then be " + "available to the t2t-trainer.") + + +def import_usr_dir(): + """Import module at FLAGS.t2t_usr_dir, if provided.""" + if not FLAGS.t2t_usr_dir: + return + dir_path = os.path.expanduser(FLAGS.t2t_usr_dir) + if dir_path[-1] == "/": + dir_path = dir_path[:-1] + containing_dir, module_name = os.path.split(dir_path) + tf.logging.info("Importing user module %s from path %s", module_name, + containing_dir) + sys.path.insert(0, containing_dir) + importlib.import_module(module_name) + sys.path.pop(0) From a1edf01ab8f988429e901c856f63eaf22219ed9d Mon Sep 17 00:00:00 2001 From: Noe Casas Date: Tue, 18 Jul 2017 20:23:02 +0200 Subject: [PATCH 2/2] Move flag declaration to the binary files --- tensor2tensor/bin/t2t-datagen | 9 ++++++++- tensor2tensor/bin/t2t-trainer | 8 +++++++- tensor2tensor/utils/usr_dir.py | 18 ++++-------------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index dc2e08b22..63eb7e45e 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -65,6 +65,13 @@ flags.DEFINE_integer("max_cases", 0, "Maximum number of cases to generate (unbounded if 0).") flags.DEFINE_integer("random_seed", 429459, "Random seed to use.") +flags.DEFINE_string("t2t_usr_dir", "", + "Path to a Python module that will be imported. The " + "__init__.py file should include the necessary imports. " + "The imported files should contain registrations, " + "e.g. @registry.register_model calls, that will then be " + "available to the t2t-datagen.") + # Mapping from problems that we can generate data for to their generators. # pylint: disable=g-long-lambda _SUPPORTED_PROBLEM_GENERATORS = { @@ -274,7 +281,7 @@ def set_random_seed(): def main(_): tf.logging.set_verbosity(tf.logging.INFO) - usr_dir.import_usr_dir() + usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # Calculate the list of problems to generate. problems = sorted( diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 9bdaf4be7..6b3f4de71 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -42,10 +42,16 @@ import tensorflow as tf flags = tf.flags FLAGS = flags.FLAGS +flags.DEFINE_string("t2t_usr_dir", "", + "Path to a Python module that will be imported. The " + "__init__.py file should include the necessary imports. " + "The imported files should contain registrations, " + "e.g. @registry.register_model calls, that will then be " + "available to the t2t-trainer.") def main(_): tf.logging.set_verbosity(tf.logging.INFO) - usr_dir.import_usr_dir() + usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) utils.log_registry() utils.validate_flags() utils.run( diff --git a/tensor2tensor/utils/usr_dir.py b/tensor2tensor/utils/usr_dir.py index d6ae5699d..ed5623c8e 100644 --- a/tensor2tensor/utils/usr_dir.py +++ b/tensor2tensor/utils/usr_dir.py @@ -19,22 +19,12 @@ import importlib import tensorflow as tf -flags = tf.flags -FLAGS = flags.FLAGS -flags.DEFINE_string("t2t_usr_dir", "", - "Path to a Python module that will be imported. The " - "__init__.py file should include the necessary imports. " - "The imported files should contain registrations, " - "e.g. @registry.register_model calls, that will then be " - "available to the t2t-trainer.") - - -def import_usr_dir(): - """Import module at FLAGS.t2t_usr_dir, if provided.""" - if not FLAGS.t2t_usr_dir: +def import_usr_dir(usr_dir): + """Import user module, if provided.""" + if not usr_dir: return - dir_path = os.path.expanduser(FLAGS.t2t_usr_dir) + dir_path = os.path.expanduser(usr_dir) if dir_path[-1] == "/": dir_path = dir_path[:-1] containing_dir, module_name = os.path.split(dir_path)