Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser(),
parsers.ExportParser(),
])

self.set_defaults(
Expand Down
1 change: 0 additions & 1 deletion official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ def __init__(self, resnet_size_choices=None):
parsers.BaseParser(),
parsers.PerformanceParser(),
parsers.ImageModelParser(),
parsers.ExportParser(),
parsers.BenchmarkParser(),
])

Expand Down
35 changes: 11 additions & 24 deletions official/utils/arg_parsers/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ class BaseParser(argparse.ArgumentParser):
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
"""

def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True,
stop_threshold=True, batch_size=True, multi_gpu=True,
hooks=True):
hooks=True, export_dir=True):
super(BaseParser, self).__init__(add_help=add_help)

if data_dir:
Expand Down Expand Up @@ -176,6 +177,15 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True,
metavar="<HK>"
)

if export_dir:
self.add_argument(
"--export_dir", "-ed",
help="[default: %(default)s] If set, a SavedModel serialization of "
"the model will be exported to this directory at the end of "
"training. See the README for more details and relevant links.",
metavar="<ED>"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we provide a default directory if user provides none? Not sure if this feature is used frequently by users. If so, I think a default dir may be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that the functionality of SavedModel is currently limited to serving, I don't think this feature is frequently used.



class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments.
Expand Down Expand Up @@ -292,29 +302,6 @@ def __init__(self, add_help=False, data_format=True):
)


class ExportParser(argparse.ArgumentParser):
"""Parsing options for exporting saved models or other graph defs.

This is a separate parser for now, but should be made part of BaseParser
once all models are brought up to speed.

Args:
add_help: Create the "--help" flag. False if class instance is a parent.
export_dir: Create a flag to specify where a SavedModel should be exported.
"""

def __init__(self, add_help=False, export_dir=True):
super(ExportParser, self).__init__(add_help=add_help)
if export_dir:
self.add_argument(
"--export_dir", "-ed",
help="[default: %(default)s] If set, a SavedModel serialization of "
"the model will be exported to this directory at the end of "
"training. See the README for more details and relevant links.",
metavar="<ED>"
)


class BenchmarkParser(argparse.ArgumentParser):
"""Default parser for benchmark logging.

Expand Down
31 changes: 31 additions & 0 deletions official/wide_deep/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,37 @@ Run TensorBoard to inspect the details about the graph and training progression.
tensorboard --logdir=/tmp/census_model
```

## Inference with SavedModel
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:

```
python wide_deep.py --export_dir /tmp/wide_deep_saved_model
```

After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.

Try the following commands to inspect the SavedModel:

**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
```
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/

# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
```

### Inference
Let's use the model to predict the income group of two examples:
```
saved_model_cli run --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_examples='examples=[{"age":[46.], "education_num":[10.], "capital_gain":[7688.], "capital_loss":[0.], "hours_per_week":[38.]}, {"age":[24.], "education_num":[13.], "capital_gain":[0.], "capital_loss":[0.], "hours_per_week":[50.]}]'
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, thanks.

This will print out the predicted classes and class probabilities. Class 0 is the <=50k group and 1 is the >50k group.

## Additional Links

If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).
Expand Down
25 changes: 25 additions & 0 deletions official/wide_deep/wide_deep.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,27 @@ def parse_csv(value):
return dataset


def export_model(model, model_type, export_dir):
"""Export to SavedModel format.

Args:
model: Estimator object
model_type: string indicating model type. "wide", "deep" or "wide_deep"
export_dir: directory to export the model.
"""
wide_columns, deep_columns = build_model_columns()
if model_type == 'wide':
columns = wide_columns
elif model_type == 'deep':
columns = deep_columns
else:
columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
example_input_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
model.export_savedmodel(export_dir, example_input_fn)


def main(argv):
parser = WideDeepArgParser()
flags = parser.parse_args(args=argv[1:])
Expand Down Expand Up @@ -216,6 +237,10 @@ def eval_input_fn():
flags.stop_threshold, results['accuracy']):
break

# Export the model
if flags.export_dir is not None:
export_model(model, flags.model_type, flags.export_dir)


class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model."""
Expand Down