Skip to content

Commit

Permalink
Remove XLA:CPU runtime support from tf2xla
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629895538
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed May 2, 2024
1 parent 3f6550f commit 872be62
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 175 deletions.
2 changes: 0 additions & 2 deletions tensorflow/compiler/aot/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
data, StaticHloProfilePrinterData());
set_static_data_use_xla_runtime(data, {{USE_XLA_RUNTIME}});
{{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
Expand Down Expand Up @@ -822,7 +821,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{DECLS_FROM_OBJ_FILE}}",
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
{"{{USE_XLA_RUNTIME}}", opts.use_xla_runtime ? "true" : "false"},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},
{"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
Expand Down
1 change: 0 additions & 1 deletion tensorflow/compiler/aot/codegen_test_h.golden
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class MyClass final : public tensorflow::XlaCompiledCpuFunction {
set_static_data_program_shape(data, StaticProgramShape());
set_static_data_hlo_profile_printer_data(
data, StaticHloProfilePrinterData());
set_static_data_use_xla_runtime(data, false);

return data;
}();
Expand Down
2 changes: 0 additions & 2 deletions tensorflow/compiler/tf2xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ filegroup(
"@local_tsl//tsl/framework/fixedpoint:xla_cpu_runtime_hdrs",
"@local_tsl//tsl/platform:xla_cpu_runtime_srcs",
"@local_xla//xla:cpu_runtime_hdrs",
"@local_xla//xla/runtime:aot_ffi_execution_context_hdrs",
"@local_xla//xla/service:custom_call_status_hdrs",
"@local_xla//xla/service/cpu:runtime_hdrs",
],
Expand Down Expand Up @@ -391,7 +390,6 @@ cc_library(
# binary produced by tfcompile.
"@local_xla//xla:cpu_function_runtime",
"@local_xla//xla:executable_run_options",
"@local_xla//xla/runtime:aot_ffi_execution_context",
"@local_xla//xla/service/cpu:buffer_desc",
"//tensorflow/core/platform:types",
],
Expand Down
116 changes: 1 addition & 115 deletions tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,12 @@ limitations under the License.
#include <vector>

#include "xla/cpu_function_runtime.h"
#include "xla/runtime/aot_ffi_execution_context.h"

namespace tensorflow {

namespace {
// MemrefDesc's are part of the XLA Runtime ABI. Redefine them here (with a
// slightly different name to avoid confusion) because we cannot depend on
// XLA Runtime's headers.
// Note: this is an internal type, to be used exclusively in this file.
struct MemrefHolder {
MemrefHolder(const XlaCompiledCpuFunction::ShapeInfo& shape_info,
void* data_ptr)
: rank(shape_info.num_dimensions), data(data_ptr), offset(0) {
sizes.resize(shape_info.num_dimensions);
strides.resize(shape_info.num_dimensions);
int64_t multiplier = 1;
for (int i = shape_info.num_dimensions - 1; i >= 0; --i) {
int64_t size = shape_info.dimensions[i];
sizes[i] = size;
strides[i] = multiplier;
multiplier *= size;
}
}

unsigned rank = 0;
// Note: dtype is not needed here.
void* data = nullptr;
int64_t offset = 0;
std::vector<int64_t> sizes;
std::vector<int64_t> strides;
};
} // namespace

XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
AllocMode alloc_mode)
: raw_function_(static_data.raw_function_),
external_run_function_(static_data.external_run_function_),
cpu_executable_(static_data.cpu_executable_),
result_index_(static_data.result_index_),
buffer_table_(new void*[static_data.num_buffers_]),
buffer_infos_(static_data.buffer_infos_),
Expand All @@ -73,8 +41,7 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
variable_names_(static_data.variable_names_),
result_names_(static_data.result_names_),
program_shape_(static_data.program_shape_),
hlo_profile_printer_data_(static_data.hlo_profile_printer_data_),
use_xla_runtime_(static_data.use_xla_runtime_) {
hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
bool allocate_entry_params =
alloc_mode == AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS;
// Allocate arg and temp buffers.
Expand All @@ -92,94 +59,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
}
}

