Skip to content

Commit

Permalink
[stream_executor] NFC: Port all commands to GetBarrier() with executi…
Browse files Browse the repository at this point in the history
…on scope arg

In preparation for adding execution scope id to all commands remove functions that automatically assume default scope.

PiperOrigin-RevId: 609181938
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Feb 22, 2024
1 parent 8688c76 commit 8e4418e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
30 changes: 18 additions & 12 deletions third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,8 @@ absl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor,
absl::Status GpuCommandBuffer::LaunchWithPackedArgs(
const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel,
const KernelArgsPackedArrayBase& packed_args) {
ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope];
ExecutionScopeId execution_scope_id = kDefaulExecutionScope;
ExecutionScope& execution_scope = execution_scopes_[execution_scope_id];

CHECK_EQ(kernel.Arity() + (packed_args.number_of_shared_bytes() > 0),
packed_args.number_of_arguments());
Expand All @@ -503,7 +504,7 @@ absl::Status GpuCommandBuffer::LaunchWithPackedArgs(

// Adds a new kernel node to the graph under construction.
if (state_ == State::kCreate) {
Dependencies barrier = GetBarrier();
Dependencies barrier = GetBarrier(execution_scope_id);
GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back();
return GpuDriver::GraphAddKernelNode(
&node_info.handle, graph_, barrier, kernel.name(), gpu_func, blocks.x,
Expand Down Expand Up @@ -553,15 +554,16 @@ absl::Status GpuCommandBuffer::Launch(const ThreadDim& threads,

absl::Status GpuCommandBuffer::AddNestedCommandBuffer(
const CommandBuffer& nested) {
ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope];
ExecutionScopeId execution_scope_id = kDefaulExecutionScope;
ExecutionScope& execution_scope = execution_scopes_[execution_scope_id];

TF_RETURN_IF_ERROR(CheckNotFinalized());

GpuGraphHandle child_graph = GpuCommandBuffer::Cast(&nested)->graph();

// Adds a child graph node to the graph under construction.
if (state_ == State::kCreate) {
Dependencies barrier = GetBarrier();
Dependencies barrier = GetBarrier(execution_scope_id);
GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back();
return GpuDriver::GraphAddChildNode(&node_info.handle, graph_, barrier,
child_graph);
Expand All @@ -580,12 +582,13 @@ absl::Status GpuCommandBuffer::AddNestedCommandBuffer(
absl::Status GpuCommandBuffer::MemcpyDeviceToDevice(DeviceMemoryBase* dst,
const DeviceMemoryBase& src,
uint64_t size) {
ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope];
ExecutionScopeId execution_scope_id = kDefaulExecutionScope;
ExecutionScope& execution_scope = execution_scopes_[execution_scope_id];

TF_RETURN_IF_ERROR(CheckNotFinalized());

if (state_ == State::kCreate) {
Dependencies barrier = GetBarrier();
Dependencies barrier = GetBarrier(execution_scope_id);
GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back();
return GpuDriver::GraphAddMemcpyD2DNode(
parent_->gpu_context(), &node_info.handle, graph_, barrier,
Expand Down Expand Up @@ -631,13 +634,14 @@ absl::Status GpuCommandBuffer::Memset(ExecutionScopeId execution_scope_id,
}

absl::StatusOr<DeviceMemoryBase> GpuCommandBuffer::Allocate(size_t bytes) {
ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope];
ExecutionScopeId execution_scope_id = kDefaulExecutionScope;
ExecutionScope& execution_scope = execution_scopes_[execution_scope_id];

TF_RETURN_IF_ERROR(CheckNotFinalized());

// Adds a new memory allocation node to the graph under construction.
if (state_ == State::kCreate) {
Dependencies barrier = GetBarrier();
Dependencies barrier = GetBarrier(execution_scope_id);
GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back();

GpuDevicePtr ptr;
Expand Down Expand Up @@ -671,13 +675,14 @@ absl::StatusOr<DeviceMemoryBase> GpuCommandBuffer::Allocate(size_t bytes) {
}

absl::Status GpuCommandBuffer::Free(DeviceMemoryBase dst) {
ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope];
ExecutionScopeId execution_scope_id = kDefaulExecutionScope;
ExecutionScope& execution_scope = execution_scopes_[execution_scope_id];

TF_RETURN_IF_ERROR(CheckNotFinalized());

// Adds a new memfree node to the graph under construction.
if (state_ == State::kCreate) {
Dependencies barrier = GetBarrier();
Dependencies barrier = GetBarrier(execution_scope_id);
GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back();
GpuDevicePtr gpu_dptr = AsDevicePtr(dst);
TF_RETURN_IF_ERROR(GpuDriver::GraphAddMemFreeNode(&node_info.handle, graph_,
Expand Down Expand Up @@ -721,15 +726,16 @@ GpuCommandBuffer::CreateConditionalHandles(size_t num_handles) {
absl::StatusOr<std::vector<GpuGraphHandle>>
GpuCommandBuffer::CreateConditionalNodes(
ConditionType type, absl::Span<const GpuGraphConditionalHandle> handles) {
ExecutionScope& execution_scope = execution_scopes_[kDefaulExecutionScope];
ExecutionScopeId execution_scope_id = kDefaulExecutionScope;
ExecutionScope& execution_scope = execution_scopes_[execution_scope_id];

std::vector<GpuGraphHandle> conditional_graphs;

using ConditionalParams = GpuDriver::GpuGraphConditionalNodeParams;
using ConditionalResult = GpuDriver::GpuGraphConditionalNodeParams::Result;

for (GpuGraphConditionalHandle handle : handles) {
Dependencies barrier = GetBarrier();
Dependencies barrier = GetBarrier(execution_scope_id);
GpuGraphNodeInfo& node_info = execution_scope.nodes.emplace_back();

ConditionalParams params;
Expand Down
2 changes: 0 additions & 2 deletions third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ class GpuCommandBuffer : public CommandBuffer {
absl::Span<const ConditionBuilder> builders);

Dependencies GetBarrier(ExecutionScopeId execution_scope_id);
// TODO(ezhulenev): Remove this once all commands migrated to scopes.
Dependencies GetBarrier() { return GetBarrier(kDefaulExecutionScope); }

// Returns loaded auxiliary kernels, or loads them on a given stream executor.
// Loaded kernels owned by a current command buffer.
Expand Down

0 comments on commit 8e4418e

Please sign in to comment.