Skip to content

Commit

Permalink
Add option to enable tfrt eager context
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 303254195
Change-Id: Ibee9c3a9cb4f0abf2e1738ed09c7a9ec326b5b64
  • Loading branch information
jaingaurav authored and tensorflower-gardener committed Mar 27, 2020
1 parent 269887e commit 857f0c9
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow/c/eager/c_api_experimental.cc
Expand Up @@ -499,6 +499,10 @@ void TFE_ContextOptionsSetLazyRemoteInputsCopy(TFE_ContextOptions* options,
options->lazy_remote_inputs_copy = lazy_copy;
}

void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
options->use_tfrt = use_tfrt;
}

TFE_CancellationManager* TFE_NewCancellationManager() {
return new TFE_CancellationManager;
}
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/c/eager/c_api_experimental.h
Expand Up @@ -296,6 +296,10 @@ TF_CAPI_EXPORT extern TFE_ContextMirroringPolicy TFE_ContextGetMirroringPolicy(
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy(
TFE_ContextOptions*, bool lazy_copy);

// Sets whether to use TFRT
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
bool use_tfrt);

// -----------------------------------------------------------------------------
// Cancellation APIs.

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/c/eager/c_api_internal.h
Expand Up @@ -60,6 +60,8 @@ struct TFE_ContextOptions {
TFE_ContextMirroringPolicy mirroring_policy{TFE_MIRRORING_NONE};
// If true, lazily copy the remote inputs of a function to the target devices.
bool lazy_remote_inputs_copy = true;
// If true, use TFRT backend
bool use_tfrt = false;
};

struct TFE_Context {
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/python/eager/context.py
Expand Up @@ -411,6 +411,7 @@ def __init__(self,
execution_mode = SYNC
self._default_is_async = execution_mode == ASYNC
self._lazy_remote_inputs_copy = None
self._use_tfrt = None
self._server_def = server_def
self._collective_ops_server_def = None
self._collective_leader = None
Expand Down Expand Up @@ -514,6 +515,8 @@ def ensure_initialized(self):
if self._lazy_remote_inputs_copy is not None:
pywrap_tfe.TFE_ContextOptionsSetLazyRemoteInputsCopy(
opts, self._lazy_remote_inputs_copy)
if self._use_tfrt is not None:
pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
context_handle = pywrap_tfe.TFE_NewContext(opts)
finally:
pywrap_tfe.TFE_DeleteContextOptions(opts)
Expand Down Expand Up @@ -1565,6 +1568,21 @@ def lazy_remote_inputs_copy(self, lazy_copy):
"lazy_remote_inputs_copy should be set before being initialized.")
self._lazy_remote_inputs_copy = lazy_copy

@property
def use_tfrt(self):
return self._use_tfrt

@use_tfrt.setter
def use_tfrt(self, tfrt):
"""Sets whether to use TFRT."""
if not isinstance(tfrt, bool):
raise ValueError("Expecting a boolean but got %s" % type(tfrt))

if self._use_tfrt != tfrt:
if self._initialized:
raise ValueError("use_tfrt should be set before being initialized.")
self._use_tfrt = tfrt

def enable_run_metadata(self):
"""Enables tracing of op execution via RunMetadata.
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/tfe_wrapper.cc
Expand Up @@ -713,6 +713,7 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
&TFE_ContextOptionsSetDevicePlacementPolicy);
m.def("TFE_ContextOptionsSetLazyRemoteInputsCopy",
&TFE_ContextOptionsSetLazyRemoteInputsCopy);
m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
m.def("TFE_ContextOptionsSetMirroringPolicy",
&TFE_ContextOptionsSetMirroringPolicy);
m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
Expand Down

0 comments on commit 857f0c9

Please sign in to comment.