bool XlaCompiledCpuFunction::RunXlaRuntime() {
size_t num_memref_args = num_args_ + num_results_;
std::vector<MemrefHolder> memref_args;
memref_args.reserve(num_memref_args);

size_t num_ptrs = 1; // execution context.

// Append arguments.
for (int i = 0; i < num_args_; ++i) {
const ShapeInfo& shape_info = arg_shape_infos_[i];
memref_args.emplace_back(shape_info, buffer_table_[arg_index_table_[i]]);
num_ptrs += 3 + 2 * shape_info.num_dimensions;
}

// Append results.
for (int i = 0; i < num_results_; ++i) {
const ShapeInfo& shape_info = result_shape_infos_[i];
memref_args.emplace_back(shape_info, buffer_table_[result_index_table_[i]]);
num_ptrs += 3 + 2 * shape_info.num_dimensions;

// Point to this result from the "result" entry in the buffer table.
void** results = static_cast<void**>(buffer_table_[result_index_]);
results[i] = buffer_table_[result_index_table_[i]];
}

std::vector<void*> call_frame;
call_frame.resize(num_ptrs);
size_t ptr_index = 1;
for (const MemrefHolder& memref : memref_args) {
auto cast = [](const void* p) { return const_cast<void*>(p); };
call_frame[ptr_index + 0] = cast(&memref.data); // memref.basePtr
call_frame[ptr_index + 1] = cast(&memref.data); // memref.data
call_frame[ptr_index + 2] = cast(&memref.offset);
unsigned rank = memref.rank;
for (int64_t d = 0; d < rank; ++d) {
call_frame[ptr_index + 3 + d] = cast(&memref.sizes[d]);
call_frame[ptr_index + 3 + d + rank] = cast(&memref.strides[d]);
}
ptr_index += 3 + 2 * rank;
}

assert(num_ptrs == ptr_index);

xla::runtime::aot::ExecutionContext execution_context;
execution_context.custom_call_data = &run_options_;
xla::runtime::aot::ExecutionContext* execution_context_ptr =
&execution_context;
call_frame[0] = &execution_context_ptr;

auto xla_runtime_func =
reinterpret_cast<XlaRuntimeRawFunction>(raw_function_);
xla_runtime_func(call_frame.data());
if (execution_context.error) {
// No error support in XLA; dump error message to stderr.
std::cerr << "XLA AOT error: " << execution_context.error << ".\n";
return false;
}
return true;
}

bool XlaCompiledCpuFunction::Run() {
if (use_xla_runtime_) {
return RunXlaRuntime();
}
if (external_run_function_) {
std::vector<xla::cpu::BufferDesc> descriptor_table =
MakeXlaRuntimeDescriptorTable();
return external_run_function_(cpu_executable_, descriptor_table,
&run_options_);
}
XlaCustomCallStatus status;
raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
buffer_table_, &status, profile_counters_);
return !xla::CustomCallStatusGetMessage(&status).has_value();
}

std::vector<xla::cpu::BufferDesc>
XlaCompiledCpuFunction::MakeXlaRuntimeDescriptorTable() {
std::vector<xla::cpu::BufferDesc> descriptor_table;
descriptor_table.reserve(num_buffers_);
for (int32_t i = 0; i < num_buffers_; ++i) {
void* data = buffer_table_[i];
uint64_t size = buffer_infos_[i].size();
descriptor_table.emplace_back(data, size);
}
return descriptor_table;
}

XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
xla::cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
delete[] buffer_table_;
Expand Down
43 changes: 3 additions & 40 deletions tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,6 @@ class XlaCompiledCpuFunction {
const void** args, void** temps,
XlaCustomCallStatus*, int64_t* profile_counters);

// Signature of the XLA Runtime raw function. Used only by XLA Runtime AOT.
using XlaRuntimeRawFunction = void (*)(void**);

// Signature of an external run function. Used only by XLA Runtime JIT.
using ExternalRunFunction =
bool (*)(const xla::cpu::CpuExecutable* cpu_executable,
const std::vector<xla::cpu::BufferDesc>& descriptor_table,
const xla::ExecutableRunOptions* run_options);

