Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the batch_stats module. #68601

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tensorflow/core/kernels/batching_util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ cc_library(
],
)

cc_library(
name = "batch_stats",
hdrs = ["batch_stats.h"],
deps = [
"//tensorflow/core:framework_lite",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/time",
],
)

tf_cc_test(
name = "batch_stats_test",
srcs = ["batch_stats_test.cc"],
deps = [
":batch_stats",
"//tensorflow/core:test",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest_main",
],
)

cc_library(
name = "batch_input_task",
hdrs = ["batch_input_task.h"],
Expand Down
202 changes: 202 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_stats.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// The API for reporting and querying batch statistics such as the average batch
// costs for in-process use.
//
// All these statistics can also be retrieved from metrics reported by various
// modules (e.g., batch_resource_base), but it would be slow. This API, on the
// other hand, was designed to be queried on every request.
//
// The classes defined here are not supposed to be instantiated by the user.
// Instead, this file provides a single entry point:
//
// BatchStats& GlobalBatchStats();
//
// For example, to register batch cost, do:
//
// GlobalBatchStats()
// .model(/* model_name= */ "m", /* op_name= */ "o")
// .batch_size(4)
// .tpu_cost
// .Register(cost);
//
// To get the mean cost later, do:
//
// std::optional<absl::Duration> cost =
// .GlobalBatchStats()
// .model(/* model_name= */ "m", /* op_name= */ "o")
// .batch_size(4)
// .tpu_cost
// .mean();
//
// It is allowed and safe to store references to intermediate objects here
// because all intermediate objects are guaranteed to never be destroyed.
//
// All operations supported by this API are thread-safe.

#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_

#include <cstdint>
#include <optional>
#include <string>
#include <tuple>

#include "absl/container/node_hash_map.h"
#include "absl/log/check.h"
#include "absl/time/time.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tsl/platform/thread_annotations.h"

namespace tensorflow::serving {

// Tracks the average cost of registered samples.
//
// Thread-safe.
class CostTracker {
public:
// Registers a cost sample.
void Register(absl::Duration cost) {
DCHECK_GT(cost, absl::ZeroDuration());

mutex_lock l(mu_);
sample_count_++;
sample_sum_ += cost;
};

// Returns the average cost of all registered samples, giving each sample
// the same weight.
//
// Returns std::nullopt if no samples have been registered.
//
// TODO: b/325954758 - Switch this to an exponentially-decaying average. It's
// likely enough to set the half-life to the last 100-1000 samples.
std::optional<absl::Duration> mean() const {
int64_t count;
absl::Duration sum;

{
// We only hold the lock to read the values and release it before later
// performing a relatively slow division operation.
mutex_lock l(mu_);
count = sample_count_;
sum = sample_sum_;
}

if (count == 0) return std::nullopt;

return sum / count;
};

private:
mutable mutex mu_;

int64_t sample_count_ TF_GUARDED_BY(mu_) = 0;
absl::Duration sample_sum_ TF_GUARDED_BY(mu_);
};

// Tracks statistics for a particular model and batch size.
//
// Thread-safe.
class BatchSizeStats {
public:
CostTracker& tpu_cost() { return tpu_cost_; };

private:
CostTracker tpu_cost_;
};

// Tracks statistics for a particular model.
//
// Here, "model" means a specific version of a model (we assume that version is
// encoded in the op_name). In rare cases, when a model version has multiple
// BatchFunction operation, we also treat each such operation as a separate
// model in this context (they should also have different op_names).
//
// Thread-safe.
class ModelBatchStats {
public:
// Returns a reference to the BatchSizeStats instance for the given batch
// size.
//
// The returned reference persist for as long as 'this' is alive.
BatchSizeStats& batch_size(int32 batch_size) {
mutex_lock l(mu_);
return batch_size_stats_by_batch_size_[batch_size];
}

private:
mutable mutex mu_;

// The storage of all BatchSizeStats instances.
//
// The mutex only protects adding/finding element in the map. Access to
// elements themselves (after they were created) is not protected here. No
// element deletion is possible because we return references to items in this
// map and don't track their lifetime. We are using the node hash map so that
// elements, once created, are fixed in memory.
absl::node_hash_map<int32, BatchSizeStats> batch_size_stats_by_batch_size_
TF_GUARDED_BY(mu_);
};

// Tracks batch statistics for all models.
//
// Thread-safe.
class BatchStats {
public:
// Returns a reference to ModelBatchStats for the provided model_name and
// op_name.
//
// Upon invocation with a not-yet-seen arguments, creates an empty
// ModelBatchStats instance.
//
// The returned reference persist for as long as 'this' is alive.
ModelBatchStats& model(const std::string& model_name,
const std::string& op_name) {
std::tuple key(model_name, op_name);
mutex_lock l(mu_);
return model_batch_stats_by_model_and_op_names_[key];
}

// TODO: b/325954758 - Add a public method for scanning model_batch_stats_ and
// mention that it will always returns elements in the same order.

private:
mutable mutex mu_;

// The storage of all ModelBatchStats instances.
//
// The mutex only protects adding/finding element in the map. Access to
// elements themselves (after they were created) is not protected here. No
// element deletion is possible because we return references to items in this
// map and don't track their lifetime. We are using the node hash map for
// element pointer stability.
absl::node_hash_map<std::tuple<std::string, std::string>, ModelBatchStats>
model_batch_stats_by_model_and_op_names_ TF_GUARDED_BY(mu_);
};

// Returns the global instance of BatchStats, to use used for all production
// purposes (one should only instantiate individual classes from this file to
// test them).
inline BatchStats& GlobalBatchStats() {
static BatchStats* instance = new BatchStats();
return *instance;
}

} // namespace tensorflow::serving

#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_
69 changes: 69 additions & 0 deletions tensorflow/core/kernels/batching_util/batch_stats_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/kernels/batching_util/batch_stats.h"

#include <gtest/gtest.h>
#include "absl/time/time.h"
#include "tensorflow/core/platform/test.h"

namespace tensorflow::serving {
namespace {

TEST(BatchStatsTest, GlobalBatchStatsAlwaysReturnsTheSameInstance) {
ASSERT_EQ(&GlobalBatchStats(), &GlobalBatchStats());
}

TEST(BatchStatsTest, BasicOperation) {
BatchStats stats;
stats.model(/* model_name= */ "m", /* op_name= */ "o")
.batch_size(1)
.tpu_cost()
.Register(absl::Hours(5));
ASSERT_EQ(stats.model(/* model_name= */ "m", /* op_name= */ "o")
.batch_size(1)
.tpu_cost()
.mean(),
absl::Hours(5));
}

TEST(BatchStatsTest, ModelBatchStatsAreUniqueForEachModel) {
BatchStats stats;
ASSERT_NE(&stats.model(/* model_name= */ "m", /* op_name= */ "o"),
&stats.model(/* model_name= */ "m", /* op_name= */ "o2"));
}

TEST(BatchStatsTest, BatchSizeStatsAreUniqueForEachBatchSize) {
ModelBatchStats stats;
ASSERT_NE(&stats.batch_size(1), &stats.batch_size(2));
}

TEST(BatchStatsTest, CostTrackerStartsWithNoMean) {
CostTracker tracker;

ASSERT_FALSE(tracker.mean().has_value());
}

TEST(BatchStatsTest, CostTrackerMeanIsCorrect) {
CostTracker tracker;
tracker.Register(absl::Hours(5));
tracker.Register(absl::Hours(7));

ASSERT_EQ(*tracker.mean(), absl::Hours(6));
}

} // namespace

} // namespace tensorflow::serving
Loading