Skip to content

Commit

Permalink
Store variable's dtype and shape in IfrtRestoreTensorRegistry and s…
Browse files Browse the repository at this point in the history
…tore the `IfrtRestoreTensorRegistry` in IfrtServingExecutable for looking up the dtype and shape.

PiperOrigin-RevId: 625527802
  • Loading branch information
SiqiaoWu1993 authored and tensorflower-gardener committed Apr 17, 2024
1 parent 33cadc4 commit e0d6269
Show file tree
Hide file tree
Showing 15 changed files with 147 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name,
model_name, entry_function_name.str(), *std::move(submodule),
ifrt_model_context.GetClient(), &ifrt_model_context.GetThreadPool(),
&ifrt_model_context.GetLoadedVariableRegistry(),
&ifrt_model_context.GetRestoreTensorRegistry(),
ifrt_model_context.GetDeviceMgr(),
ifrt_model_context.GetShapeRepresentationFn());

Expand Down
12 changes: 6 additions & 6 deletions tensorflow/core/tfrt/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ cc_library(
hdrs = ["ifrt_serving_executable.h"],
deps = [
":ifrt_loaded_variable_registry",
":ifrt_restore_tensor_registry",
":ifrt_tensor_utils",
":sharding_utils",
":tf_host_callback",
Expand All @@ -66,8 +67,6 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -118,7 +117,10 @@ cc_library(
srcs = ["ifrt_restore_tensor_registry.cc"],
hdrs = ["ifrt_restore_tensor_registry.h"],
deps = [
"//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types",
"//tensorflow/core:framework",
"//tensorflow/core/framework:tensor",
"//tensorflow/core/framework:types_proto_cc",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
Expand Down Expand Up @@ -240,8 +242,6 @@ cc_library(
":sharding_utils",
"//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types",
"//tensorflow/core:framework",
"//tensorflow/core/tfrt/mlrt/interpreter:future",
"//tensorflow/core/tfrt/utils:fallback_tensor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -300,8 +300,6 @@ tf_cc_test(
"//tensorflow/core/framework:tensor_matcher",
"//tensorflow/core/framework:tensor_testutil",
"//tensorflow/core/framework:types_proto_cc",
"//tensorflow/core/tfrt/mlrt/interpreter:future",
"//tensorflow/core/tfrt/utils:fallback_tensor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_googletest//:gtest_main",
Expand Down Expand Up @@ -389,6 +387,7 @@ tf_cc_test(
tags = ["no_oss"],
deps = [
":ifrt_loaded_variable_registry",
":ifrt_restore_tensor_registry",
":ifrt_serving_executable",
":sharding_utils",
":tf_host_callback",
Expand Down Expand Up @@ -440,6 +439,7 @@ tf_cc_test(
deps = [
":ifrt_executable_registry",
":ifrt_loaded_variable_registry",
":ifrt_restore_tensor_registry",
":ifrt_serving_executable",
":tf_host_callback",
"//tensorflow/compiler/mlir/tensorflow",
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/resource_loader.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h"
#include "tensorflow/core/tfrt/ifrt/tf_host_callback.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -79,13 +80,14 @@ CreateIfrtServingExecutable(mlir::MLIRContext& context) {
xla::ifrt::test_util::GetClient());

IfrtLoadedVariableRegistry ifrt_loaded_variable_registry;
IfrtRestoreTensorRegistry ifrt_restore_tensor_registry;
TF_ASSIGN_OR_RETURN(std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr,
CreateTfStaticDeviceMgr());

return std::make_unique<IfrtServingExecutable>(
"test", "main", std::move(mlir_module), client, &GetThreadPool(),
&ifrt_loaded_variable_registry, device_mgr.get(),
tensorflow::IdentityShapeRepresentationFn());
&ifrt_loaded_variable_registry, &ifrt_restore_tensor_registry,
device_mgr.get(), tensorflow::IdentityShapeRepresentationFn());
}

TEST(IfrtExecutableRegistry, Basic) {
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ namespace ifrt_serving {
class IfrtLoadedVariableRegistry {
public:
struct LoadedVariable {
DtypeAndShape dtype_and_shape;
xla::ifrt::Future<absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>> array;
};
using LoadedVariableConstructor =
Expand Down
29 changes: 9 additions & 20 deletions tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
#include "xla/hlo/ir/hlo_sharding.h"
Expand All @@ -32,12 +33,9 @@ limitations under the License.
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
#include "tensorflow/core/tfrt/ifrt/sharding_utils.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/future.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tsl/concurrency/ref_count.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand All @@ -63,6 +61,8 @@ absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> LoadIfrtVariable(
thread_pool);
}

} // namespace

absl::StatusOr<ifrt_serving::DtypeAndShape> GetDtypeAndShape(
const ResourceHandle& resource_handle) {
const std::vector<DtypeAndPartialTensorShape>& dtype_and_partial_shapes =
Expand All @@ -84,31 +84,20 @@ absl::StatusOr<ifrt_serving::DtypeAndShape> GetDtypeAndShape(
return dtype_and_shape;
}

} // namespace

std::string GetRuntimeNameFromVarHandle(const ResourceHandle& handle) {
return absl::StrCat(handle.container(), "__", handle.name());
}

absl::Status LoadRestoredTensorAsIfrtLoadedVariable(
const tensorflow::Tensor& variable_handle_tensor,
absl::string_view runtime_name,
std::shared_ptr<xla::ifrt::Client> ifrt_client,
const tsl::thread::ThreadPool& thread_pool,
ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry,
ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry,
tfrt::ConcurrentWorkQueue* checkpoint_loader_queue,
const VariableDeviceShardingConfigProto& sharding_config) {
if (variable_handle_tensor.dtype() != DT_RESOURCE) {
return absl::InvalidArgumentError(
absl::StrCat("variable_handle_tensor is ",
DataTypeString(variable_handle_tensor.dtype()),
" but expected DT_RESOURCE"));
}
const ResourceHandle& handle =
variable_handle_tensor.scalar<ResourceHandle>()();
std::string runtime_name = GetRuntimeNameFromVarHandle(handle);
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>> restored_tensor_future =
ifrt_restore_tensor_registry.Get(runtime_name);
ifrt_restore_tensor_registry.GetRestoredTensor(runtime_name);
if (!restored_tensor_future.IsValid()) {
return absl::InternalError(absl::StrCat(
"LoadVariableOp: failed to fetch variable tensor: ", runtime_name));
Expand All @@ -120,17 +109,17 @@ absl::Status LoadRestoredTensorAsIfrtLoadedVariable(
xla::ifrt::Future<absl::StatusOr<tsl::RCReference<xla::ifrt::Array>>>(
loaded_variable_promise);

TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape,
GetDtypeAndShape(handle));
TF_ASSIGN_OR_RETURN(
absl::StatusOr<ifrt_serving::DtypeAndShape> dtype_and_shape,
ifrt_restore_tensor_registry.GetDtypeAndShape(runtime_name));
// TODO(b/330360798) Load variable on devices from the result of core
// selection.
TF_RETURN_IF_ERROR(ifrt_loaded_variable_registry.TryRegisterLoadedVariable(
runtime_name,
[&]() -> absl::StatusOr<
ifrt_serving::IfrtLoadedVariableRegistry::LoadedVariable> {
return ifrt_serving::IfrtLoadedVariableRegistry::LoadedVariable(
{.dtype_and_shape = dtype_and_shape,
.array = loaded_variable_future});
{.array = loaded_variable_future});
}));
restored_tensor_future.OnReady(
[ifrt_client = ifrt_client, &thread_pool = thread_pool,
Expand Down
7 changes: 4 additions & 3 deletions tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@ limitations under the License.
#include "absl/status/status.h"
#include "xla/python/ifrt/client.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/future.h"
#include "tsl/platform/threadpool.h"
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime

namespace tensorflow {
namespace ifrt_serving {

absl::StatusOr<ifrt_serving::DtypeAndShape> GetDtypeAndShape(
const ResourceHandle& resource_handle);

// Returns the runtime name from the resource handle. The name will be concat of
// handle's container name and handle's name.
std::string GetRuntimeNameFromVarHandle(const ResourceHandle& handle);
Expand All @@ -44,7 +45,7 @@ std::string GetRuntimeNameFromVarHandle(const ResourceHandle& handle);
// can look for the actual loaded variable value in
// `ifrt_loaded_variable_registry`.
absl::Status LoadRestoredTensorAsIfrtLoadedVariable(
const tensorflow::Tensor& variable_handle_tensor,
absl::string_view runtime_name,
std::shared_ptr<xla::ifrt::Client> ifrt_client,
const tsl::thread::ThreadPool& thread_pool,
ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry,
Expand Down
41 changes: 19 additions & 22 deletions tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ limitations under the License.
#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h"
#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h"
#include "tensorflow/core/tfrt/mlrt/interpreter/future.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tsl/concurrency/ref_count.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/env.h"
Expand Down Expand Up @@ -86,17 +84,17 @@ TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableNotFoundWrongName) {
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>>::CreatePromise();
auto future = xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>>(promise);

TF_ASSERT_OK(
restored_tensor_registry.TryRegister("var_x_wrong", std::move(future)));
IfrtRestoreTensorRegistry::RestoredTensorInfo restored_tensor_info = {
GetDtypeAndShape(variable_handle.scalar<ResourceHandle>()()).value(),
future};
TF_ASSERT_OK(restored_tensor_registry.TryRegister("var_x_wrong",
restored_tensor_info));
promise.Set(input_tensor);
TF_ASSERT_OK(LoadRestoredTensorAsIfrtLoadedVariable(
variable_handle, client, thread_pool, restored_tensor_registry,
loaded_variable_registry, restore_work_queue.get(), sharding_config));
TF_ASSERT_OK_AND_ASSIGN(
auto v,
loaded_variable_registry.GetLoadedVariable(GetRuntimeNameFromVarHandle(
variable_handle.scalar<ResourceHandle>()())));
EXPECT_THAT(v.array.Await().status(), StatusIs(absl::StatusCode::kNotFound));
EXPECT_THAT(
LoadRestoredTensorAsIfrtLoadedVariable(
"var_x", client, thread_pool, restored_tensor_registry,
loaded_variable_registry, restore_work_queue.get(), sharding_config),
StatusIs(absl::StatusCode::kNotFound));
}

TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableSucceed) {
Expand Down Expand Up @@ -129,21 +127,20 @@ TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableSucceed) {
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>>::CreatePromise();
auto future = xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>>(promise);

TF_ASSERT_OK(restored_tensor_registry.TryRegister(
GetRuntimeNameFromVarHandle(variable_handle.scalar<ResourceHandle>()()),
std::move(future)));
IfrtRestoreTensorRegistry::RestoredTensorInfo restored_tensor_info = {
GetDtypeAndShape(variable_handle.scalar<ResourceHandle>()()).value(),
future};

TF_ASSERT_OK(
restored_tensor_registry.TryRegister("var_x", restored_tensor_info));
TF_ASSERT_OK(LoadRestoredTensorAsIfrtLoadedVariable(
variable_handle, client, thread_pool, restored_tensor_registry,
"var_x", client, thread_pool, restored_tensor_registry,
loaded_variable_registry, restore_work_queue.get(), sharding_config));
promise.Set(input_tensor);
TF_ASSERT_OK_AND_ASSIGN(
auto v,
loaded_variable_registry.GetLoadedVariable(GetRuntimeNameFromVarHandle(
variable_handle.scalar<ResourceHandle>()())));
TF_ASSERT_OK_AND_ASSIGN(auto v,
loaded_variable_registry.GetLoadedVariable("var_x"));
TF_ASSERT_OK_AND_ASSIGN(auto assembled_array, v.array.Await());

EXPECT_TRUE(v.dtype_and_shape.shape.IsSameSize(TensorShape({2, 2})));
EXPECT_EQ(v.dtype_and_shape.dtype, DT_INT32);
TF_ASSERT_OK_AND_ASSIGN(auto disassembled_arrays,
assembled_array->DisassembleIntoSingleDeviceArrays(
xla::ifrt::ArrayCopySemantics::kAlwaysCopy));
Expand Down
26 changes: 19 additions & 7 deletions tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,47 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
#include "xla/python/ifrt/future.h"
#include "tensorflow/core/framework/tensor.h"

namespace tensorflow {
namespace ifrt_serving {

absl::Status IfrtRestoreTensorRegistry::TryRegister(
absl::string_view name,
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>> tensor_future) {
absl::string_view name, RestoredTensorInfo restored_tensor_info) {
absl::MutexLock lock(&mutex_);
auto& variable = restored_tensors_[name];
if (variable.IsValid()) {
auto& info = restored_tensors_[name];
if (info.tensor_future.IsValid()) {
return absl::AlreadyExistsError(
absl::StrCat("Variable '", name, "' already registered."));
}
variable = std::move(tensor_future);
info = std::move(restored_tensor_info);
return absl::OkStatus();
}

xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>>
IfrtRestoreTensorRegistry::Get(absl::string_view name) const {
IfrtRestoreTensorRegistry::GetRestoredTensor(absl::string_view name) const {
absl::MutexLock lock(&mutex_);
auto it = restored_tensors_.find(name);
if (it == restored_tensors_.end()) {
return xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>>(
absl::NotFoundError(absl::StrCat("Variable '", name, "' not found.")));
}

return it->second;
return it->second.tensor_future;
}

absl::StatusOr<DtypeAndShape> IfrtRestoreTensorRegistry::GetDtypeAndShape(
absl::string_view name) const {
absl::MutexLock lock(&mutex_);
auto it = restored_tensors_.find(name);
if (it == restored_tensors_.end()) {
return absl::NotFoundError(
absl::StrCat("Variable '", name, "' not found."));
}

return it->second.dtype_and_shape;
}

} // namespace ifrt_serving
Expand Down
22 changes: 15 additions & 7 deletions tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,38 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h"
#include "xla/python/ifrt/future.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"

namespace tensorflow {
namespace ifrt_serving {

// This class is thread safe.
class IfrtRestoreTensorRegistry {
public:
struct RestoredTensorInfo {
DtypeAndShape dtype_and_shape;
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>> tensor_future;
};
// Tries to register a loaded variable with the given name.
// Returns an error if the named tensor already exists.
absl::Status TryRegister(
absl::string_view name,
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>> tensor_future)
absl::Status TryRegister(absl::string_view name,
RestoredTensorInfo restored_tensor_info)
ABSL_LOCKS_EXCLUDED(mutex_);

xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>> Get(
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>> GetRestoredTensor(
absl::string_view name) const ABSL_LOCKS_EXCLUDED(mutex_);

absl::StatusOr<DtypeAndShape> GetDtypeAndShape(absl::string_view name) const
ABSL_LOCKS_EXCLUDED(mutex_);

private:
mutable absl::Mutex mutex_;
absl::flat_hash_map<std::string,
xla::ifrt::Future<absl::StatusOr<tensorflow::Tensor>>>
restored_tensors_ ABSL_GUARDED_BY(mutex_);
absl::flat_hash_map<std::string, RestoredTensorInfo> restored_tensors_
ABSL_GUARDED_BY(mutex_);
};

} // namespace ifrt_serving
Expand Down
Loading

0 comments on commit e0d6269

Please sign in to comment.