Skip to content

Commit

Permalink
Add DHT versions for fallback resource support.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 528945537
  • Loading branch information
rohitju authored and tensorflower-gardener committed May 3, 2023
1 parent 3c3ef70 commit 03286f5
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 9 deletions.
42 changes: 42 additions & 0 deletions tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ def SetResourceOp : FallbackSync_Op<"set_resource", [CoreRT_TypedAttributeTrait]
let assemblyFormat = "operands attr-dict";
}

def SetResourceDhtOp : FallbackSync_Op<"set_resource_dht", [CoreRT_TypedAttributeTrait]> {
let summary = "Set a DHT in resource array";

let description = [{
Set a DHT in resource array.

arg: the tensor to be set in the resource array.
index: the index in the resource array
}];

let arguments = (ins
TensorType:$arg,
I64Attr:$index
);

let results = (outs);

let assemblyFormat = "operands attr-dict";
}

def GetResourceOp : FallbackSync_Op<"get_resource",
[CoreRT_TypedAttributeTrait]> {
let summary = "get a tensor in resource array";
Expand All @@ -82,6 +102,28 @@ def GetResourceOp : FallbackSync_Op<"get_resource",
let assemblyFormat = "attr-dict `:` type($results)";
}

def GetResourceDhtOp : FallbackSync_Op<"get_resource_dht",
[CoreRT_TypedAttributeTrait]> {
let summary = "get a DHT in resource array";

let description = [{
Get a tensor in resource array.

indices: the indices in the resource array.
results: the tensor values for the corresponding indices.
}];

let arguments = (ins
I64ArrayAttr:$indices
);

let results = (outs
Variadic<TensorType>:$results
);

let assemblyFormat = "attr-dict `:` type($results)";
}

