Skip to content

Commit

Permalink
Add the mixed_priority_policy attr to the fallback kernel and make th…
Browse files Browse the repository at this point in the history
…e fallback kernel pass the policy and low priority params to the batch resource.

PiperOrigin-RevId: 622943741
  • Loading branch information
eunjaekim-0 authored and tensorflower-gardener committed Apr 8, 2024
1 parent 6ae914d commit f259352
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 24 deletions.
14 changes: 14 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_resource_base.h
Expand Up @@ -45,6 +45,20 @@ limitations under the License.
namespace tensorflow {
namespace serving {

// Options used to create a batch resource.
struct BatchResourceOptions {
int32_t num_batch_threads;
int32_t max_batch_size;
int32_t batch_timeout_micros;
int32_t max_enqueued_batches;
std::vector<int32_t> allowed_batch_sizes;
int32_t low_priority_max_batch_size;
int32_t low_priority_batch_timeout_micros;
int32_t low_priority_max_enqueued_batches;
std::vector<int32_t> low_priority_allowed_batch_sizes;
MixedPriorityBatchingPolicy mixed_priority_batching_policy;
};

// Base class for resource that encapsulating the state and logic for batching
// tensors.
class BatchResourceBase : public ResourceBase {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/runtime_fallback/runtime/BUILD
Expand Up @@ -55,6 +55,7 @@ cc_library(
":op_logger",
":runtime_fallback_tensor",
"//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
"//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
"//tensorflow/core/platform:statusor",
"//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state",
"//tensorflow/core/runtime_fallback/kernel:kernel_fallback_execute_compat",
Expand Down Expand Up @@ -193,6 +194,7 @@ cc_library(
"//tensorflow/core/kernels:batch_kernels",
"//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
"//tensorflow/core/kernels/batching_util:batch_resource_base",
"//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
"//tensorflow/core/kernels/batching_util:bounded_executor",
"//tensorflow/core/kernels/batching_util:warmup",
"//tensorflow/core/lib/core:refcount",
Expand Down
Expand Up @@ -99,6 +99,8 @@ BatchFunctionFallbackKernelBase::BatchFunctionFallbackKernelBase(
&low_priority_allowed_batch_sizes_));
OP_REQUIRES_OK(c, c->GetAttr("low_priority_max_enqueued_batches",
&low_priority_max_enqueued_batches_));
OP_REQUIRES_OK(c,
c->GetAttr("mixed_priority_policy", &mixed_priority_policy_));

if (shared_name_.empty()) {
// If shared_name is not supplied, use name instead (prevent collisions by
Expand Down
23 changes: 21 additions & 2 deletions tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/kernels/batch_kernels.h"
#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/warmup.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
Expand Down Expand Up @@ -78,6 +79,7 @@ class BatchFunctionFallbackKernelBase : public AsyncOpKernel {
int32 low_priority_batch_timeout_micros_;
int32 low_priority_max_enqueued_batches_;
std::vector<int32> low_priority_allowed_batch_sizes_;
std::string mixed_priority_policy_;
bool enable_large_batch_splitting_;
bool has_attribute_enable_large_batch_splitting_;
bool disable_padding_;
Expand Down Expand Up @@ -203,10 +205,27 @@ void BatchFunctionFallbackKernel<BatchResourceType>::ComputeAsync(
} else {
creator = [this, c]()
-> absl::StatusOr<tensorflow::core::RefCountPtr<BatchResourceType>> {
serving::BatchResourceOptions batch_resource_options;
TF_ASSIGN_OR_RETURN(
batch_resource_options.mixed_priority_batching_policy,
serving::GetMixedPriorityBatchingPolicy(mixed_priority_policy_));
batch_resource_options.num_batch_threads = num_batch_threads_;
batch_resource_options.max_batch_size = max_batch_size_;
batch_resource_options.batch_timeout_micros = batch_timeout_micros_;
batch_resource_options.max_enqueued_batches = max_enqueued_batches_;
batch_resource_options.allowed_batch_sizes = allowed_batch_sizes_;
batch_resource_options.low_priority_max_batch_size =
low_priority_max_batch_size_;
batch_resource_options.low_priority_batch_timeout_micros =
low_priority_batch_timeout_micros_;
batch_resource_options.low_priority_max_enqueued_batches =
low_priority_max_enqueued_batches_;
batch_resource_options.low_priority_allowed_batch_sizes =
low_priority_allowed_batch_sizes_;

std::unique_ptr<BatchResourceType> new_resource;
auto status = BatchResourceType::Create(
c, num_batch_threads_, max_batch_size_, batch_timeout_micros_,
max_enqueued_batches_, allowed_batch_sizes_, batch_function_,
c, batch_resource_options, batch_function_,
enable_large_batch_splitting_, disable_padding_, &new_resource);
if (!status.ok()) return status;
if (c->session_metadata() != nullptr) {
Expand Down
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdlib>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
Expand Down Expand Up @@ -109,18 +109,16 @@ class FallbackBatchResource : public tensorflow::serving::BatchResourceBase {
return batch_function->name();
}

static Status Create(OpKernelContext* c, int32_t num_batch_threads,
int32_t max_batch_size, int32_t batch_timeout_micros,
int32_t max_enqueued_batches,
ArrayRef<int32_t> allowed_batch_sizes,
static Status Create(OpKernelContext* c,
const serving::BatchResourceOptions& options,
tsl::RCReference<const tfrt::Function> bef_func,
bool enable_large_batch_splitting, bool disable_padding,
std::unique_ptr<FallbackBatchResource>* resource) {
const tfrt::ExecutionContext* exec_ctx = nullptr;
TF_RETURN_IF_ERROR(GetTfrtExecutionContext(c, &exec_ctx));

BatcherT::Options batcher_options;
batcher_options.num_batch_threads = num_batch_threads;
batcher_options.num_batch_threads = options.num_batch_threads;
std::shared_ptr<BatcherT> batcher;
TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));

Expand All @@ -135,11 +133,16 @@ class FallbackBatchResource : public tensorflow::serving::BatchResourceBase {
resource->reset(new FallbackBatchResource(
*exec_ctx, *fallback_request_state, std::move(bef_func),
std::move(batcher),
GetBatcherQueueOptions(num_batch_threads, max_batch_size,
batch_timeout_micros, max_enqueued_batches,
allowed_batch_sizes,
enable_large_batch_splitting, disable_padding),
allowed_batch_sizes));
GetBatcherQueueOptions(
options.num_batch_threads, options.max_batch_size,
options.batch_timeout_micros, options.max_enqueued_batches,
options.allowed_batch_sizes, enable_large_batch_splitting,
disable_padding, options.low_priority_max_batch_size,
options.low_priority_batch_timeout_micros,
options.low_priority_max_enqueued_batches,
options.low_priority_allowed_batch_sizes,
options.mixed_priority_batching_policy),
options.allowed_batch_sizes));
return OkStatus();
}

Expand Down Expand Up @@ -410,6 +413,26 @@ REGISTER_OP("_BatchFunctionFallback")
.Attr("low_priority_batch_timeout_micros: int = 0")
.Attr("low_priority_allowed_batch_sizes: list(int) = []")
.Attr("low_priority_max_enqueued_batches: int = 0")
// Policy that determines the mixed priority batching behavior when low
// priority batch parameters are present.
//
// low_priority_padding_with_next_allowed_batch_size: If high priority
// batches time out without reaching the max batch size, low priority inputs
// pad the high priority batches up to the next allowed batch size. A low
// priority only batch gets schedule only when the low priority input times
// out or reaches the max batch size while there is no high priority input
// waiting to be processed.
// low_priority_padding_with_max_batch_size: Same as above but pad up to the
// max batch size.
// priority_isolation: High priority and low priority inputs never share the
// same batch, i.e., no low priority input padding high priority batches.
// Low priority inputs get scheduled only as part of low priority only
// batches as described above.
.Attr(
"mixed_priority_policy: "
"{'low_priority_padding_with_max_batch_size', "
"'low_priority_padding_with_next_allowed_batch_size', "
"'priority_isolation'} = 'low_priority_padding_with_max_batch_size'")
.Attr("Tin: list(type)")
.Attr("Tcaptured: list(type) >= 0")
.Attr("Tout: list(type)")
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/tfrt/mlrt/kernel/BUILD
Expand Up @@ -94,8 +94,9 @@ cc_library(
":context",
":kernel_runner_utils",
"//tensorflow/core:framework",
"//tensorflow/core/kernels/batching_util:batch_resource_base",
"//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
"//tensorflow/core/platform:protobuf",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:statusor",
"//tensorflow/core/runtime_fallback/runtime:fallback_batch_kernel",
"//tensorflow/core/tfrt/fallback:op_kernel_runner_cache",
Expand Down
26 changes: 16 additions & 10 deletions tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h"

#include <cstdint>
#include <cstdlib>
#include <functional>
#include <memory>
Expand All @@ -25,7 +26,9 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/batching_util/batch_resource_base.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h"
#include "tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h"
Expand Down Expand Up @@ -215,25 +218,28 @@ class MlrtBatchResource : public tensorflow::serving::BatchResourceBase {
return batch_function.name();
}

static Status Create(OpKernelContext* c, int32_t num_batch_threads,
int32_t max_batch_size, int32_t batch_timeout_micros,
int32_t max_enqueued_batches,
const std::vector<int32_t>& allowed_batch_sizes,
static Status Create(OpKernelContext* c,
const serving::BatchResourceOptions& options,
mlrt::bc::Function function,
bool enable_large_batch_splitting, bool disable_padding,
std::unique_ptr<MlrtBatchResource>* resource) {
BatcherT::Options batcher_options;
batcher_options.num_batch_threads = num_batch_threads;
batcher_options.num_batch_threads = options.num_batch_threads;
std::shared_ptr<BatcherT> batcher;
TF_RETURN_IF_ERROR(BatcherT::Create(batcher_options, &batcher));

resource->reset(new MlrtBatchResource(
function, std::move(batcher),
GetBatcherQueueOptions(num_batch_threads, max_batch_size,
batch_timeout_micros, max_enqueued_batches,
allowed_batch_sizes,
enable_large_batch_splitting, disable_padding),
allowed_batch_sizes));
GetBatcherQueueOptions(
options.num_batch_threads, options.max_batch_size,
options.batch_timeout_micros, options.max_enqueued_batches,
options.allowed_batch_sizes, enable_large_batch_splitting,
disable_padding, options.low_priority_max_batch_size,
options.low_priority_batch_timeout_micros,
options.low_priority_max_enqueued_batches,
options.low_priority_allowed_batch_sizes,
options.mixed_priority_batching_policy),
options.allowed_batch_sizes));
return OkStatus();
}

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc
Expand Up @@ -1287,6 +1287,10 @@ mlrt::bc::Buffer CreateExecutableForBatchFunctionOp() {
key: "low_priority_max_enqueued_batches"
value { i: 1 }
}
attr {
key: "mixed_priority_policy"
value { s: "low_priority_padding_with_max_batch_size" }
}
attr {
key: "container"
value { s: "container" }
Expand Down

0 comments on commit f259352

Please sign in to comment.