From e2e5b2203dec4fd6ac9c8c96185e92f247cc6ec7 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 23 May 2024 07:09:23 -0700 Subject: [PATCH] [xla:ffi] Add XLA_FFI_ExecutionStage enum to call frame and add to to error logs Add execution stage to error logs to be able to distinguish errors coming from different execution stages. PiperOrigin-RevId: 636539611 --- third_party/xla/xla/ffi/api/api.h | 29 ++++++++---- third_party/xla/xla/ffi/api/c_api.h | 44 +++++++++++-------- third_party/xla/xla/ffi/call_frame.cc | 4 +- third_party/xla/xla/ffi/call_frame.h | 5 ++- third_party/xla/xla/ffi/ffi_api.cc | 10 +++-- third_party/xla/xla/ffi/ffi_api.h | 13 +++--- .../service/gpu/runtime/custom_call_thunk.cc | 15 ++++--- .../service/gpu/runtime/custom_call_thunk.h | 1 + 8 files changed, 77 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 71da434822386b..df574e4643c027 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -1278,20 +1278,33 @@ class Handler : public Ffi { XLA_FFI_Error* FailedDecodeError(const XLA_FFI_CallFrame* call_frame, std::array decoded, const DiagnosticEngine& diagnostic) const { - std::string message = - "Failed to decode all FFI handler operands (bad operands at: "; + auto stage = [&] { + switch (call_frame->stage) { + case XLA_FFI_ExecutionStage_PREPARE: + return "prepare"; + case XLA_FFI_ExecutionStage_INITIALIZE: + return "initialize"; + case XLA_FFI_ExecutionStage_EXECUTE: + return "execute"; + } + }; + + std::string str; + std::stringstream message(str); + + message << "[" << stage() << "] " + << "Failed to decode all FFI handler operands (bad operands at: "; for (size_t cnt = 0, idx = 0; idx < kSize; ++idx) { if (!decoded[idx]) { - if (cnt++) message.append(", "); - message.append(std::to_string(idx)); + if (cnt++) message << ", "; + message << std::to_string(idx); } } - message.append(")"); + message << ")"; if (auto s = std::move(diagnostic).Result(); !s.empty()) { - message.append("\nDiagnostics:\n"); - message.append(s); + message << "\nDiagnostics:\n" << s; } - return InvalidArgument(call_frame->api, message); + return InvalidArgument(call_frame->api, message.str()); } template diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index 3670e7a9f8a276..bf9c2d80676caf 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -261,6 +261,30 @@ struct XLA_FFI_Array { // Call frame //===----------------------------------------------------------------------===// +// XLA runtime has multiple execution stages and it is possible to run +// different handlers for each stage: +// +// (1) Prepare - called before the execution to let FFI handlers to prepare +// for the execution and request resources from runtime, i.e. in XLA:GPU +// we use prepare stage to request collective cliques. +// +// (2) Initialize - called before the execution after acquiring all the +// resources requested in the prepare stage. +// +// (3) Execute - called when FFI handler is executed. Note that FFI handler +// can be called as a part of command buffer capture (CUDA graph capture +// on GPU backend) and argument buffers might contain uninitialized +// values in this case. +// +// It is undefined behavior to access argument buffers in prepare and +// initialize stages as they might not be initialized yet. However it is safe +// to use memory address as it is assigned ahead of time by buffer assignment. +typedef enum { + XLA_FFI_ExecutionStage_PREPARE = 0, + XLA_FFI_ExecutionStage_INITIALIZE = 1, + XLA_FFI_ExecutionStage_EXECUTE = 2, +} XLA_FFI_ExecutionStage; + struct XLA_FFI_Args { size_t struct_size; void* priv; @@ -303,6 +327,7 @@ struct XLA_FFI_CallFrame { const XLA_FFI_Api* api; XLA_FFI_ExecutionContext* ctx; + XLA_FFI_ExecutionStage stage; XLA_FFI_Args args; XLA_FFI_Rets rets; XLA_FFI_Attrs attrs; @@ -317,24 +342,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_CallFrame, attrs); // External functions registered with XLA as FFI handlers. typedef XLA_FFI_Error* XLA_FFI_Handler(XLA_FFI_CallFrame* call_frame); -// XLA runtime has multiple execution stages and it is possible to run -// different handlers for each stage: -// -// (1) Prepare - called before the execution to let FFI handlers to prepare -// for the execution and request resources from runtime, i.e. in XLA:GPU -// we use prepare stage to request collective cliques. -// -// (2) Initialize - called before the execution after acquiring all the -// resources requested in the prepare stage. -// -// (3) Execute - called when FFI handler is executed. Note that FFI handler -// can be called as a part of command buffer capture (CUDA graph capture -// on GPU backend) and argument buffers might contain uninitialized -// values in this case. -// -// It is undefined behavior to access argument buffers in prepare and -// initialize stages as they might not be initialized yet. However it is safe -// to use memory address as it is assigned ahead of time by buffer assignment. +// XLA FFI handlers for execution stages (see XLA_FFI_ExecutionStage). struct XLA_FFI_Handler_Bundle { XLA_FFI_Handler* prepare; // optional XLA_FFI_Handler* initialize; // optional diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index c23508782bc858..bd991f4e584700 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -220,10 +220,12 @@ CallFrame::CallFrame(absl::Span args, attributes_(InitAttrs(attrs)) {} XLA_FFI_CallFrame CallFrame::Build(const XLA_FFI_Api* api, - XLA_FFI_ExecutionContext* ctx) { + XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage) { XLA_FFI_CallFrame call_frame = {XLA_FFI_CallFrame_STRUCT_SIZE, nullptr}; call_frame.api = api; call_frame.ctx = ctx; + call_frame.stage = stage; call_frame.args = arguments_->ffi_args; call_frame.rets = results_->ffi_rets; call_frame.attrs = attributes_->ffi_attrs; diff --git a/third_party/xla/xla/ffi/call_frame.h b/third_party/xla/xla/ffi/call_frame.h index 9bfd31cd4326ae..7270c63f2fe487 100644 --- a/third_party/xla/xla/ffi/call_frame.h +++ b/third_party/xla/xla/ffi/call_frame.h @@ -133,8 +133,9 @@ class CallFrame { ~CallFrame(); // Builds an XLA_FFI_CallFrame from owned arguments and attributes. - XLA_FFI_CallFrame Build(const XLA_FFI_Api* api, - XLA_FFI_ExecutionContext* ctx); + XLA_FFI_CallFrame Build( + const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, + XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE); private: friend class CallFrameBuilder; diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index 39eefa3077e426..b2ba4a8ef0a0dd 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -81,16 +81,18 @@ absl::Status TakeStatus(XLA_FFI_Error* error) { } absl::Status Call(Ffi& handler, CallFrame& call_frame, - const CallOptions& options) { + const CallOptions& options, XLA_FFI_ExecutionStage stage) { XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options); - XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx); + XLA_FFI_CallFrame ffi_call_frame = + call_frame.Build(GetXlaFfiApi(), &ctx, stage); return TakeStatus(handler.Call(&ffi_call_frame)); } absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, - const CallOptions& options) { + const CallOptions& options, XLA_FFI_ExecutionStage stage) { XLA_FFI_ExecutionContext ctx = CreateExecutionContext(options); - XLA_FFI_CallFrame ffi_call_frame = call_frame.Build(GetXlaFfiApi(), &ctx); + XLA_FFI_CallFrame ffi_call_frame = + call_frame.Build(GetXlaFfiApi(), &ctx, stage); return TakeStatus((*handler)(&ffi_call_frame)); } diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h index f31f6a65095583..8c03428ec2471e 100644 --- a/third_party/xla/xla/ffi/ffi_api.h +++ b/third_party/xla/xla/ffi/ffi_api.h @@ -59,11 +59,14 @@ struct CallOptions { // `error` if it's not nullptr; returns OK status otherwise. absl::Status TakeStatus(XLA_FFI_Error* error); -absl::Status Call(Ffi& handler, CallFrame& call_frame, - const CallOptions& options = {}); - -absl::Status Call(XLA_FFI_Handler* handler, CallFrame& call_frame, - const CallOptions& options = {}); +absl::Status Call( + Ffi& handler, CallFrame& call_frame, const CallOptions& options = {}, + XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE); + +absl::Status Call( + XLA_FFI_Handler* handler, CallFrame& call_frame, + const CallOptions& options = {}, + XLA_FFI_ExecutionStage stage = XLA_FFI_ExecutionStage_EXECUTE); namespace internal { // This is an internal workaround to override FFI execution context for FFI diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc index 454fe759f791c0..d7922beed05427 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc @@ -112,7 +112,8 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { } absl::Status CustomCallThunk::ExecuteFfiHandler( - XLA_FFI_Handler* handler, int32_t device_ordinal, se::Stream* stream, + XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage, + int32_t device_ordinal, se::Stream* stream, se::DeviceMemoryAllocator* allocator, const ffi::ExecutionContext* execution_context, const BufferAllocations* buffer_allocations) { @@ -155,7 +156,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( CallOptions options = {device_ordinal, stream, allocator, called_computation_, execution_context}; - return Call(bundle_->execute, call_frame, options); + return Call(bundle_->execute, call_frame, options, stage); } absl::Status CustomCallThunk::Prepare(const PrepareParams& params, @@ -172,16 +173,18 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { } return ExecuteFfiHandler( - bundle_->initialize, params.buffer_allocations->device_ordinal(), - params.stream, params.buffer_allocations->memory_allocator(), + bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE, + params.buffer_allocations->device_ordinal(), params.stream, + params.buffer_allocations->memory_allocator(), params.ffi_execution_context, params.buffer_allocations); } absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { if (bundle_.has_value()) { return ExecuteFfiHandler( - bundle_->execute, params.buffer_allocations->device_ordinal(), - params.stream, params.buffer_allocations->memory_allocator(), + bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, + params.buffer_allocations->device_ordinal(), params.stream, + params.buffer_allocations->memory_allocator(), params.ffi_execution_context, params.buffer_allocations); } return ExecuteCustomCall(params); diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h index 14ad33eff3ff02..3a88202b298168 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h @@ -104,6 +104,7 @@ class CustomCallThunk : public Thunk { absl::Status ExecuteCustomCall(const ExecuteParams& params); absl::Status ExecuteFfiHandler(XLA_FFI_Handler* handler, + XLA_FFI_ExecutionStage stage, int32_t device_ordinal, se::Stream* stream, se::DeviceMemoryAllocator* allocator, const ffi::ExecutionContext* execution_context,