Permalink
Browse files

use existing inter and intra flags, and fix wide deep test. (#5110)

  • Loading branch information...
robieta authored and k-w-w committed Aug 16, 2018
1 parent b64f67d commit 909ee1b3151ff643c1d54fd73c8fd8f98f193bb0
Showing with 6 additions and 9 deletions.
  1. +2 −1 official/wide_deep/census_test.py
  2. +4 −8 official/wide_deep/wide_deep_run_loop.py
@@ -95,7 +95,8 @@ def build_and_test_estimator(self, model_type):
"""Ensure that model trains and minimizes loss."""
model = census_main.build_estimator(
self.temp_dir, model_type,
model_column_fn=census_dataset.build_model_columns)
model_column_fn=census_dataset.build_model_columns,
inter_op=0, intra_op=0)

# Train for 1 step to initialize model and evaluate initial loss
def get_input_fn(num_epochs, shuffle, batch_size):
@@ -38,6 +38,10 @@ def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base()
flags_core.define_benchmark()
flags_core.define_performance(
num_parallel_calls=False, inter_op=True, intra_op=True,
synthetic_data=False, max_train_steps=False, dtype=False,
all_reduce_alg=False)

flags.adopt_module_key_flags(flags_core)

@@ -48,14 +52,6 @@ def define_wide_deep_flags():
flags.DEFINE_boolean(
name="download_if_missing", default=True, help=flags_core.help_wrap(
"Download data to data_dir if it is not already present."))
flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0,
help="Number of threads to use for inter-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number.")
flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0,
help="Number of threads to use for intra-op parallelism. "
"If left as default value of 0, the system will pick an appropriate number.")


def export_model(model, model_type, export_dir, model_column_fn):

0 comments on commit 909ee1b

Please sign in to comment.