Skip to content

Commit

Permalink
tf_host_callback in tfrt/ifrt use DeviceMgr instead of StaticDeviceMgr
Browse files Browse the repository at this point in the history
for better generality and not owning the DeviceMgr since that
can be owned/created in fallback_request

PiperOrigin-RevId: 632264827
  • Loading branch information
deqiangc authored and tensorflower-gardener committed May 14, 2024
1 parent c4f3a1d commit de043dc
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 14 deletions.
1 change: 1 addition & 0 deletions tensorflow/core/tfrt/fallback/fallback_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class FallbackState {
const SessionOptions &session_options() const { return session_options_; }

const DeviceMgr &device_manager() const { return device_manager_; }
DeviceMgr &device_manager() { return device_manager_; }

const DeviceSet &device_set() const { return device_set_; }

Expand Down
10 changes: 4 additions & 6 deletions tensorflow/core/tfrt/ifrt/ifrt_model_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ class IfrtModelContext {
std::shared_ptr<xla::ifrt::Client> client,
IfrtServingCoreSelector* ifrt_serving_core_selector,
const tsl::thread::ThreadPool* thread_pool,
std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr,
tensorflow::DeviceMgr* device_mgr,
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
: client_(std::move(client)),
ifrt_serving_core_selector_(ifrt_serving_core_selector),
thread_pool_(*thread_pool),
device_mgr_(std::move(device_mgr)),
device_mgr_(device_mgr),
shape_representation_fn_(shape_representation_fn) {}

void RegisterHandle(ServingExecutableRegistry::Handle handle) {
Expand Down Expand Up @@ -100,9 +100,7 @@ class IfrtModelContext {
return restore_tensor_registry_;
}

tensorflow::StaticDeviceMgr* GetDeviceMgr() const {
return device_mgr_.get();
}
tensorflow::DeviceMgr* GetDeviceMgr() const { return device_mgr_; }
IfrtServingCoreSelector* GetIfrtServingCoreSelector() const {
return ifrt_serving_core_selector_;
}
Expand All @@ -127,7 +125,7 @@ class IfrtModelContext {
IfrtServingCoreSelector* ifrt_serving_core_selector_; // May be nullptr
const tsl::thread::ThreadPool& thread_pool_;

std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
tensorflow::DeviceMgr* device_mgr_; // Not owned.
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_ =
tensorflow::IdentityShapeRepresentationFn();

Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ IfrtServingExecutable::Create(
IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry,
const IfrtRestoreTensorRegistry* ifrt_restore,
tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
tensorflow::StaticDeviceMgr* device_mgr,
tensorflow::DeviceMgr* device_mgr,
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
IfrtServingCoreSelector* ifrt_serving_core_selector) {
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -237,7 +237,7 @@ GroupHostCallbackByKey(const Tf2HloResult& tf2hlo_result) {
// TODO: shape propagation in module
absl::StatusOr<xla::HostCallback> BuildHostCallback(
absl::string_view key, const HostCallbackBuilderInfo& builder_info,
mlir::ModuleOp module, tensorflow::StaticDeviceMgr* device_mgr,
mlir::ModuleOp module, tensorflow::DeviceMgr* device_mgr,
std::vector<std::unique_ptr<TfHostCallback>>& tf_host_callbacks) {
VLOG(2) << "BuildHostCallback for key: " << key;

Expand Down Expand Up @@ -310,7 +310,7 @@ absl::StatusOr<xla::HostCallback> BuildHostCallback(

absl::StatusOr<std::vector<xla::HostCallback>> BuildHostCallbacks(
const Tf2HloResult& tf2hlo_result, mlir::ModuleOp module,
tensorflow::StaticDeviceMgr* device_mgr,
tensorflow::DeviceMgr* device_mgr,
std::vector<std::unique_ptr<TfHostCallback>>& tf_host_callbacks) {
TF_ASSIGN_OR_RETURN(auto host_callback_maps,
GroupHostCallbackByKey(tf2hlo_result));
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class IfrtServingExecutable {
IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry,
const IfrtRestoreTensorRegistry* ifrt_restore,
tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
tensorflow::StaticDeviceMgr* device_mgr,
tensorflow::DeviceMgr* device_mgr,
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
IfrtServingCoreSelector* ifrt_serving_core_selector);

Expand Down Expand Up @@ -140,7 +140,7 @@ class IfrtServingExecutable {
IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry,
const IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry,
tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
tensorflow::StaticDeviceMgr* device_mgr,
tensorflow::DeviceMgr* device_mgr,
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn,
IfrtServingCoreSelector* ifrt_serving_core_selector,
tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata)
Expand Down Expand Up @@ -176,7 +176,7 @@ class IfrtServingExecutable {
IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry_;
const IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry_;
tfrt::ConcurrentWorkQueue* checkpoint_loader_queue_;
tensorflow::StaticDeviceMgr* device_mgr_; // Not owned. For host callback.
tensorflow::DeviceMgr* device_mgr_; // Not owned. For host callback.
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
IfrtServingCoreSelector* ifrt_serving_core_selector_;

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/tfrt/ifrt/tf_host_callback.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ absl::StatusOr<std::unique_ptr<TfHostCallback>> TfHostCallback::Create(
absl::string_view entry_function_name,
absl::Span<const DtypeAndShape> operand_type_and_shapes,
absl::Span<const DtypeAndShape> result_type_and_shapes,
tensorflow::StaticDeviceMgr* device_mgr) {
tensorflow::DeviceMgr* device_mgr) {
tensorflow::SessionOptions options;
// Explicitly disable non-CPU devices to avoid triggering TPU device
// initialization inside TF.
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/tfrt/ifrt/tf_host_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TfHostCallback {
absl::string_view entry_function_name,
absl::Span<const DtypeAndShape> operand_type_and_shapes,
absl::Span<const DtypeAndShape> result_type_and_shapes,
tensorflow::StaticDeviceMgr* device_mgr);
tensorflow::DeviceMgr* device_mgr);

// The host callback function takes two pointer arrays, each element of which
// points to allocated host buffer in host layout according to corresponding
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/core/tfrt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/device.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/statusor.h"
Expand Down Expand Up @@ -79,6 +81,11 @@ class ModelRuntimeContext {
flib_def_ = flib_def;
}

tensorflow::DeviceMgr* device_mgr() const { return device_mgr_; }
void set_device_mgr(tensorflow::DeviceMgr* device_mgr) {
device_mgr_ = device_mgr;
}

bool is_local_session() const { return is_local_session_; }

void set_is_local_session(bool is_local_session) {
Expand All @@ -104,6 +111,7 @@ class ModelRuntimeContext {
const GraphDef* graph_def_ = nullptr;
const CallableOptions* callable_options_ = nullptr;
tfrt::ResourceContext* resource_context_ = nullptr;
tensorflow::DeviceMgr* device_mgr_ = nullptr;

FunctionLibraryDefinition* flib_def_ = nullptr;

Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/tfrt/saved_model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime:device_mgr",
"//tensorflow/core/framework:function_proto_cc",
"//tensorflow/core/framework:graph_proto_cc",
"//tensorflow/core/framework:tensor_proto_cc",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/tfrt/saved_model/saved_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ absl::StatusOr<std::unique_ptr<SavedModel>> SavedModelImpl::LoadSavedModel(
CombineSignatureDefs(meta_graph_def.signature_def());
model_context.set_graph_def(&meta_graph_def.graph_def());
model_context.set_callable_options(&callable_options);
model_context.set_device_mgr(&fallback_state->device_manager());

TF_RETURN_IF_ERROR(
options.graph_execution_options.runtime->CreateRuntimeResources(
model_context));
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/tfrt/tfrt_session/tfrt_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "Eigen/ThreadPool" // from @eigen_archive
#include "llvm/ADT/STLExtras.h"
#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/local_session_selection.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/session_factory.h"
Expand Down Expand Up @@ -218,6 +219,7 @@ class TfrtSession : public tensorflow::Session {
&options, /*export_dir=*/"unknown_export_dir", resource_context.get());
// TODO(b/334641254): Offer a Session option that prunes the graph_def.
model_context.set_graph_def(&graph);
model_context.set_device_mgr(&fallback_state->device_manager());
// In the multi-host case, this prevents local Sessions from running
// global resource creation functions.
model_context.set_is_local_session(
Expand Down

0 comments on commit de043dc

Please sign in to comment.