Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the low priority task queue to the internal::Queue of the SharedBatchScheduler. Update the internal::Queue to accept the callback function with the additional task vector argument and SharedBatchScheduler::ThreadLogic to pass the tasks from the low priority queue to the callback function. #63224

Merged
merged 1 commit into from Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions tensorflow/core/kernels/batching_util/BUILD
Expand Up @@ -187,6 +187,7 @@ tf_cc_test(
size = "small",
srcs = ["shared_batch_scheduler_test.cc"],
deps = [
":batch_scheduler",
":fake_clock_env",
":shared_batch_scheduler",
"//tensorflow/core:lib",
Expand All @@ -196,9 +197,7 @@ tf_cc_test(
"//tensorflow/core/platform:status_matchers",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_absl//absl/utility",
],
)

Expand Down
92 changes: 75 additions & 17 deletions tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <stddef.h>

#include <atomic>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <functional>
Expand Down Expand Up @@ -121,6 +122,11 @@ class SharedBatchScheduler
using BatchTaskUniqueptr = std::unique_ptr<Batch<TaskType>>;
using BatchUniquePtr =
std::variant<BatchTaskUniqueptr, BatchTaskHandleUniquePtr>;

using ProcessBatchCallback =
std::variant<std::function<void(std::unique_ptr<Batch<TaskType>>)>,
std::function<void(std::unique_ptr<Batch<TaskType>>,
std::vector<std::unique_ptr<TaskType>>)>>;
// TODO(b/25089730): Tune defaults based on best practices as they develop.
struct Options {
// The name to use for the pool of batch threads.
Expand Down Expand Up @@ -253,8 +259,7 @@ class SharedBatchScheduler
PriorityQueueOptions low_priority_queue_options;
};
Status AddQueue(const QueueOptions& options,
std::function<void(std::unique_ptr<Batch<TaskType>>)>
process_batch_callback,
ProcessBatchCallback process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue);

private:
Expand All @@ -272,9 +277,7 @@ class SharedBatchScheduler

// Called by `AddQueue`.
Status AddQueueAfterRewritingOptions(
const QueueOptions& options,
std::function<void(std::unique_ptr<Batch<TaskType>>)>
process_batch_callback,
const QueueOptions& options, ProcessBatchCallback process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue);

static bool BatchExists(const BatchUniquePtr& batch_to_process);
Expand Down Expand Up @@ -334,8 +337,15 @@ namespace internal {
template <typename TaskType>
class Queue {
public:
using ProcessBatchCallback =
using ProcessBatchCallbackWithoutPaddingTasks =
std::function<void(std::unique_ptr<Batch<TaskType>>)>;
using ProcessBatchCallbackWithPaddingTasks =
std::function<void(std::unique_ptr<Batch<TaskType>>,
std::vector<std::unique_ptr<TaskType>>)>;
using ProcessBatchCallback =
std::variant<ProcessBatchCallbackWithoutPaddingTasks,
ProcessBatchCallbackWithPaddingTasks>;

using SchedulableBatchCallback = std::function<void()>;
using SplitInputTaskIntoSubtasksCallback = std::function<Status(
std::unique_ptr<TaskType>* input_task, int open_batch_remaining_slot,
Expand Down Expand Up @@ -388,8 +398,14 @@ class Queue {
// Batches are guaranteed to form at task enqueue time.
std::unique_ptr<Batch<TaskType>> ScheduleBatchWithEagerSplit();

// Retrieves the tasks up to the specified size from the low priority task
// queue. It will immediately return an empty vector when
// enable_priority_queue is false.
std::vector<std::unique_ptr<TaskType>> GetLowPriorityTasks(size_t size);

// Processes a batch that has been returned earlier by ScheduleBatch().
void ProcessBatch(std::unique_ptr<Batch<TaskType>> batch);
void ProcessBatch(std::unique_ptr<Batch<TaskType>> batch,
std::vector<std::unique_ptr<TaskType>> padding_task);

// Determines whether the queue is empty, i.e. has no tasks waiting or being
// processed.
Expand Down Expand Up @@ -463,6 +479,10 @@ class Queue {
const std::deque<std::unique_ptr<Batch<TaskType>>>& GetBatches() const
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Gets the low priority task queue.
TaskQueue<TaskType>& GetLowPriorityTaskQueue()
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);

const typename SharedBatchScheduler<TaskType>::QueueOptions options_;

// The environment to use.
Expand Down Expand Up @@ -497,6 +517,14 @@ class Queue {
std::deque<std::unique_ptr<Batch<BatchInputTaskHandle<TaskType>>>>
task_handle_batches_ TF_GUARDED_BY(mu_);

// The enqueued tasks for low priority inputs.
// Each element corresponds to a task to be dequeued. These tasks to be
// consumed by `Queue<TaskType>::ProcessBatch` to either pad the high priority
// batches below or form their own batch to be executed.
//
// Used iff `QueueOptions.enable_lazy_split` is false.
TaskQueue<TaskType> low_priority_tasks_ TF_GUARDED_BY(mu_);

// The enqueued batches for low priority input.
// Each element corresponds to a task to be dequeued and processed by
// `Queue<TaskType>::ProcessBatch`.
Expand Down Expand Up @@ -602,9 +630,7 @@ SharedBatchScheduler<TaskType>::~SharedBatchScheduler() {

template <typename TaskType>
Status SharedBatchScheduler<TaskType>::AddQueue(
const QueueOptions& options,
std::function<void(std::unique_ptr<Batch<TaskType>>)>
process_batch_callback,
const QueueOptions& options, ProcessBatchCallback process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue) {
QueueOptions rewrite_options = options;
if ((!rewrite_options.enable_large_batch_splitting) &&
Expand All @@ -623,9 +649,7 @@ Status SharedBatchScheduler<TaskType>::AddQueue(

template <typename TaskType>
Status SharedBatchScheduler<TaskType>::AddQueueAfterRewritingOptions(
const QueueOptions& options,
std::function<void(std::unique_ptr<Batch<TaskType>>)>
process_batch_callback,
const QueueOptions& options, ProcessBatchCallback process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue) {
if (options.input_batch_size_limit == 0) {
return errors::InvalidArgument(
Expand Down Expand Up @@ -793,15 +817,20 @@ void SharedBatchScheduler<TaskType>::ThreadLogic() {
batch_to_schedule->AddTask(std::move(task_handles[i]->GetSplitTask()));
}
batch_to_schedule->Close();

} else {
// The corresponding `queue_for_batch` must be created with
// `enable_lazy_split=false`.
batch_to_schedule =
std::move(absl::get<BatchTaskUniqueptr>(batch_to_process));
}

queue_for_batch->ProcessBatch(std::move(batch_to_schedule));
// TODO(b/316379576): Make the policy determine between padding up to the max
// batch size and up to the next smallest allowed batch size.
size_t low_priority_task_padding_size =
queue_for_batch->max_execution_batch_size() - batch_to_schedule->size();
queue_for_batch->ProcessBatch(
std::move(batch_to_schedule),
queue_for_batch->GetLowPriorityTasks(low_priority_task_padding_size));
}

namespace internal {
Expand Down Expand Up @@ -1135,7 +1164,23 @@ Queue<TaskType>::ScheduleBatch() {
}

template <typename TaskType>
void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
std::vector<std::unique_ptr<TaskType>> Queue<TaskType>::GetLowPriorityTasks(
size_t size) {
std::vector<std::unique_ptr<TaskType>> low_priority_tasks_to_pad;
// If priority queue is not enable, immediately return instead of attempting
// to acquire a lock.
if (!options_.enable_priority_queue) return low_priority_tasks_to_pad;
{
mutex_lock l(mu_);
low_priority_tasks_to_pad = GetLowPriorityTaskQueue().RemoveTask(size);
}
return low_priority_tasks_to_pad;
}

template <typename TaskType>
void Queue<TaskType>::ProcessBatch(
std::unique_ptr<Batch<TaskType>> batch,
std::vector<std::unique_ptr<TaskType>> padding_task) {
profiler::TraceMeConsumer trace_me(
[&] {
return profiler::TraceMeEncode(
Expand All @@ -1144,7 +1189,15 @@ void Queue<TaskType>::ProcessBatch(std::unique_ptr<Batch<TaskType>> batch) {
},
profiler::ContextType::kSharedBatchScheduler,
batch->traceme_context_id());
process_batch_callback_(std::move(batch));

if (std::holds_alternative<ProcessBatchCallbackWithoutPaddingTasks>(
process_batch_callback_)) {
std::get<ProcessBatchCallbackWithoutPaddingTasks>(process_batch_callback_)(
std::move(batch));
} else {
std::get<ProcessBatchCallbackWithPaddingTasks>(process_batch_callback_)(
std::move(batch), std::move(padding_task));
}

{
mutex_lock l(mu_);
Expand Down Expand Up @@ -1267,6 +1320,11 @@ Queue<TaskType>::GetBatches() const {
return high_priority_batches_;
}

template <typename TaskType>
TaskQueue<TaskType>& Queue<TaskType>::GetLowPriorityTaskQueue() {
return low_priority_tasks_;
}

template <typename TaskType>
QueueHandle<TaskType>::QueueHandle(
std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler,
Expand Down
113 changes: 110 additions & 3 deletions tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc
Expand Up @@ -20,10 +20,12 @@ limitations under the License.
#include <thread> // NOLINT(build/c++11)
#include <tuple>
#include <utility>
#include <vector>

#include "absl/base/call_once.h"
#include "absl/container/fixed_array.h"
#include "absl/time/time.h"
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
Expand Down Expand Up @@ -121,14 +123,16 @@ QueueOptions CreateQueueOptions(size_t max_execution_batch_size,
size_t batch_timeout_micros,
size_t max_enqueued_batches,
bool enable_large_batch_splitting,
bool enable_lazy_split, SplitFunc split_func) {
bool enable_lazy_split, SplitFunc split_func,
bool enable_priority_queue = false) {
QueueOptions queue_options;
queue_options.max_enqueued_batches = max_enqueued_batches;
queue_options.max_execution_batch_size = max_execution_batch_size;
queue_options.input_batch_size_limit = input_batch_size_limit;
queue_options.batch_timeout_micros = batch_timeout_micros;
queue_options.enable_large_batch_splitting = enable_large_batch_splitting;
queue_options.enable_lazy_split = enable_lazy_split;
queue_options.enable_priority_queue = enable_priority_queue;
if (enable_large_batch_splitting) {
queue_options.split_input_task_func = split_func;
}
Expand All @@ -141,11 +145,12 @@ class SharedBatchSchedulerTest
QueueOptions CreateQueueOptions(size_t max_execution_batch_size,
size_t input_batch_size_limit,
size_t batch_timeout_micros,
size_t max_enqueued_batches) {
size_t max_enqueued_batches,
bool enable_priority_queue = false) {
return tensorflow::serving::CreateQueueOptions(
max_execution_batch_size, input_batch_size_limit, batch_timeout_micros,
max_enqueued_batches, enable_input_batch_split(), enable_lazy_split(),
get_split_func());
get_split_func(), enable_priority_queue);
}
bool enable_input_batch_split() const { return std::get<0>(GetParam()); }

Expand Down Expand Up @@ -242,6 +247,108 @@ TEST_P(SharedBatchSchedulerTest, Basic) {
}
}

TEST_P(SharedBatchSchedulerTest,
CallbackWithTaskVectorOkWithPriorityQueueEnabled) {
bool queue_0_callback_called = false;
auto queue_0_callback = [&queue_0_callback_called](
std::unique_ptr<Batch<FakeTask>> batch,
std::vector<std::unique_ptr<FakeTask>> tasks) {
queue_0_callback_called = true;
ASSERT_TRUE(batch->IsClosed());
ASSERT_EQ(3, batch->num_tasks());
EXPECT_EQ(1, batch->task(0).size());
EXPECT_EQ(3, batch->task(1).size());
EXPECT_EQ(5, batch->task(2).size());
EXPECT_EQ(0, tasks.size());
};
bool queue_1_callback_called = false;
auto queue_1_callback = [&queue_1_callback_called](
std::unique_ptr<Batch<FakeTask>> batch,
std::vector<std::unique_ptr<FakeTask>> tasks) {
queue_1_callback_called = true;
ASSERT_TRUE(batch->IsClosed());
ASSERT_EQ(2, batch->num_tasks());
EXPECT_EQ(2, batch->task(0).size());
EXPECT_EQ(4, batch->task(1).size());
EXPECT_EQ(0, tasks.size());
};
{
std::shared_ptr<Scheduler> scheduler =
CreateSharedBatchScheduler(/*num_batch_threads=*/3);

// Create two queues.
const QueueOptions queue_options = CreateQueueOptions(
/*max_execution_batch_size=*/10, /*input_batch_size_limit=*/10,
/*batch_timeout_micros=*/1 * 1000 * 1000, /*max_enqueued_batches=*/2,
/*enable_priority_queue=*/true);
std::unique_ptr<Queue> queue_0 =
CreateQueue(scheduler, queue_options, queue_0_callback);
std::unique_ptr<Queue> queue_1 =
CreateQueue(scheduler, queue_options, queue_1_callback);

// Submit tasks to the two queues.
TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
}
EXPECT_TRUE(queue_0_callback_called);
EXPECT_TRUE(queue_1_callback_called);
}

// For now there shouldn't be much difference with the enabled case above since
// nothing currently inserts any tasks to the low priority task queue.
TEST_P(SharedBatchSchedulerTest,
CallbackWithTaskVectorOkWithPriorityQueueDisabled) {
bool queue_0_callback_called = false;
auto queue_0_callback = [&queue_0_callback_called](
std::unique_ptr<Batch<FakeTask>> batch,
std::vector<std::unique_ptr<FakeTask>> tasks) {
queue_0_callback_called = true;
ASSERT_TRUE(batch->IsClosed());
ASSERT_EQ(3, batch->num_tasks());
EXPECT_EQ(1, batch->task(0).size());
EXPECT_EQ(3, batch->task(1).size());
EXPECT_EQ(5, batch->task(2).size());
EXPECT_EQ(0, tasks.size());
};
bool queue_1_callback_called = false;
auto queue_1_callback = [&queue_1_callback_called](
std::unique_ptr<Batch<FakeTask>> batch,
std::vector<std::unique_ptr<FakeTask>> tasks) {
queue_1_callback_called = true;
ASSERT_TRUE(batch->IsClosed());
ASSERT_EQ(2, batch->num_tasks());
EXPECT_EQ(2, batch->task(0).size());
EXPECT_EQ(4, batch->task(1).size());
EXPECT_EQ(0, tasks.size());
};
{
std::shared_ptr<Scheduler> scheduler =
CreateSharedBatchScheduler(/*num_batch_threads=*/3);

// Create two queues.
const QueueOptions queue_options = CreateQueueOptions(
/*max_execution_batch_size=*/10, /*input_batch_size_limit=*/10,
/*batch_timeout_micros=*/1 * 1000 * 1000, /*max_enqueued_batches=*/2,
/*enable_priority_queue=*/true);
std::unique_ptr<Queue> queue_0 =
CreateQueue(scheduler, queue_options, queue_0_callback);
std::unique_ptr<Queue> queue_1 =
CreateQueue(scheduler, queue_options, queue_1_callback);

// Submit tasks to the two queues.
TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
}
EXPECT_TRUE(queue_0_callback_called);
EXPECT_TRUE(queue_1_callback_called);
}

TEST_P(SharedBatchSchedulerTest, ObeyBatchSizeConstraint) {
// Set up a fake clock, which only advances when we explicitly tell it to.
test_util::FakeClockEnv env(Env::Default());
Expand Down