Skip to content

Commit

Permalink
Adds functionality to send TSL metrics over model_service RPC.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613337350
  • Loading branch information
CMinge authored and tensorflow-copybara committed Mar 6, 2024
1 parent 72acbaf commit 9564ef6
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tensorflow_serving/apis/model_management.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@ option cc_enable_arenas = true;

message ReloadConfigRequest {
ModelServerConfig config = 1;
repeated string metric_names = 2;
}

message ReloadConfigResponse {
StatusProto status = 1;
repeated Metric metric = 2;
}

message Metric {
string name = 1;
oneof value_increase {
int64 int64_value_increase = 2;
}
}
2 changes: 2 additions & 0 deletions tensorflow_serving/model_servers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ cc_library(
"//tensorflow_serving/apis:model_management_cc_proto",
"//tensorflow_serving/apis:model_service_cc_proto",
"//tensorflow_serving/util:status_util",
"@com_google_absl//absl/container:flat_hash_map",
"@org_tensorflow//tensorflow/core:lib",
],
)

Expand Down
50 changes: 50 additions & 0 deletions tensorflow_serving/model_servers/model_service_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ limitations under the License.

#include "tensorflow_serving/model_servers/model_service_impl.h"

#include <cstdint>
#include <memory>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "tsl/lib/monitoring/collected_metrics.h"
#include "tsl/lib/monitoring/collection_registry.h"
#include "tensorflow_serving/model_servers/get_model_status_impl.h"
#include "tensorflow_serving/model_servers/grpc_status_util.h"
#include "tensorflow_serving/util/status_util.h"
Expand All @@ -38,6 +46,8 @@ ::grpc::Status ModelServiceImpl::HandleReloadConfigRequest(
ReloadConfigResponse *response) {
ModelServerConfig server_config = request->config();
Status status;
const absl::flat_hash_map<std::string, int64_t> old_metric_values =
GetMetrics(request);
switch (server_config.config_case()) {
case ModelServerConfig::kModelConfigList: {
const ModelConfigList list = server_config.model_config_list();
Expand All @@ -62,11 +72,51 @@ ::grpc::Status ModelServiceImpl::HandleReloadConfigRequest(
if (!status.ok()) {
LOG(ERROR) << "ReloadConfig failed: " << status.message();
}
const absl::flat_hash_map<std::string, int64_t> new_metric_values =
GetMetrics(request);
RecordMetricsIncrease(old_metric_values, new_metric_values, response);

const StatusProto status_proto = ToStatusProto(status);
*response->mutable_status() = status_proto;
return ToGRPCStatus(status);
}

absl::flat_hash_map<std::string, int64_t> ModelServiceImpl::GetMetrics(
const ReloadConfigRequest *request) {
absl::flat_hash_map<std::string, int64_t> metric_values = {};
const tsl::monitoring::CollectionRegistry::CollectMetricsOptions options;
tsl::monitoring::CollectionRegistry *collection_registry =
tsl::monitoring::CollectionRegistry::Default();
std::unique_ptr<tsl::monitoring::CollectedMetrics> collected_metrics =
collection_registry->CollectMetrics(options);

for (const std::string &metric_name : request->metric_names()) {
int64_t metric_value = 0;
if (collected_metrics->point_set_map.contains(metric_name)) {
std::vector<std::unique_ptr<tsl::monitoring::Point>> *points =
&collected_metrics->point_set_map[metric_name]->points;
if (!points->empty()) {
metric_value = (*points)[0]->int64_value;
}
}
metric_values.insert({metric_name, metric_value});
}
return metric_values;
}

void ModelServiceImpl::RecordMetricsIncrease(
const absl::flat_hash_map<std::string, int64_t> &old_metric_values,
const absl::flat_hash_map<std::string, int64_t> &new_metric_values,
ReloadConfigResponse *response) {
for (const auto &[metric_name, metric_value] : new_metric_values) {
Metric metric;
metric.set_name(metric_name);
int64_t old_metric_value = old_metric_values.contains(metric_name)
? old_metric_values.at(metric_name)
: 0;
metric.set_int64_value_increase(metric_value - old_metric_value);
*response->add_metric() = metric;
}
}
} // namespace serving
} // namespace tensorflow
11 changes: 11 additions & 0 deletions tensorflow_serving/model_servers/model_service_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ class ModelServiceImpl final : public ModelService::Service {

private:
ServerCore *core_;

// Obtains values for metrics provided in request.
absl::flat_hash_map<std::string, int64_t> GetMetrics(
const ReloadConfigRequest *request);

// Compares old_metric_values and new_metric_values, storing the increases in
// response
void RecordMetricsIncrease(
const absl::flat_hash_map<std::string, int64_t> &old_metric_values,
const absl::flat_hash_map<std::string, int64_t> &new_metric_values,
ReloadConfigResponse *response);
};

} // namespace serving
Expand Down

0 comments on commit 9564ef6

Please sign in to comment.