Skip to content

Commit

Permalink
Merge pull request #53515 from skye:jitstate
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 417898285
Change-Id: Id6a2c0c48482e1186e46022729f84a858d5bfe1a
  • Loading branch information
tensorflower-gardener committed Dec 23, 2021
2 parents dc8b97e + 7ba7339 commit ebdbd61
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 82 deletions.
83 changes: 41 additions & 42 deletions tensorflow/compiler/xla/python/jax_jit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,39 @@ namespace py = pybind11;
namespace {

// Protected by the GIL.
GlobalJitState& global_state = *new GlobalJitState();
JitState& global_state = *new JitState();

// TODO(phawkins): Google style guide forbids thread-local values with
// non-trivial destructors.
ABSL_CONST_INIT thread_local ThreadLocalJitState thread_local_state; // NOLINT

bool JitIsDisabled() {
return thread_local_state.disable_jit.value_or(global_state.disable_jit);
}
ABSL_CONST_INIT thread_local JitState thread_local_state; // NOLINT

} // namespace

GlobalJitState& GetGlobalState() { return global_state; }
ThreadLocalJitState& GetLocalState() { return thread_local_state; }
JitState& GetGlobalState() { return global_state; }
JitState& GetLocalState() { return thread_local_state; }

bool GetDisableJit() {
CHECK(global_state.disable_jit.has_value());
return thread_local_state.disable_jit.value_or(*global_state.disable_jit);
}

bool GetEnableX64() {
return thread_local_state.enable_x64.value_or(global_state.enable_x64);
CHECK(global_state.enable_x64.has_value());
return thread_local_state.enable_x64.value_or(*global_state.enable_x64);
}

absl::optional<pybind11::function> GetPostHook() {
return thread_local_state.post_hook.has_value() ? thread_local_state.post_hook
: global_state.post_hook;
}

static std::string OptionalDebugString(
const absl::optional<py::object> optional) {
if (optional.has_value()) {
return py::cast<std::string>(py::str(optional.value()));
} else {
return "None";
}
}

std::string CallSignature::DebugString() const {
Expand All @@ -103,13 +119,6 @@ std::string CallSignature::DebugString() const {
const xla::PyArgSignature& s) {
out->append(s.DebugString());
};
std::string thread_local_extra_jit_context_str;
if (thread_local_extra_jit_context.has_value()) {
thread_local_extra_jit_context_str =
py::cast<std::string>(py::str(thread_local_extra_jit_context.value()));
} else {
thread_local_extra_jit_context_str = "None";
}
return absl::StrFormat(
"static args (positional + keyword): %s\nstatic arg keyword names: %s\n"
"dynamic arg signatures (positional + keyword): %s\n"
Expand All @@ -123,9 +132,8 @@ std::string CallSignature::DebugString() const {
absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter),
absl::StrJoin(dynamic_arg_names, ",", py_object_formatter),
absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter), // new line
device, jax_enable_x64,
py::cast<std::string>(py::str(global_extra_jit_context)),
thread_local_extra_jit_context_str);
device, jax_enable_x64, OptionalDebugString(global_extra_jit_context),
OptionalDebugString(thread_local_extra_jit_context));
}

bool CallSignature::operator==(const CallSignature& other) const {
Expand All @@ -152,7 +160,10 @@ bool CallSignature::operator==(const CallSignature& other) const {
". The error was:\n", e.what()));
}
}) &&
global_extra_jit_context.equal(other.global_extra_jit_context) &&
(global_extra_jit_context.has_value() ==
other.global_extra_jit_context.has_value()) &&
(!global_extra_jit_context.has_value() ||
global_extra_jit_context->equal(*other.global_extra_jit_context)) &&
(thread_local_extra_jit_context.has_value() ==
other.thread_local_extra_jit_context.has_value()) &&
(!thread_local_extra_jit_context.has_value() ||
Expand Down Expand Up @@ -784,7 +795,7 @@ xla::StatusOr<py::object> CompiledFunction::Call(
xla::GlobalPyRefManager()->MaybeCollectGarbage();

auto& tls = thread_local_state;
if (tls.disable_jit.value_or(global_state.disable_jit)) {
if (GetDisableJit()) {
return fun_(*py::reinterpret_borrow<py::args>(args),
**kwargs.value_or(py::kwargs()));
}
Expand Down Expand Up @@ -817,7 +828,7 @@ xla::StatusOr<py::object> CompiledFunction::Call(
**kwargs.value_or(py::kwargs())))[0]);
}

