Skip to content

Commit

Permalink
Remove StreamExecutorInterface from all the non-static CommandBuffer …
Browse files Browse the repository at this point in the history
…methods.

Each CommandBuffer class is uniquely tied to a specific StreamExecutorInterface (parent_), which is tracked as member data.

PiperOrigin-RevId: 632593555
  • Loading branch information
klucke authored and tensorflower-gardener committed May 10, 2024
1 parent adf1068 commit 2ca68b7
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 166 deletions.
19 changes: 6 additions & 13 deletions third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ absl::Status CommandBufferCmdSequence::Record(
}
}

se::StreamExecutor* device = execute_params.stream->parent();
const ModuleAnnotations* annotations = GetCurrentModuleAnnotations();

// Track the number of commands recorded between barriers.
Expand All @@ -309,7 +308,7 @@ absl::Status CommandBufferCmdSequence::Record(
<< num_recorded_commands[execution_scope_id]
<< " recorded commands into the execution scope #"
<< execution_scope_id.value();
TF_RETURN_IF_ERROR(command_buffer->Barrier(device, execution_scope_id));
TF_RETURN_IF_ERROR(command_buffer->Barrier(execution_scope_id));
num_recorded_commands.erase(execution_scope_id);
}
VLOG(5) << " Record command buffer with scope id "
Expand Down Expand Up @@ -849,8 +848,7 @@ absl::Status IfCmd::Record(const Thunk::ExecuteParams& execute_params,
VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")";

return command_buffer->If(
execution_scope_id, execute_params.stream->parent(),
se::DeviceMemory<bool>(pred),
execution_scope_id, se::DeviceMemory<bool>(pred),
CreateBuilder(&then_commands_, &execute_params, &record_params));
}

