Skip to content

Commit

Permalink
Add some new fields to xla::CompiledMemoryStats to show host memory…
Browse files Browse the repository at this point in the history
… usage stats. Also update PjRt C API.

PiperOrigin-RevId: 610027375
  • Loading branch information
yueshengys authored and tensorflower-gardener committed Feb 24, 2024
1 parent 86d283d commit 5b29cd2
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 11 deletions.
2 changes: 0 additions & 2 deletions third_party/xla/xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ cc_library(
"//xla:literal",
"//xla:shape_util",
"//xla:status",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
Expand Down Expand Up @@ -108,7 +107,6 @@ cc_library(
":pjrt_c_api_hdrs",
"//xla:shape_util",
"//xla:status",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/pjrt:pjrt_client",
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/pjrt/c/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# PJRT C API changelog

## 0.43
* Added some new fields to PJRT_Executable_GetCompiledMemoryStats

## 0.42
* Renamed all ``priv`` fields to ``extension_start``

Expand Down
17 changes: 13 additions & 4 deletions third_party/xla/xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ extern "C" {
// Changes include:
// * Adding a new field to the PJRT_Api or argument structs
// * Renaming a method or argument (doesn't affect ABI)
#define PJRT_API_MINOR 42
#define PJRT_API_MINOR 43

// The plugin should set the major_version and minor_version of
// PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in
Expand Down Expand Up @@ -1382,18 +1382,27 @@ struct PJRT_Executable_GetCompiledMemoryStats_Args {
PJRT_Executable* executable;

// Mirrors xla::CompiledMemoryStats.
// Device default memory (e.g., HBM for GPU/TPU) usage stats.
int64_t generated_code_size_in_bytes; // out
int64_t argument_size_in_bytes; // out
int64_t output_size_in_bytes; // out
// How much argument is reused for output.
int64_t alias_size_in_bytes; // out
int64_t temp_size_in_bytes; // out

// Host memory usage stats.
int64_t host_generated_code_size_in_bytes; // out
int64_t host_argument_size_in_bytes; // out
int64_t host_output_size_in_bytes; // out
int64_t host_alias_size_in_bytes; // out
int64_t host_temp_size_in_bytes; // out
};
PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_GetCompiledMemoryStats_Args,
temp_size_in_bytes);
host_temp_size_in_bytes);

// Return memory stats that allow callers to estimate device memory usage
// when running this executable.
// Return memory stats that allow callers to estimate memory usage when running
// this executable. The memory stats could contain usage info from different
// memory spaces, like default memory (e.g., HBM for GPU/TPU) and host memory.
typedef PJRT_Error* PJRT_Executable_GetCompiledMemoryStats(
PJRT_Executable_GetCompiledMemoryStats_Args* args);

Expand Down
7 changes: 6 additions & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ limitations under the License.
#include "xla/primitive_util.h"
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -963,6 +962,12 @@ absl::StatusOr<xla::CompiledMemoryStats> GetCompiledMemoryStats(
results.output_size_in_bytes = args.output_size_in_bytes;
results.alias_size_in_bytes = args.alias_size_in_bytes;
results.temp_size_in_bytes = args.temp_size_in_bytes;
results.host_generated_code_size_in_bytes =
args.host_generated_code_size_in_bytes;
results.host_argument_size_in_bytes = args.host_argument_size_in_bytes;
results.host_output_size_in_bytes = args.host_output_size_in_bytes;
results.host_alias_size_in_bytes = args.host_alias_size_in_bytes;
results.host_temp_size_in_bytes = args.host_temp_size_in_bytes;
return results;
}

Expand Down
7 changes: 6 additions & 1 deletion third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
Expand Down Expand Up @@ -1509,6 +1508,12 @@ PJRT_Error* PJRT_Executable_GetCompiledMemoryStats(
args->output_size_in_bytes = memory_stats.output_size_in_bytes;
args->alias_size_in_bytes = memory_stats.alias_size_in_bytes;
args->temp_size_in_bytes = memory_stats.temp_size_in_bytes;
args->host_generated_code_size_in_bytes =
memory_stats.host_generated_code_size_in_bytes;
args->host_argument_size_in_bytes = memory_stats.host_argument_size_in_bytes;
args->host_output_size_in_bytes = memory_stats.host_output_size_in_bytes;
args->host_alias_size_in_bytes = memory_stats.host_alias_size_in_bytes;
args->host_temp_size_in_bytes = memory_stats.host_temp_size_in_bytes;
return nullptr;
}

Expand Down
8 changes: 8 additions & 0 deletions third_party/xla/xla/pjrt/executable_metadata.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@ import "xla/service/hlo.proto";

// Mirror of xla::CompiledMemoryStats.
message CompiledMemoryStatsProto {
// Device default memory (e.g., HBM for GPU/TPU) usage stats.
int64 generated_code_size_in_bytes = 1;
int64 argument_size_in_bytes = 2;
int64 output_size_in_bytes = 3;
int64 alias_size_in_bytes = 4;
int64 temp_size_in_bytes = 5;
xla.HloProto hlo_proto = 6;

// Host memory usage stats.
int64 host_generated_code_size_in_bytes = 7;
int64 host_argument_size_in_bytes = 8;
int64 host_output_size_in_bytes = 9;
int64 host_alias_size_in_bytes = 10;
int64 host_temp_size_in_bytes = 11;
}
12 changes: 10 additions & 2 deletions third_party/xla/xla/pjrt/pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,17 @@ std::string CompiledMemoryStats::DebugString() const {
"argument_size_in_bytes=$1, "
"output_size_in_bytes=$2, "
"alias_size_in_bytes=$3, "
"temp_size_in_bytes=$4)",
"temp_size_in_bytes=$4, "
"host_generated_code_size_in_bytes=$5, "
"host_argument_size_in_bytes=$6, "
"host_output_size_in_bytes=$7, "
"host_alias_size_in_bytes=$8, "
"host_temp_size_in_bytes=$9)",
generated_code_size_in_bytes, argument_size_in_bytes,
output_size_in_bytes, alias_size_in_bytes, temp_size_in_bytes);
output_size_in_bytes, alias_size_in_bytes, temp_size_in_bytes,
host_generated_code_size_in_bytes, host_argument_size_in_bytes,
host_output_size_in_bytes, host_alias_size_in_bytes,
host_temp_size_in_bytes);
}

