Skip to content

Commit

Permalink
Introduce the batch_stats module.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630398218
  • Loading branch information
tensorflower-gardener committed May 24, 2024
1 parent 3068334 commit 6c36398
Show file tree
Hide file tree
Showing 8 changed files with 682 additions and 3 deletions.
33 changes: 33 additions & 0 deletions tensorflow/core/kernels/batching_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ cc_library(
],
)

cc_library(
name = "batch_stats",
hdrs = ["batch_stats.h"],
deps = [
"//tensorflow/core:framework_lite",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/time",
],
)

tf_cc_test(
name = "batch_stats_test",
srcs = ["batch_stats_test.cc"],
deps = [
":batch_stats",
"//tensorflow/core:test",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "batch_input_task",
hdrs = ["batch_input_task.h"],
Expand Down Expand Up @@ -100,6 +124,7 @@ cc_library(
"//tensorflow/core/lib/core:status",
"//tensorflow/core/platform:thread_annotations",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@local_tsl//tsl/platform:criticality",
Expand All @@ -113,6 +138,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
Expand All @@ -126,7 +152,12 @@ cc_library(
srcs = ["batch_scheduler_utils.cc"],
hdrs = ["batch_scheduler_utils.h"],
deps = [
":batch_scheduler_hdrs",
"//tensorflow/core:portable_gif_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
],
)

Expand Down Expand Up @@ -160,7 +191,9 @@ tf_cc_test(
name = "batch_scheduler_utils_test",
srcs = ["batch_scheduler_utils_test.cc"],
deps = [
":batch_scheduler_hdrs",
":batch_scheduler_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
49 changes: 47 additions & 2 deletions tensorflow/core/kernels/batching_util/batch_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ limitations under the License.
#include <atomic>
#include <cstddef>
#include <deque>
#include <functional>
#include <iterator>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
Expand Down Expand Up @@ -304,6 +304,15 @@ class Batch {
// Returns the TraceMe context id of this batch.
uint64 traceme_context_id() const;

// Attempts to trim this batch to a new, smaller size (not to be confused with
// the number of tasks in the batch). On success, the trimmed tasks go into
// 'out_trimmed_tasks' in the same order the tasks were in this batch.
//
// The method might not succeed if it needs to split a large task to hit the
// correct size.
void TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks);

private:
mutable mutex mu_;

Expand Down Expand Up @@ -505,6 +514,42 @@ uint64 Batch<TaskType>::traceme_context_id() const {
return traceme_context_id_;
}

template <typename TaskType>
void Batch<TaskType>::TryTrimToNewSize(
int new_size, std::vector<std::unique_ptr<TaskType>>& out_trimmed_tasks) {
mutex_lock l(mu_);
DCHECK_GT(new_size, 0);
DCHECK_LT(new_size, size_);
DCHECK(out_trimmed_tasks.empty());

// Index of the first task to trim away. It is possible that it is the index
// of a task of size larger than 1 that will have to be split in order to get
// to the target new_size.
int32 first_task_to_move = 0;
// The sum of sizes of tasks i, where i < first_task_to_move.
int32 size_of_previous_tasks = 0;
while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <=
new_size) {
size_of_previous_tasks += tasks_[first_task_to_move]->size();
first_task_to_move++;
}

// Check whether task 'first_task_to_move' will have to be split.
if (size_of_previous_tasks < new_size) {
// TODO: b/325954758 - Consider supporting splitting large tasks and then
// drop 'Try' from the method name.
return;
}
DCHECK_EQ(size_of_previous_tasks, new_size);

// Actually trim.
out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move);
std::move(tasks_.begin() + first_task_to_move, tasks_.end(),
std::back_inserter(out_trimmed_tasks));
tasks_.resize(first_task_to_move);
size_ = new_size;
}

} // namespace serving
} // namespace tensorflow

Expand Down
51 changes: 50 additions & 1 deletion tensorflow/core/kernels/batching_util/batch_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ limitations under the License.
#include <optional>
#include <string>
#include <tuple>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tsl/platform/criticality.h"
Expand All @@ -37,6 +38,7 @@ namespace {

using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Pointer;
using ::testing::Property;

TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) {
Expand Down Expand Up @@ -386,6 +388,53 @@ TEST(BatchTest, RemoveAllTasks) {
EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty()); // third call
}

TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

auto task2 = new FakeTask(7);
batch.AddTask(std::unique_ptr<FakeTask>(task2));

auto task3 = new FakeTask(9);
batch.AddTask(std::unique_ptr<FakeTask>(task3));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 8,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);

EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3)));

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) {
Batch<FakeTask> batch;

auto task0 = new FakeTask(3);
batch.AddTask(std::unique_ptr<FakeTask>(task0));

auto task1 = new FakeTask(5);
batch.AddTask(std::unique_ptr<FakeTask>(task1));

std::vector<std::unique_ptr<FakeTask>> trimmed_tasks;
batch.TryTrimToNewSize(/* new_size= */ 4,
/* out_trimmed_tasks= */ trimmed_tasks);

EXPECT_EQ(batch.size(), 8);
EXPECT_EQ(batch.num_tasks(), 2);
EXPECT_TRUE(trimmed_tasks.empty());

batch.Close(); // Batch::~Batch blocks until the batch is closed.
}

} // namespace
} // namespace serving
} // namespace tensorflow
72 changes: 72 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_scheduler_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,30 @@ limitations under the License.

#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h"

#include <algorithm>
#include <string>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/flags/flag.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

ABSL_FLAG(tensorflow::serving::BatchPaddingPolicy,
tensorflow_batch_padding_policy,
tensorflow::serving::BatchPaddingPolicy::kPadUp,
"The policy that a batch schduler is using when deciding what to do "
"when, say, 18 requests need to be batched, but only 16 and 32 batch "
"sizes are allowed. The following options are available. PAD_UP: pad "
"to size 32. BATCH_DOWN: schedule a batch of size 16 and leave 2 "
"requests in the batch buffer. MINIMIZE_TPU_COST_PER_REQUEST: a "
"smarter greedy policy that chooses to either PAD_UP or BATCH_DOWN "
"so as to minimize the TPU costs per real request. In this case, it "
"would compare (batch_16_cost / 16) and (batch_32_cost / 18). "
"WARNING: not all batch schedulers might support this option.");

namespace tensorflow {
namespace serving {

Expand All @@ -40,5 +59,58 @@ int GetNextAllowedBatchSize(int batch_size,
return batch_size;
}

int32 GetPrevAllowedBatchSize(int batch_size,
const std::vector<int32>& allowed_batch_sizes,
bool disable_padding) {
if (disable_padding || allowed_batch_sizes.empty()) {
return batch_size;
}

DCHECK(absl::c_is_sorted(allowed_batch_sizes));
DCHECK_GT(batch_size, 0);

// First from the end allowed batch size not larger than batch_size.
auto result = std::find_if(
allowed_batch_sizes.rbegin(), allowed_batch_sizes.rend(),
[&](int allowed_size) { return allowed_size <= batch_size; });

if (result == allowed_batch_sizes.rend()) {
// No such element exists.
return batch_size;
}

return *result;
}

bool AbslParseFlag(absl::string_view text, BatchPaddingPolicy* out,
std::string* error) {
if (text == "PAD_UP") {
*out = BatchPaddingPolicy::kPadUp;
return true;
}
if (text == "BATCH_DOWN") {
*out = BatchPaddingPolicy::kBatchDown;
return true;
}
if (text == "MINIMIZE_TPU_COST_PER_REQUEST") {
*out = BatchPaddingPolicy::kMinimizeTpuCostPerRequest;
return true;
}
*error = "unrecognized batching policy string";
return false;
}

string AbslUnparseFlag(BatchPaddingPolicy in) {
switch (in) {
case BatchPaddingPolicy::kPadUp:
return "PAD_UP";
case BatchPaddingPolicy::kBatchDown:
return "BATCH_DOWN";
case BatchPaddingPolicy::kMinimizeTpuCostPerRequest:
return "MINIMIZE_TPU_COST_PER_REQUEST";
}
CHECK(FATAL) << "Unrecognized BatchPaddingPolicy enum value."; // Crash OK
}

} // namespace serving
} // namespace tensorflow
Loading

0 comments on commit 6c36398

Please sign in to comment.