bool jax_enable_x64 = tls.enable_x64.value_or(global_state.enable_x64);
bool jax_enable_x64 = GetEnableX64();
arguments.signature.jax_enable_x64 = jax_enable_x64;
// The C++ jit do not support Tracers arguments inputs yet. The Python-based
// jit function will be called if any of the dynamic arguments is unsupported.
Expand Down Expand Up @@ -930,8 +941,7 @@ xla::StatusOr<py::object> CompiledFunction::Call(
py::object out = cache_entry->out_pytree_def.Unflatten(flat_device_arrays);

// If there is a post-hook function, call it with the inputs and the outputs.
absl::optional<py::object> post_hook =
tls.post_hook.has_value() ? tls.post_hook : global_state.post_hook;
absl::optional<py::object> post_hook = GetPostHook();
if (post_hook) {
(*post_hook)(AsPyHandle(), args,
py::cast<py::dict>(kwargs.value_or(py::kwargs())), out);
Expand Down Expand Up @@ -1281,23 +1291,12 @@ void BuildJaxjitSubmodule(py::module& m) {
std::move(donate_argnums), std::move(cache));
},
py::is_method(cfun_type));
py::class_<GlobalJitState> global_state_(jitlib, "GlobalJitState");
global_state_.def_readwrite("disable_jit", &GlobalJitState::disable_jit);
global_state_.def_readwrite("enable_x64", &GlobalJitState::enable_x64);
global_state_.def_readwrite("extra_jit_context",
&GlobalJitState::extra_jit_context);
global_state_.def_readwrite("post_hook", &GlobalJitState::post_hook);

py::class_<ThreadLocalJitState> thread_local_state_(jitlib,
"ThreadLocalJitState");
thread_local_state_.def_readwrite("disable_jit",
&ThreadLocalJitState::disable_jit);
thread_local_state_.def_readwrite("enable_x64",
&ThreadLocalJitState::enable_x64);
thread_local_state_.def_readwrite("extra_jit_context",
&ThreadLocalJitState::extra_jit_context);
thread_local_state_.def_readwrite("post_hook",
&ThreadLocalJitState::post_hook);

py::class_<JitState> jit_state_(jitlib, "JitState");
jit_state_.def_readwrite("disable_jit", &JitState::disable_jit);
jit_state_.def_readwrite("enable_x64", &JitState::enable_x64);
jit_state_.def_readwrite("extra_jit_context", &JitState::extra_jit_context);
jit_state_.def_readwrite("post_hook", &JitState::post_hook);