Expand Down Expand Up @@ -893,8 +891,7 @@ absl::Status IfElseCmd::Record(const Thunk::ExecuteParams& execute_params,
VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")";

return command_buffer->IfElse(
execution_scope_id, execute_params.stream->parent(),
se::DeviceMemory<bool>(pred),
execution_scope_id, se::DeviceMemory<bool>(pred),
CreateBuilder(&then_commands_, &execute_params, &record_params),
CreateBuilder(&else_commands_, &execute_params, &record_params));
}
Expand Down Expand Up @@ -939,7 +936,6 @@ absl::Status CaseCmd::Record(const Thunk::ExecuteParams& execute_params,
VLOG(5) << " index: " << index_ << " (" << index.opaque() << ")";

return command_buffer->Case(execution_scope_id,
execute_params.stream->parent(),
se::DeviceMemory<int32_t>(index),
CreateBuilders(absl::MakeSpan(branches_commands_),
&execute_params, &record_params));
Expand Down Expand Up @@ -985,7 +981,7 @@ absl::Status ForCmd::Record(const Thunk::ExecuteParams& execute_params,
<< loop_counter.opaque() << ")";

return command_buffer->For(
execution_scope_id, execute_params.stream->parent(), num_iterations_,
execution_scope_id, num_iterations_,
se::DeviceMemory<int32_t>(loop_counter),
CreateBuilder(&body_commands_, &execute_params, &record_params));
}
Expand Down Expand Up @@ -1030,8 +1026,7 @@ absl::Status WhileCmd::Record(const Thunk::ExecuteParams& execute_params,
VLOG(5) << " pred: " << pred_ << " (" << pred.opaque() << ")";

return command_buffer->While(
execution_scope_id, execute_params.stream->parent(),
se::DeviceMemory<bool>(pred),
execution_scope_id, se::DeviceMemory<bool>(pred),
CreateExecutionScopeBuilder(&cond_commands_, &execute_params,
&record_params),
CreateBuilder(&body_commands_, &execute_params, &record_params));
Expand Down Expand Up @@ -1340,7 +1335,6 @@ absl::Status BarrierCmd::Record(const Thunk::ExecuteParams& execute_params,
<< " to stream " << execution_stream_id().value();
if (from_stream_id_ != execution_stream_id()) {
TF_RETURN_IF_ERROR(command_buffer->Barrier(
execute_params.stream->parent(),
CommandBufferCmd::GetExecutionScope(record_params, from_stream_id_),
CommandBufferCmd::GetExecutionScope(record_params,
execution_stream_id())));
Expand All @@ -1367,8 +1361,7 @@ absl::Status CollectiveCmd::BarrierIfAsync(
const CommandBufferCmd::RecordParams& record_params) {
if (IsAsync()) {
TF_RETURN_IF_ERROR(
command_buffer->Barrier(executor,
CommandBufferCmd::GetExecutionScope(
command_buffer->Barrier(CommandBufferCmd::GetExecutionScope(
record_params, async_from_stream_id_),
CommandBufferCmd::GetExecutionScope(
record_params, execution_stream_id())));
Expand Down
43 changes: 14 additions & 29 deletions third_party/xla/xla/stream_executor/command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,20 @@ class CommandBuffer {
// Adds an execution barrier to a given execution scope: all commands added
// before a barrier in a the execution scope will complete before any of the
// commands added after a barrier in the same execution scope.
virtual absl::Status Barrier(StreamExecutorInterface* executor,
ExecutionScopeId execution_scope_id) = 0;
virtual absl::Status Barrier(ExecutionScopeId execution_scope_id) = 0;

// Adds an execution barrier that synchronizes commands across multiple
// execution scopes. See example #2 in execution scope id documentation.
virtual absl::Status Barrier(
StreamExecutorInterface* executor,
absl::Span<const ExecutionScopeId> execution_scope_ids) = 0;

// Adds an execution barrier from execution scope `from_execution_scope_id` to
// execution scope `to_execution_scope_id`. See example #3 for details.
virtual absl::Status Barrier(StreamExecutorInterface* executor,
ExecutionScopeId from_execution_scope_id,
virtual absl::Status Barrier(ExecutionScopeId from_execution_scope_id,
ExecutionScopeId to_execution_scope_id) = 0;

// Adds an execution barrier to the default execution scope.
absl::Status Barrier(StreamExecutorInterface* executor) {
return Barrier(executor, kDefaulExecutionScope);
}
absl::Status Barrier() { return Barrier(kDefaulExecutionScope); }

// Adds a kernel launch command.
virtual absl::Status Launch(ExecutionScopeId execution_scope_id,
Expand Down Expand Up @@ -292,29 +287,24 @@ class CommandBuffer {
// Adds a conditional operation that will execute a command buffer constructed
// by `then_builder` if `pred` value is `true`.
virtual absl::Status If(ExecutionScopeId execution_scope_id,
StreamExecutorInterface* executor,
DeviceMemory<bool> pred, Builder then_builder) = 0;

// Adds a conditional If operation to default execution scope.
absl::Status If(StreamExecutorInterface* executor, DeviceMemory<bool> pred,
Builder then_builder) {
return If(kDefaulExecutionScope, executor, pred, then_builder);
absl::Status If(DeviceMemory<bool> pred, Builder then_builder) {
return If(kDefaulExecutionScope, pred, then_builder);
}

// Adds a conditional operation that will execute a command buffer constructed
// by `then_builder` if `pred` value is `true`, or a command buffer
// constructed by `else_builder` if `pred` is `false`.
virtual absl::Status IfElse(ExecutionScopeId execution_scope_id,
StreamExecutorInterface* executor,
DeviceMemory<bool> pred, Builder then_builder,
Builder else_builder) = 0;

// Adds a conditional IfElse operation to default execution scope.
absl::Status IfElse(StreamExecutorInterface* executor,
DeviceMemory<bool> pred, Builder then_builder,
absl::Status IfElse(DeviceMemory<bool> pred, Builder then_builder,
Builder else_builder) {
return IfElse(kDefaulExecutionScope, executor, pred, then_builder,
else_builder);
return IfElse(kDefaulExecutionScope, pred, then_builder, else_builder);
}

// Adds a conditional operation that will execute a command buffer constructed
Expand All @@ -323,31 +313,28 @@ class CommandBuffer {
//
// See: https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case
virtual absl::Status Case(ExecutionScopeId execution_scope_id,
StreamExecutorInterface* executor,
DeviceMemory<int32_t> index,
std::vector<Builder> branches) = 0;

// Adds a conditional Case operation to default execution scope.
absl::Status Case(StreamExecutorInterface* executor,
DeviceMemory<int32_t> index,
absl::Status Case(DeviceMemory<int32_t> index,
std::vector<Builder> branches) {
return Case(kDefaulExecutionScope, executor, index, branches);
return Case(kDefaulExecutionScope, index, branches);
}

// Adds a conditional operation that will execute a command buffer constructed
// by the `body_builder` exactly `num_iteration` times. This means the
// condition is known at compile time (`num_iteration` < `loop_counter`), and
// does not require a `cond_builder`.
virtual absl::Status For(ExecutionScopeId execution_scope_id,
StreamExecutorInterface* executor,
int32_t num_iteration,
DeviceMemory<int32_t> loop_counter,
Builder body_builder) = 0;

// Adds a conditional For operation to default execution scope.
absl::Status For(StreamExecutorInterface* executor, int32_t num_iteration,
DeviceMemory<int32_t> loop_counter, Builder body_builder) {
return For(kDefaulExecutionScope, executor, num_iteration, loop_counter,
absl::Status For(int32_t num_iteration, DeviceMemory<int32_t> loop_counter,
Builder body_builder) {
return For(kDefaulExecutionScope, num_iteration, loop_counter,
body_builder);
}

Expand All @@ -368,16 +355,14 @@ class CommandBuffer {
// condition twice: (1) before the conditional node in the scope defined by
// `execution_scope_id` (2) inside the loop body with default execution scope.
virtual absl::Status While(ExecutionScopeId execution_scope_id,
StreamExecutorInterface* executor,
DeviceMemory<bool> pred,
ExecutionScopeBuilder cond_builder,
Builder body_builder) = 0;

// Adds a conditional While operation to default execution scope.
absl::Status While(StreamExecutorInterface* executor, DeviceMemory<bool> pred,
absl::Status While(DeviceMemory<bool> pred,
ExecutionScopeBuilder cond_builder, Builder body_builder) {
return While(kDefaulExecutionScope, executor, pred, cond_builder,
body_builder);
return While(kDefaulExecutionScope, pred, cond_builder, body_builder);
}

//--------------------------------------------------------------------------//
Expand Down
Loading

0 comments on commit 2ca68b7

Please sign in to comment.