Skip to content

Commit

Permalink
[xla:ffi] Add XLA_FFI_ExecutionStage enum to call frame and add to to…
Browse files Browse the repository at this point in the history
… error logs

Add execution stage to error logs to be able to distinguish errors coming from different execution stages.

PiperOrigin-RevId: 636539611
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed May 24, 2024
1 parent b4afcb3 commit e2e5b22
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 44 deletions.
29 changes: 21 additions & 8 deletions third_party/xla/xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1278,20 +1278,33 @@ class Handler : public Ffi {
XLA_FFI_Error* FailedDecodeError(const XLA_FFI_CallFrame* call_frame,
std::array<bool, kSize> decoded,
const DiagnosticEngine& diagnostic) const {
std::string message =
"Failed to decode all FFI handler operands (bad operands at: ";
auto stage = [&] {
switch (call_frame->stage) {
case XLA_FFI_ExecutionStage_PREPARE:
return "prepare";
case XLA_FFI_ExecutionStage_INITIALIZE:
return "initialize";
case XLA_FFI_ExecutionStage_EXECUTE:
return "execute";
}
};

std::string str;
std::stringstream message(str);

message << "[" << stage() << "] "
<< "Failed to decode all FFI handler operands (bad operands at: ";
for (size_t cnt = 0, idx = 0; idx < kSize; ++idx) {
if (!decoded[idx]) {
if (cnt++) message.append(", ");
message.append(std::to_string(idx));
if (cnt++) message << ", ";
message << std::to_string(idx);
}
}
message.append(")");
message << ")";
if (auto s = std::move(diagnostic).Result(); !s.empty()) {
message.append("\nDiagnostics:\n");
message.append(s);
message << "\nDiagnostics:\n" << s;
}
return InvalidArgument(call_frame->api, message);
return InvalidArgument(call_frame->api, message.str());
}

template <typename...>
Expand Down
44 changes: 26 additions & 18 deletions third_party/xla/xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,30 @@ struct XLA_FFI_Array {
// Call frame
//===----------------------------------------------------------------------===//

// XLA runtime has multiple execution stages and it is possible to run
// different handlers for each stage:
//
// (1) Prepare - called before the execution to let FFI handlers to prepare
// for the execution and request resources from runtime, i.e. in XLA:GPU
// we use prepare stage to request collective cliques.
//
// (2) Initialize - called before the execution after acquiring all the
// resources requested in the prepare stage.
//
// (3) Execute - called when FFI handler is executed. Note that FFI handler
// can be called as a part of command buffer capture (CUDA graph capture
// on GPU backend) and argument buffers might contain uninitialized
// values in this case.
//
// It is undefined behavior to access argument buffers in prepare and
// initialize stages as they might not be initialized yet. However it is safe
// to use memory address as it is assigned ahead of time by buffer assignment.
typedef enum {
XLA_FFI_ExecutionStage_PREPARE = 0,
XLA_FFI_ExecutionStage_INITIALIZE = 1,
XLA_FFI_ExecutionStage_EXECUTE = 2,
} XLA_FFI_ExecutionStage;

struct XLA_FFI_Args {
size_t struct_size;
void* priv;
Expand Down Expand Up @@ -303,6 +327,7 @@ struct XLA_FFI_CallFrame {

const XLA_FFI_Api* api;
XLA_FFI_ExecutionContext* ctx;
XLA_FFI_ExecutionStage stage;
XLA_FFI_Args args;
XLA_FFI_Rets rets;
XLA_FFI_Attrs attrs;
Expand All @@ -317,24 +342,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_CallFrame, attrs);
// External functions registered with XLA as FFI handlers.
typedef XLA_FFI_Error* XLA_FFI_Handler(XLA_FFI_CallFrame* call_frame);

// XLA runtime has multiple execution stages and it is possible to run
// different handlers for each stage:
//
// (1) Prepare - called before the execution to let FFI handlers to prepare
// for the execution and request resources from runtime, i.e. in XLA:GPU
// we use prepare stage to request collective cliques.
//
// (2) Initialize - called before the execution after acquiring all the
// resources requested in the prepare stage.
//
// (3) Execute - called when FFI handler is executed. Note that FFI handler
// can be called as a part of command buffer capture (CUDA graph capture
// on GPU backend) and argument buffers might contain uninitialized
// values in this case.
//
// It is undefined behavior to access argument buffers in prepare and
// initialize stages as they might not be initialized yet. However it is safe
// to use memory address as it is assigned ahead of time by buffer assignment.
// XLA FFI handlers for execution stages (see XLA_FFI_ExecutionStage).
struct XLA_FFI_Handler_Bundle {
XLA_FFI_Handler* prepare; // optional
XLA_FFI_Handler* initialize; // optional
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,12 @@ CallFrame::CallFrame(absl::Span<const CallFrameBuilder::Buffer> args,
attributes_(InitAttrs(attrs)) {}

XLA_FFI_CallFrame CallFrame::Build(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx) {
XLA_FFI_ExecutionContext* ctx,
XLA_FFI_ExecutionStage stage) {
XLA_FFI_CallFrame call_frame = {XLA_FFI_CallFrame_STRUCT_SIZE, nullptr};
call_frame.api = api;
call_frame.ctx = ctx;
call_frame.stage = stage;
call_frame.args = arguments_->ffi_args;
call_frame.rets = results_->ffi_rets;
call_frame.attrs = attributes_->ffi_attrs;
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/ffi/call_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ class CallFrame {
~CallFrame();

// Builds an XLA_FFI_CallFrame from owned arguments and attributes.
XLA_FFI_CallFrame Build(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx);
XLA_FFI_CallFrame Build(
const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

private:
friend class CallFrameBuilder;
Expand Down
10 changes: 6 additions & 4 deletions third_party/xla/xla/ffi/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,18 @@ absl::Status TakeStatus(XLA_FFI_Error* error) {
}

absl::Status Call(Ffi& handler, CallFrame& call_frame,
const CallOptions& options) {
const CallOptions& options, XLA_FFI_ExecutionStage stage) {
XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options);
XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx);
XLA_FFI_CallFrame ffi_call_frame =
call_frame.Build(GetXlaFfiApi(), &ctx, stage);
return TakeStatus(handler.Call(&ffi_call_frame));
}

absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options) {
const CallOptions& options, XLA_FFI_ExecutionStage stage) {
XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options);
XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx);
XLA_FFI_CallFrame ffi_call_frame =
call_frame.Build(GetXlaFfiApi(), &ctx, stage);
return TakeStatus((*handler)(&ffi_call_frame));
}

Expand Down
13 changes: 8 additions & 5 deletions third_party/xla/xla/ffi/ffi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@ struct CallOptions {
// `error` if it's not nullptr; returns OK status otherwise.
absl::Status TakeStatus(XLA_FFI_Error* error);

absl::Status Call(Ffi& handler, CallFrame& call_frame,
const CallOptions& options = {});

absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options = {});
absl::Status Call(
Ffi& handler, CallFrame& call_frame, const CallOptions& options = {},
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

absl::Status Call(
XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options = {},
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

namespace internal {
// This is an internal workaround to override FFI execution context for FFI
Expand Down
15 changes: 9 additions & 6 deletions third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
}

absl::Status CustomCallThunk::ExecuteFfiHandler(
XLA_FFI_Handler* handler, int32_t device_ordinal, se::Stream* stream,
XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
int32_t device_ordinal, se::Stream* stream,
se::DeviceMemoryAllocator* allocator,
const ffi::ExecutionContext* execution_context,
const BufferAllocations* buffer_allocations) {
Expand Down Expand Up @@ -155,7 +156,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(

CallOptions options = {device_ordinal, stream, allocator, called_computation_,
execution_context};
return Call(bundle_->execute, call_frame, options);
return Call(bundle_->execute, call_frame, options, stage);
}

absl::Status CustomCallThunk::Prepare(const PrepareParams& params,
Expand All @@ -172,16 +173,18 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) {
}

return ExecuteFfiHandler(
bundle_->initialize, params.buffer_allocations->device_ordinal(),
params.stream, params.buffer_allocations->memory_allocator(),
bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE,
params.buffer_allocations->device_ordinal(), params.stream,
params.buffer_allocations->memory_allocator(),
params.ffi_execution_context, params.buffer_allocations);
}

absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
if (bundle_.has_value()) {
return ExecuteFfiHandler(
bundle_->execute, params.buffer_allocations->device_ordinal(),
params.stream, params.buffer_allocations->memory_allocator(),
bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE,
params.buffer_allocations->device_ordinal(), params.stream,
params.buffer_allocations->memory_allocator(),
params.ffi_execution_context, params.buffer_allocations);
}
return ExecuteCustomCall(params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class CustomCallThunk : public Thunk {
absl::Status ExecuteCustomCall(const ExecuteParams& params);

absl::Status ExecuteFfiHandler(XLA_FFI_Handler* handler,
XLA_FFI_ExecutionStage stage,
int32_t device_ordinal, se::Stream* stream,
se::DeviceMemoryAllocator* allocator,
const ffi::ExecutionContext* execution_context,
Expand Down

0 comments on commit e2e5b22

Please sign in to comment.