Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release 1.9-rc1 cherry-pick request: opt out of saving tpu graph #19925

Merged
merged 3 commits into from
Jun 12, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 11 additions & 9 deletions tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ def host_call(logits):
export_outputs['classes'] =
export_output_lib.ClassificationOutput(classes=classes)

tpu.outside_compilation(host_call, logits)
tpu.outside_compilation(host_call, [logits])

...
```
Expand All @@ -1830,6 +1830,7 @@ def __init__(self,
predict_batch_size=None,
batch_axis=None,
eval_on_tpu=True,
export_to_tpu=True,
warm_start_from=None):
"""Constructs an `TPUEstimator` instance.

Expand Down Expand Up @@ -1872,6 +1873,8 @@ def __init__(self,
False or `PER_HOST_V2`, batch_axis is ignored.
eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
export_to_tpu: If True, `export_savedmodel()` exports a metagraph for
serving on TPU besides the one on CPU.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
Expand Down Expand Up @@ -1943,6 +1946,8 @@ def __init__(self,
use_tpu,
eval_on_tpu)

self._export_to_tpu = export_to_tpu

self._is_input_fn_invoked = None

def _add_meta_graph_for_mode(self,
Expand All @@ -1965,11 +1970,11 @@ def _add_meta_graph_for_mode(self,
save_variables,
mode=mode)

input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
input_receiver_fn_map[mode]}
export_tags = [tag_constants.SERVING, tag_constants.TPU]
mode = _REWRITE_FOR_INFERENCE_MODE
try:
if self._export_to_tpu:
input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
input_receiver_fn_map[mode]}
export_tags = [tag_constants.SERVING, tag_constants.TPU]
mode = _REWRITE_FOR_INFERENCE_MODE
(super(TPUEstimator, self).
_add_meta_graph_for_mode(builder,
input_receiver_fn_map,
Expand All @@ -1978,9 +1983,6 @@ def _add_meta_graph_for_mode(self,
save_variables=False,
mode=mode,
export_tags=export_tags))
except Exception as error: # pylint: disable=broad-except
logging.warning('Saving meta graph for TPU failed: {}.'
.format(str(error)))

def _call_model_fn(self, features, labels, mode, config):
if mode == _REWRITE_FOR_INFERENCE_MODE:
Expand Down