// Defining the first virtual non-pure method, which is usually the virtual
Expand Down
23 changes: 22 additions & 1 deletion third_party/xla/xla/pjrt/pjrt_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/client/executable_build_options.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/layout.h"
#include "xla/pjrt/compile_options.pb.h"
#include "xla/pjrt/executable_metadata.pb.h"
#include "xla/pjrt/execute_options.pb.h"
Expand Down Expand Up @@ -260,19 +261,27 @@ struct ExecuteOptions {
const ExecuteOptionsProto& proto);
};

// Static device memory usage for a compiled program.
// Static memory usage for a compiled program.
// The on-device memory needed to run an executable is at least
// generated_code_size_in_bytes
// + argument_size_in_bytes + output_size_in_bytes - alias_size_in_bytes
// + temp_size_in_bytes.
struct CompiledMemoryStats {
// Device default memory (e.g., HBM for GPU/TPU) usage stats.
int64_t generated_code_size_in_bytes = 0;
int64_t argument_size_in_bytes = 0;
int64_t output_size_in_bytes = 0;
// How much argument is reused for output.
int64_t alias_size_in_bytes = 0;
int64_t temp_size_in_bytes = 0;

// Host memory usage stats.
int64_t host_generated_code_size_in_bytes = 0;
int64_t host_argument_size_in_bytes = 0;
int64_t host_output_size_in_bytes = 0;
int64_t host_alias_size_in_bytes = 0;
int64_t host_temp_size_in_bytes = 0;

std::string serialized_hlo_proto = "";
std::string DebugString() const;

Expand All @@ -284,6 +293,12 @@ struct CompiledMemoryStats {
proto.set_alias_size_in_bytes(alias_size_in_bytes);
proto.set_temp_size_in_bytes(temp_size_in_bytes);
proto.mutable_hlo_proto()->ParseFromString(serialized_hlo_proto);
proto.set_host_generated_code_size_in_bytes(
host_generated_code_size_in_bytes);
proto.set_host_argument_size_in_bytes(host_argument_size_in_bytes);
proto.set_host_output_size_in_bytes(host_output_size_in_bytes);
proto.set_host_alias_size_in_bytes(host_alias_size_in_bytes);
proto.set_host_temp_size_in_bytes(host_temp_size_in_bytes);
return proto;
}

Expand All @@ -295,6 +310,12 @@ struct CompiledMemoryStats {
stats.alias_size_in_bytes = proto.alias_size_in_bytes();
stats.temp_size_in_bytes = proto.temp_size_in_bytes();
stats.serialized_hlo_proto = proto.hlo_proto().SerializeAsString();
stats.host_generated_code_size_in_bytes =
proto.host_generated_code_size_in_bytes();
stats.host_argument_size_in_bytes = proto.host_argument_size_in_bytes();
stats.host_output_size_in_bytes = proto.host_output_size_in_bytes();
stats.host_alias_size_in_bytes = proto.host_alias_size_in_bytes();
stats.host_temp_size_in_bytes = proto.host_temp_size_in_bytes();
return stats;
}
};
Expand Down

0 comments on commit 5b29cd2

Please sign in to comment.