Skip to content

Commit

Permalink
Use TPUInferenceContext in TPUEstimator ExportV2 API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 263059884
  • Loading branch information
tensorflower-gardener committed Aug 13, 2019
1 parent c9abf8d commit b0ec76e
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions tensorflow_estimator/python/estimator/tpu/tpu_estimator.py
Expand Up @@ -1993,8 +1993,17 @@ def _call_model_fn(self, features, labels, is_export_mode=False):
if (running_on_cpu and
isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
return estimator_spec.as_estimator_spec()
# type `EstimatorSpec`. As we are running on the CPU, escape
# the TPUInferenceContext.
graph_context = ops.get_default_graph()._get_control_flow_context()
try:
if isinstance(graph_context, tpu._TPUInferenceContext):
ops.get_default_graph()._set_control_flow_context(
graph_context.outer_context)
return estimator_spec.as_estimator_spec()
finally:
ops.get_default_graph()._set_control_flow_context(
graph_context)
else:
return estimator_spec

Expand Down Expand Up @@ -2845,8 +2854,18 @@ def _call_model_fn(self, features, labels, mode, config):
return super(TPUEstimator, self)._call_model_fn(features, labels, mode,
config)
else:
return super(TPUEstimator, self)._call_model_fn(features, labels, mode,
config)
if mode == _INFERENCE_ON_TPU_MODE:
context = tpu._TPUInferenceContext("tpu_inference")
try:
context.Enter()
result = super(TPUEstimator, self)._call_model_fn(
features, labels, mode, config)
finally:
context.Exit()
return result
else:
return super(TPUEstimator, self)._call_model_fn(
features, labels, mode, config)

def _call_model_fn_for_inference(self, features, labels, mode, config):
"""Wraps `_call_model_fn` for `export_saved_model`."""
Expand Down

0 comments on commit b0ec76e

Please sign in to comment.