// Simple struct to describe a tensor's shape.
// Note: this is a poor man's substitute for xla::ShapeProto, but we cannot
// depend on protobuf's in this library.
Expand All @@ -90,9 +81,6 @@ class XlaCompiledCpuFunction {
// The raw function to call.
RawFunction raw_function_;

ExternalRunFunction external_run_function_ = nullptr;
const xla::cpu::CpuExecutable* cpu_executable_ = nullptr;

// Contains information about the buffers used by the XLA computation.
const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr;
int32_t num_buffers_ = 0;
Expand Down Expand Up @@ -139,8 +127,6 @@ class XlaCompiledCpuFunction {
// declared so we don't have access to that information here.
int64_t profile_counters_size_ = 0;

bool use_xla_runtime_ = false;

// Only XlaCompiledCpuFunction is allowed to read and write the above
// fields.
friend class XlaCompiledCpuFunction;
Expand Down Expand Up @@ -333,16 +319,6 @@ class XlaCompiledCpuFunction {
static_data->raw_function_ = raw_function;
}

static void set_static_data_external_run_function(
StaticData* static_data, ExternalRunFunction external_run_function) {
static_data->external_run_function_ = external_run_function;
}

static void set_static_data_cpu_executable(
StaticData* static_data, const xla::cpu::CpuExecutable* cpu_executable) {
static_data->cpu_executable_ = cpu_executable;
}

static void set_static_data_buffer_infos(
StaticData* static_data,
const xla::cpu_function_runtime::BufferInfo* buffer_infos) {
Expand Down Expand Up @@ -430,19 +406,13 @@ class XlaCompiledCpuFunction {
static_data->profile_counters_size_ = profile_counters_size;
}

static void set_static_data_use_xla_runtime(StaticData* static_data,
bool use_xla_runtime) {
static_data->use_xla_runtime_ = use_xla_runtime;
}
// TODO(ezhulenev): This is a no-op after removing xla runtime, however it is
// still required for building some targets. Figure out why and delete!
static void set_static_data_use_xla_runtime(StaticData* static_data, bool) {}

private:
const RawFunction raw_function_;

// [Optional] External Run() function.
const ExternalRunFunction external_run_function_;
// [Maybe Optional] CpuExecutable to be passed to external_run_function_.
const xla::cpu::CpuExecutable* cpu_executable_;

const size_t result_index_;

// Array containing pointers to argument and temp buffers (slots corresponding
Expand Down Expand Up @@ -490,13 +460,6 @@ class XlaCompiledCpuFunction {
const xla::ProgramShapeProto* program_shape_ = nullptr;
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;

const bool use_xla_runtime_ = false;

// Creates a descriptor table for XLA Runtime.
std::vector<xla::cpu::BufferDesc> MakeXlaRuntimeDescriptorTable();

bool RunXlaRuntime();

// Add `XlaJitCompiledCpuFunction` as a friend so that it can access the
// `set_static_data_*` static methods above.
friend class XlaJitCompiledCpuFunction;
Expand Down
15 changes: 0 additions & 15 deletions tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,6 @@ void CollectNames(const T& entries, std::vector<string>* nonempty_names,
name_ptrs->push_back(nullptr); // array terminator
}

bool RunXlaRuntime(const xla::cpu::CpuExecutable* cpu_executable,
const std::vector<xla::cpu::BufferDesc>& descriptor_table,
const xla::ExecutableRunOptions* run_options) {
assert(cpu_executable->IsXlaRuntime());
Status status =
cpu_executable->ExecuteXlaRuntime(descriptor_table, run_options);
return status.ok();
}

} // namespace

/*static*/ absl::StatusOr<std::unique_ptr<XlaJitCompiledCpuFunction>>
Expand Down Expand Up @@ -171,12 +162,6 @@ XlaJitCompiledCpuFunction::Compile(
std::make_unique<xla::ProgramShapeProto>(program_shape->ToProto());
XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_,
raw_function);
if (cpu_executable->IsXlaRuntime()) {
XlaCompiledCpuFunction::set_static_data_external_run_function(
&jit->static_data_, RunXlaRuntime);
XlaCompiledCpuFunction::set_static_data_cpu_executable(&jit->static_data_,
cpu_executable);
}
XlaCompiledCpuFunction::set_static_data_buffer_infos(
&jit->static_data_, jit->buffer_infos_.data());
XlaCompiledCpuFunction::set_static_data_num_buffers(
Expand Down

0 comments on commit 872be62

Please sign in to comment.