jitlib.def(
"global_state", [&]() { return &global_state; },
Expand All @@ -1306,7 +1305,7 @@ void BuildJaxjitSubmodule(py::module& m) {
"thread_local_state", [&]() { return &thread_local_state; },
py::return_value_policy::reference);

jitlib.def("jit_is_disabled", &JitIsDisabled);
jitlib.def("jit_is_disabled", &GetDisableJit);
jitlib.def("get_enable_x64", &GetEnableX64);

// TODO(phawkins): delete the following methods after dropping compatibility
Expand Down
43 changes: 19 additions & 24 deletions tensorflow/compiler/xla/python/jax_jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,42 +37,37 @@ namespace jax {
// - possibly a thread-local value, which initially is absl::nullopt and
// overrides the global value if set. The thread-local state is
// used to implement context managers that locally override the global state.
// TODO(phawkins): consider changing the global state to optional types to
// catch cases where we fail to set it.
struct GlobalJitState {
bool disable_jit = false;
bool enable_x64 = false;

// Extra context that should be included in the JIT cache key. Must be
// hashable and have an equality defined.
pybind11::object extra_jit_context = pybind11::none();

// A callback that, if present, is called when a JITted function is executed
// from cache.
absl::optional<pybind11::function> post_hook;
};

struct ThreadLocalJitState {
~ThreadLocalJitState() {
struct JitState {
~JitState() {
if (extra_jit_context) {
// We likely do not hold the GIL, so we hand the Python object to the
// global reference manager to destroy.
// We likely do not hold the GIL if this JitState is thread-local, so we
// hand the Python object to the global reference manager to destroy.
pybind11::object o = std::move(*extra_jit_context);
xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1));
extra_jit_context = absl::nullopt;
}
}

absl::optional<bool> disable_jit;
absl::optional<bool> enable_x64;

// Extra context that should be included in the JIT cache key. Must be
// hashable and have an equality defined.
absl::optional<pybind11::object> extra_jit_context;

// A callback that, if present, is called when a JITted function is executed
// from cache. May be unset even in global state.
absl::optional<pybind11::function> post_hook;
};

// Returns the value for jax_enable_x64 (defined by a thread-local value if
// defined, defaulting to the value of the flag otherwise).
JitState& GetGlobalState();
JitState& GetLocalState();

// Getters for JitState fields that first look in thread-local state, then
// fallback to global state.
bool GetDisableJit();
bool GetEnableX64();
GlobalJitState& GetGlobalState();
ThreadLocalJitState& GetLocalState();
absl::optional<pybind11::function> GetPostHook();

// The signature of Python jitted function call, partitioned into:
// - dynamic positional arguments (i.e. positional args which are not static)
Expand Down Expand Up @@ -112,7 +107,7 @@ struct CallSignature {
bool jax_enable_x64;

// Opaque additional context that should be included as part of the cache key.
pybind11::object global_extra_jit_context;
absl::optional<pybind11::object> global_extra_jit_context;
absl::optional<pybind11::object> thread_local_extra_jit_context;

bool operator==(const CallSignature& other) const;
Expand Down
9 changes: 4 additions & 5 deletions tensorflow/compiler/xla/python/pmap_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,9 @@ xla::StatusOr<py::object> PmapFunction::Call(py::args args, py::kwargs kwargs) {
}

// Get dynamic argument signatures.
GlobalJitState& global_state = jax::GetGlobalState();
ThreadLocalJitState& tls = jax::GetLocalState();
const bool jax_enable_x64 = tls.enable_x64.value_or(global_state.enable_x64);
JitState& global_state = jax::GetGlobalState();
JitState& tls = jax::GetLocalState();
const bool jax_enable_x64 = GetEnableX64();
arguments.signature.jax_enable_x64 = jax_enable_x64;
for (py::handle arg : arguments.flat_dynamic_args) {
auto signature_or_error = xla::PyArgSignatureOfValue(arg, jax_enable_x64);
Expand Down Expand Up @@ -530,8 +530,7 @@ xla::StatusOr<py::object> PmapFunction::Call(py::args args, py::kwargs kwargs) {
cache_entry.out_pytree_def.Unflatten(flat_sharded_device_arrays);

// If there is a post-hook function, call it with the inputs and the outputs.
absl::optional<py::object> post_hook =
tls.post_hook.has_value() ? tls.post_hook : global_state.post_hook;
absl::optional<py::object> post_hook = GetPostHook();
if (post_hook) {
(*post_hook)(this->AsPyHandle(), args, kwargs, out);
}
Expand Down
18 changes: 7 additions & 11 deletions tensorflow/compiler/xla/python/xla_extension/jax_jit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,16 @@ Client = xla_extension.Client
CompiledFunctionCache = xla_extension.CompiledFunctionCache
CompiledFunction = xla_extension.CompiledFunction

class GlobalJitState:
disable_jit: bool
enable_x64: bool
class JitState:
disable_jit: Optional[bool]
enable_x64: Optional[bool]
extra_jit_context: Any
post_hook: Optional[Callable]

class ThreadLocalJitState:
disable_jit: bool
enable_x64: bool
extra_jit_context: Any

def global_state() -> GlobalJitState: ...
def thread_local_state() -> ThreadLocalJitState: ...
def global_state() -> JitState: ...
def thread_local_state() -> JitState: ...

def jit_is_enabled() -> bool: ...
def jit_is_disabled() -> bool: ...
def get_enable_x64() -> bool: ...

def set_disable_jit_cpp_flag(__arg: bool) -> None: ...
Expand Down

0 comments on commit ebdbd61

Please sign in to comment.