Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xla:ffi] Add XLA_FFI_ExecutionStage enum to call frame and add to to error logs #68541

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions third_party/xla/xla/ffi/api/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1278,20 +1278,31 @@ class Handler : public Ffi {
XLA_FFI_Error* FailedDecodeError(const XLA_FFI_CallFrame* call_frame,
std::array<bool, kSize> decoded,
const DiagnosticEngine& diagnostic) const {
std::string message =
"Failed to decode all FFI handler operands (bad operands at: ";
auto stage = [&] {
switch (call_frame->stage) {
case XLA_FFI_ExecutionStage_PREPARE:
return "prepare";
case XLA_FFI_ExecutionStage_INITIALIZE:
return "initialize";
case XLA_FFI_ExecutionStage_EXECUTE:
return "execute";
}
};

std::stringstream message;
message << "[" << stage() << "] "
<< "Failed to decode all FFI handler operands (bad operands at: ";
for (size_t cnt = 0, idx = 0; idx < kSize; ++idx) {
if (!decoded[idx]) {
if (cnt++) message.append(", ");
message.append(std::to_string(idx));
if (cnt++) message << ", ";
message << std::to_string(idx);
}
}
message.append(")");
message << ")";
if (auto s = std::move(diagnostic).Result(); !s.empty()) {
message.append("\nDiagnostics:\n");
message.append(s);
message << "\nDiagnostics:\n" << s;
}
return InvalidArgument(call_frame->api, message);
return InvalidArgument(call_frame->api, message.str());
}

template <typename...>
Expand Down
44 changes: 26 additions & 18 deletions third_party/xla/xla/ffi/api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,30 @@ struct XLA_FFI_Array {
// Call frame
//===----------------------------------------------------------------------===//

// XLA runtime has multiple execution stages and it is possible to run
// different handlers for each stage:
//
// (1) Prepare - called before the execution to let FFI handlers to prepare
// for the execution and request resources from runtime, i.e. in XLA:GPU
// we use prepare stage to request collective cliques.
//
// (2) Initialize - called before the execution after acquiring all the
// resources requested in the prepare stage.
//
// (3) Execute - called when FFI handler is executed. Note that FFI handler
// can be called as a part of command buffer capture (CUDA graph capture
// on GPU backend) and argument buffers might contain uninitialized
// values in this case.
//
// It is undefined behavior to access argument buffers in prepare and
// initialize stages as they might not be initialized yet. However it is safe
// to use memory address as it is assigned ahead of time by buffer assignment.
typedef enum {
XLA_FFI_ExecutionStage_PREPARE = 0,
XLA_FFI_ExecutionStage_INITIALIZE = 1,
XLA_FFI_ExecutionStage_EXECUTE = 2,
} XLA_FFI_ExecutionStage;

struct XLA_FFI_Args {
size_t struct_size;
void* priv;
Expand Down Expand Up @@ -303,6 +327,7 @@ struct XLA_FFI_CallFrame {

const XLA_FFI_Api* api;
XLA_FFI_ExecutionContext* ctx;
XLA_FFI_ExecutionStage stage;
XLA_FFI_Args args;
XLA_FFI_Rets rets;
XLA_FFI_Attrs attrs;
Expand All @@ -317,24 +342,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_CallFrame, attrs);
// External functions registered with XLA as FFI handlers.
typedef XLA_FFI_Error* XLA_FFI_Handler(XLA_FFI_CallFrame* call_frame);

// XLA runtime has multiple execution stages and it is possible to run
// different handlers for each stage:
//
// (1) Prepare - called before the execution to let FFI handlers to prepare
// for the execution and request resources from runtime, i.e. in XLA:GPU
// we use prepare stage to request collective cliques.
//
// (2) Initialize - called before the execution after acquiring all the
// resources requested in the prepare stage.
//
// (3) Execute - called when FFI handler is executed. Note that FFI handler
// can be called as a part of command buffer capture (CUDA graph capture
// on GPU backend) and argument buffers might contain uninitialized
// values in this case.
//
// It is undefined behavior to access argument buffers in prepare and
// initialize stages as they might not be initialized yet. However it is safe
// to use memory address as it is assigned ahead of time by buffer assignment.
// XLA FFI handlers for execution stages (see XLA_FFI_ExecutionStage).
struct XLA_FFI_Handler_Bundle {
XLA_FFI_Handler* prepare; // optional
XLA_FFI_Handler* initialize; // optional
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/ffi/call_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,12 @@ CallFrame::CallFrame(absl::Span<const CallFrameBuilder::Buffer> args,
attributes_(InitAttrs(attrs)) {}

XLA_FFI_CallFrame CallFrame::Build(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx) {
XLA_FFI_ExecutionContext* ctx,
XLA_FFI_ExecutionStage stage) {
XLA_FFI_CallFrame call_frame = {XLA_FFI_CallFrame_STRUCT_SIZE, nullptr};
call_frame.api = api;
call_frame.ctx = ctx;
call_frame.stage = stage;
call_frame.args = arguments_->ffi_args;
call_frame.rets = results_->ffi_rets;
call_frame.attrs = attributes_->ffi_attrs;
Expand Down
5 changes: 3 additions & 2 deletions third_party/xla/xla/ffi/call_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,9 @@ class CallFrame {
~CallFrame();

// Builds an XLA_FFI_CallFrame from owned arguments and attributes.
XLA_FFI_CallFrame Build(const XLA_FFI_Api* api,
XLA_FFI_ExecutionContext* ctx);
XLA_FFI_CallFrame Build(
const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

private:
friend class CallFrameBuilder;
Expand Down
10 changes: 6 additions & 4 deletions third_party/xla/xla/ffi/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,18 @@ absl::Status TakeStatus(XLA_FFI_Error* error) {
}

absl::Status Call(Ffi& handler, CallFrame& call_frame,
const CallOptions& options) {
const CallOptions& options, XLA_FFI_ExecutionStage stage) {
XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options);
XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx);
XLA_FFI_CallFrame ffi_call_frame =
call_frame.Build(GetXlaFfiApi(), &ctx, stage);
return TakeStatus(handler.Call(&ffi_call_frame));
}

absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options) {
const CallOptions& options, XLA_FFI_ExecutionStage stage) {
XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options);
XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx);
XLA_FFI_CallFrame ffi_call_frame =
call_frame.Build(GetXlaFfiApi(), &ctx, stage);
return TakeStatus((*handler)(&ffi_call_frame));
}

Expand Down
13 changes: 8 additions & 5 deletions third_party/xla/xla/ffi/ffi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,14 @@ struct CallOptions {
// `error` if it's not nullptr; returns OK status otherwise.
absl::Status TakeStatus(XLA_FFI_Error* error);

absl::Status Call(Ffi& handler, CallFrame& call_frame,
const CallOptions& options = {});

absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options = {});
absl::Status Call(
Ffi& handler, CallFrame& call_frame, const CallOptions& options = {},
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

absl::Status Call(
XLA_FFI_Handler* handler, CallFrame& call_frame,
const CallOptions& options = {},
XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE);

namespace internal {
// This is an internal workaround to override FFI execution context for FFI
Expand Down
15 changes: 9 additions & 6 deletions third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
}

absl::Status CustomCallThunk::ExecuteFfiHandler(
XLA_FFI_Handler* handler, int32_t device_ordinal, se::Stream* stream,
XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
int32_t device_ordinal, se::Stream* stream,
se::DeviceMemoryAllocator* allocator,
const ffi::ExecutionContext* execution_context,
const BufferAllocations* buffer_allocations) {
Expand Down Expand Up @@ -155,7 +156,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(

CallOptions options = {device_ordinal, stream, allocator, called_computation_,
execution_context};
return Call(bundle_->execute, call_frame, options);
return Call(bundle_->execute, call_frame, options, stage);
}

absl::Status CustomCallThunk::Prepare(const PrepareParams& params,
Expand All @@ -172,16 +173,18 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) {
}

return ExecuteFfiHandler(
bundle_->initialize, params.buffer_allocations->device_ordinal(),
params.stream, params.buffer_allocations->memory_allocator(),
bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE,
params.buffer_allocations->device_ordinal(), params.stream,
params.buffer_allocations->memory_allocator(),
params.ffi_execution_context, params.buffer_allocations);
}

absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
if (bundle_.has_value()) {
return ExecuteFfiHandler(
bundle_->execute, params.buffer_allocations->device_ordinal(),
params.stream, params.buffer_allocations->memory_allocator(),
bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE,
params.buffer_allocations->device_ordinal(), params.stream,
params.buffer_allocations->memory_allocator(),
params.ffi_execution_context, params.buffer_allocations);
}
return ExecuteCustomCall(params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class CustomCallThunk : public Thunk {
absl::Status ExecuteCustomCall(const ExecuteParams& params);

absl::Status ExecuteFfiHandler(XLA_FFI_Handler* handler,
XLA_FFI_ExecutionStage stage,
int32_t device_ordinal, se::Stream* stream,
se::DeviceMemoryAllocator* allocator,
const ffi::ExecutionContext* execution_context,
Expand Down
Loading