def CreateOp: FallbackSync_Op<"createop", []> {
let summary = "The Fallback CreateOp";

Expand Down
11 changes: 11 additions & 0 deletions tensorflow/core/tfrt/graph_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ cc_library(
tags = ["no_oss"],
deps = [
":graph_execution_options",
":sync_resource_state",
"//learning/brain/experimental/tfrt/mlrt/application/tensorflow/compiler/transforms:import_model",
"//learning/brain/experimental/tfrt/mlrt/application/tensorflow/kernel:context",
"//learning/brain/experimental/tfrt/native_lowering/kernels:sync_context",
Expand Down Expand Up @@ -130,6 +131,7 @@ cc_library(
"//learning/brain/experimental/tfrt/mlrt/application/tensorflow/kernel",
"//learning/brain/experimental/tfrt/native_lowering/kernels",
"//learning/brain/experimental/tfrt/native_lowering/kernels:kernels_alwayslink",
"//learning/brain/tfrt/mlrt/application/vrooml:kernel",
"//learning/infra/mira/mlrt/interpreter:context",
"//learning/infra/mira/mlrt/interpreter:value",
"//tensorflow/core/framework:graph_proto_cc",
Expand Down Expand Up @@ -195,3 +197,12 @@ tf_cc_test(
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "sync_resource_state",
hdrs = ["sync_resource_state.h"],
visibility = ["//visibility:public"],
deps = [
"@tf_runtime//:tensor",
],
)
18 changes: 12 additions & 6 deletions tensorflow/core/tfrt/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ limitations under the License.
#include "tensorflow/core/tfrt/fallback/cost_recorder.h"
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
Expand Down Expand Up @@ -104,7 +105,8 @@ tensorflow::Status RunMlrtFunction(
const tsl::RCReference<tfrt::RequestContext>& request_context,
tfrt::ConcurrentWorkQueue& work_queue,
absl::Span<const tensorflow::Tensor> inputs,
std::vector<tensorflow::Tensor>* outputs) {
std::vector<tensorflow::Tensor>* outputs,
SyncResourceState* sync_resource_state) {
DCHECK(function);
const auto* fallback_request_state =
request_context->GetDataIfExists<tfd::KernelFallbackCompatRequestState>();
Expand All @@ -118,7 +120,7 @@ tensorflow::Status RunMlrtFunction(
// TODO(chky, rohitju): Unify tfrt::SyncContext with tf_mlrt::Context.
tfrt::ExecutionContext exec_ctx(request_context);
execution_context.AddUserContext(
std::make_unique<tfrt::SyncContext>(&exec_ctx));
std::make_unique<tfrt::SyncContext>(&exec_ctx, sync_resource_state));

// Set up tf_mlrt::Context which is used for executing tensorflow::OpKernel.
execution_context.AddUserContext(std::make_unique<tf_mlrt::Context>(
Expand Down Expand Up @@ -306,7 +308,8 @@ tensorflow::Status GraphExecutionRunOnFunction(

return RunMlrtFunction(function, *loaded_executable,
request_info->tfrt_request_context,
*request_info->request_queue, inputs, outputs);
*request_info->request_queue, inputs, outputs,
/*sync_resource_state=*/nullptr);
}

DCHECK(func);
Expand Down Expand Up @@ -757,13 +760,15 @@ tensorflow::Status GraphExecutor::InitBytecode(
if (auto function = loaded_executable->GetFunction(kFallbackInitFunction)) {
TF_RETURN_IF_ERROR(RunMlrtFunction(
function, *loaded_executable, request_info->tfrt_request_context,
*request_info->request_queue, {}, &outputs));
*request_info->request_queue, {}, &outputs,
&loaded_graph->sync_resource_state()));
}

if (auto function = loaded_executable->GetFunction(kResourceInitFunction)) {
TF_RETURN_IF_ERROR(RunMlrtFunction(
function, *loaded_executable, request_info->tfrt_request_context,
*request_info->request_queue, {}, &outputs));
*request_info->request_queue, {}, &outputs,
&loaded_graph->sync_resource_state()));
}

return OkStatus();
Expand Down Expand Up @@ -861,7 +866,8 @@ tensorflow::Status GraphExecutor::RunWithSyncInterpreter(
mlrt::ExecutionContext execution_context(
executable_context->bytecode_executable.get());

auto sync_context = std::make_unique<tfrt::SyncContext>(&exec_ctx);
auto sync_context = std::make_unique<tfrt::SyncContext>(
&exec_ctx, &loaded_client_graph.sync_resource_state());
execution_context.AddUserContext(std::move(sync_context));

auto tf_context = std::make_unique<tensorflow::tf_mlrt::Context>(
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/tfrt/graph_executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/tfrt/fallback/fallback_state.h"
#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h"
Expand Down Expand Up @@ -105,7 +106,8 @@ tensorflow::Status RunMlrtFunction(
const tsl::RCReference<tfrt::RequestContext>& request_context,
tfrt::ConcurrentWorkQueue& work_queue,
absl::Span<const tensorflow::Tensor> inputs,
std::vector<tensorflow::Tensor>* outputs);
std::vector<tensorflow::Tensor>* outputs,
SyncResourceState* sync_resource_state);

// Loads (if not yet) and runs a subgraph in a graph as per each request.
class GraphExecutor {
Expand Down Expand Up @@ -178,6 +180,7 @@ class GraphExecutor {

OpKernelRunnerTable& runner_table() { return runner_table_; }
tfd::FallbackResourceArray& resource_array() { return resource_array_; }
SyncResourceState& sync_resource_state() { return sync_resource_state_; }

private:
std::string name_;
Expand All @@ -194,6 +197,7 @@ class GraphExecutor {
std::shared_ptr<ExecutableContext> executable_context_
TF_GUARDED_BY(executable_context_mu_);
mutable absl::once_flag create_cost_recorder_once_;
SyncResourceState sync_resource_state_;
};

// A subgraph constructed by specifying input/output tensors.
Expand Down
48 changes: 48 additions & 0 deletions tensorflow/core/tfrt/graph_executor/sync_resource_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_SYNC_RESOURCE_STATE_H_
#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_SYNC_RESOURCE_STATE_H_

#include <utility>
#include <vector>

#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
namespace tensorflow {
namespace tfrt_stub {

class SyncResourceState {
public:
// Sets `dht` in the array at `index`. `index` should be dense and
// duplicate indices are not allowed.
void SetResourceDht(int index, tfrt::DenseHostTensor dht) {
if (resource_dht_.size() <= index) {
resource_dht_.resize(index + 1);
}

resource_dht_[index] = std::move(dht);
}

tfrt::DenseHostTensor GetResourceDht(int index) const {
return resource_dht_.at(index).CopyRef();
}

private:
std::vector<tfrt::DenseHostTensor> resource_dht_;
};

} // namespace tfrt_stub
} // namespace tensorflow

#endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_SYNC_RESOURCE_STATE_H_
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "learning/brain/experimental/tfrt/mlrt/application/tensorflow/kernel/kernel.h"
#include "learning/brain/experimental/tfrt/native_lowering/kernels/math_kernels.h"
#include "learning/brain/experimental/tfrt/native_lowering/kernels/sync_fallback_kernels.h"
#include "learning/brain/tfrt/mlrt/application/vrooml/kernel.h"
#include "absl/status/statusor.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/platform/status.h"
Expand Down Expand Up @@ -72,6 +73,7 @@ SynchronousGraphExecutor::Create(
tensorflow::tf_mlrt::RegisterTfMlrtKernels(*kernel_registry);
tfrt::cpu::RegisterMlrtMathKernels(kernel_registry.get());
tfrt::cpu::RegisterMlrtFallbackCompatKernels(kernel_registry.get());
tensorflow::vrooml_mlrt::RegisterDhtResourceKernels(*kernel_registry);

tensorflow::StatusOr<std::unique_ptr<tensorflow::tfrt_stub::GraphExecutor>>
graph_executor = tensorflow::tfrt_stub::GraphExecutor::Create(
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/tfrt/saved_model/saved_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ tensorflow::Status RunBytecodeInitializers(
if (auto function = loaded_executable.GetFunction("_tfrt_fallback_init")) {
TF_RETURN_IF_ERROR(RunMlrtFunction(
function, loaded_executable, request_info->tfrt_request_context,
*request_info->request_queue, {}, &outputs));
*request_info->request_queue, {}, &outputs,
/*sync_resource_state=*/nullptr));
}

for (const auto& p : initializers_and_signatures.initializers) {
Expand All @@ -275,7 +276,8 @@ tensorflow::Status RunBytecodeInitializers(
if (auto function = loaded_executable.GetFunction("_tfrt_resource_init")) {
TF_RETURN_IF_ERROR(RunMlrtFunction(
function, loaded_executable, request_info->tfrt_request_context,
*request_info->request_queue, {}, &outputs));
*request_info->request_queue, {}, &outputs,
/*sync_resource_state=*/nullptr));
}

return OkStatus();
Expand Down

0 comments on commit 03286f5

Please sign in to comment.