Skip to content

Commit

Permalink
Eliminate call to StreamExecutor::implementation() now that StreamExe…
Browse files Browse the repository at this point in the history
…cutor inherits from StreamExecutorInterface.

PiperOrigin-RevId: 626685442
  • Loading branch information
klucke authored and tensorflower-gardener committed Apr 20, 2024
1 parent 4858783 commit 083da14
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 33 deletions.
5 changes: 2 additions & 3 deletions tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.cc
Expand Up @@ -56,7 +56,7 @@ Status FlushProgramMemory(se::Platform* platform, int device_ordinal) {
tpu::TpuNodeContext::Create(device_ordinal));

auto* executor = tensorflow::down_cast<tpu::TpuExecutorInterface*>(
node_interfaces->stream_executor()->implementation());
node_interfaces->stream_executor());
return executor->UnloadAllPrograms();
}

Expand Down Expand Up @@ -214,8 +214,7 @@ absl::StatusOr<xla::ShapeTree<xla::MaybeOwningDeviceMemory>> BuildInputBuffers(
// Perform a compaction to reduce fragmentation.
Status PerformCompaction(stream_executor::Stream* stream) {
tsl::profiler::TraceMe trace_me("PerformCompaction", /*level=*/2);
auto* ds_executor =
down_cast<tpu::TpuExecutorInterface*>(stream->parent()->implementation());
auto* ds_executor = down_cast<tpu::TpuExecutorInterface*>(stream->parent());
TF_RETURN_IF_ERROR(ds_executor->EnqueueCompactionOnStreamForHbm(stream));
// LoadProgram and GetOrCreateConstantHandle are not managed by stream
// dependencies but they write to shared memory, so we need to block here to
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/command_buffer.cc
Expand Up @@ -32,7 +32,7 @@ namespace stream_executor {

absl::StatusOr<std::unique_ptr<CommandBuffer>> CommandBuffer::Create(
StreamExecutor* executor, Mode mode) {
return executor->implementation()->CreateCommandBuffer(mode);
return executor->CreateCommandBuffer(mode);
}

absl::StatusOr<std::unique_ptr<CommandBuffer>> CommandBuffer::Trace(
Expand Down
3 changes: 1 addition & 2 deletions third_party/xla/xla/stream_executor/event.cc
Expand Up @@ -26,8 +26,7 @@ namespace stream_executor {

Event::Event(StreamExecutor* stream_exec)
: stream_exec_(stream_exec),
implementation_(
stream_exec_->implementation()->CreateEventImplementation()) {}
implementation_(stream_exec_->CreateEventImplementation()) {}

Event::~Event() {
// Deal with nullptr implementation_, as this event may have been std::moved.
Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc
Expand Up @@ -887,7 +887,7 @@ absl::Status GpuCommandBuffer::If(ExecutionScopeId execution_scope_id,
StreamExecutor* executor,
DeviceMemory<bool> predicate,
Builder then_builder) {
DCHECK(executor->implementation() == parent_);
DCHECK(executor == parent_);

TF_ASSIGN_OR_RETURN(SetIfConditionKernel * set_if_condition,
GetSetIfConditionKernel(executor));
Expand All @@ -909,7 +909,7 @@ absl::Status GpuCommandBuffer::IfElse(ExecutionScopeId execution_scope_id,
DeviceMemory<bool> predicate,
Builder then_builder,
Builder else_builder) {
DCHECK(executor->implementation() == parent_);
DCHECK(executor == parent_);

TF_ASSIGN_OR_RETURN(SetIfElseConditionKernel * set_if_else_condition,
GetSetIfElseConditionKernel(executor));
Expand All @@ -931,7 +931,7 @@ absl::Status GpuCommandBuffer::Case(ExecutionScopeId execution_scope_id,
StreamExecutor* executor,
DeviceMemory<int32_t> index,
std::vector<Builder> branches) {
DCHECK(executor->implementation() == parent_);
DCHECK(executor == parent_);

// TODO(ezhulenev): Relax this constraint, we can launch multiple back to back
// kernels to update conditional handles in batches of size 8.
Expand Down Expand Up @@ -974,7 +974,7 @@ absl::Status GpuCommandBuffer::For(ExecutionScopeId execution_scope_id,
int32_t num_iteration,
DeviceMemory<int32_t> loop_counter,
Builder body_builder) {
DCHECK(executor->implementation() == parent_);
DCHECK(executor == parent_);

TF_ASSIGN_OR_RETURN(SetForConditionKernel * set_for_condition,
GetSetForConditionKernel(executor));
Expand Down Expand Up @@ -1009,7 +1009,7 @@ absl::Status GpuCommandBuffer::While(ExecutionScopeId execution_scope_id,
DeviceMemory<bool> pred,
ExecutionScopeBuilder cond_builder,
Builder body_builder) {
DCHECK(executor->implementation() == parent_);
DCHECK(executor == parent_);

TF_ASSIGN_OR_RETURN(SetWhileConditionKernel * set_while_condition,
GetSetWhileConditionKernel(executor));
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/gpu/gpu_executor.h
Expand Up @@ -445,7 +445,7 @@ class GpuExecutor : public StreamExecutor {
};

inline GpuExecutor* ExtractGpuExecutor(StreamExecutor* stream_exec) {
return static_cast<GpuExecutor*>(stream_exec->implementation());
return static_cast<GpuExecutor*>(stream_exec);
}

} // namespace gpu
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/kernel.cc
Expand Up @@ -55,7 +55,7 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) {

absl::StatusOr<std::unique_ptr<Kernel>> Kernel::Create(
StreamExecutor *executor, const MultiKernelLoaderSpec &spec) {
TF_ASSIGN_OR_RETURN(auto kernel, executor->implementation()->CreateKernel());
TF_ASSIGN_OR_RETURN(auto kernel, executor->CreateKernel());
TF_RETURN_IF_ERROR(executor->GetKernel(spec, kernel.get()));
return kernel;
}
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/stream.cc
Expand Up @@ -51,7 +51,7 @@ absl::Status Stream::Initialize(
return absl::InternalError(
"stream appears to already have been initialized");
}
implementation_ = parent_->implementation()->GetStreamImplementation();
implementation_ = parent_->GetStreamImplementation();
if (priority.has_value()) {
if (std::holds_alternative<StreamPriority>(*priority)) {
implementation_->SetPriority(std::get<StreamPriority>(*priority));
Expand Down
Expand Up @@ -70,8 +70,7 @@ class TpuCompiler : public Compiler {
StatusHelper status;
ExecutorApiFn()->TpuCompiler_RunHloPassesFn(
compiler_, &hlo_module,
static_cast<stream_executor::tpu::TpuExecutor*>(
executor->implementation())
static_cast<stream_executor::tpu::TpuExecutor*>(executor)
->se_executor(),
&allocator, &result, status.c_status);
if (!status.ok()) {
Expand Down Expand Up @@ -100,8 +99,7 @@ class TpuCompiler : public Compiler {
StatusHelper status;
ExecutorApiFn()->TpuCompiler_RunBackendFn(
compiler_, &hlo_module,
static_cast<stream_executor::tpu::TpuExecutor*>(
executor->implementation())
static_cast<stream_executor::tpu::TpuExecutor*>(executor)
->se_executor(),
&allocator, &result, status.c_status);
if (!status.ok()) {
Expand Down Expand Up @@ -141,9 +139,9 @@ class TpuCompiler : public Compiler {
se_lists_storage.emplace_back(stream_exec[i].size());
se_lists[i].exec = se_lists_storage.back().data();
for (int j = 0; j < stream_exec[i].size(); ++j) {
se_lists[i].exec[j] = static_cast<stream_executor::tpu::TpuExecutor*>(
stream_exec[i][j]->implementation())
->se_executor();
se_lists[i].exec[j] =
static_cast<stream_executor::tpu::TpuExecutor*>(stream_exec[i][j])
->se_executor();
}
}

Expand Down
21 changes: 9 additions & 12 deletions third_party/xla/xla/stream_executor/tpu/tpu_transfer_manager.cc
Expand Up @@ -110,8 +110,8 @@ absl::Status TpuTransferManager::TransferLiteralToInfeed(
StatusHelper status;
XLA_Literal c_literal;
ApiConverter::ToC(literal, &c_literal);
auto* tpu_executor = static_cast<stream_executor::tpu::TpuExecutor*>(
executor->implementation());
auto* tpu_executor =
static_cast<stream_executor::tpu::TpuExecutor*>(executor);

stream_executor::tpu::ExecutorApiFn()
->TpuTransferManager_TransferLiteralToInfeedFn(
Expand All @@ -126,8 +126,8 @@ absl::Status TpuTransferManager::TransferBuffersToInfeed(
se::StreamExecutor* executor,
const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) {
StatusHelper status;
auto* tpu_executor = static_cast<stream_executor::tpu::TpuExecutor*>(
executor->implementation());
auto* tpu_executor =
static_cast<stream_executor::tpu::TpuExecutor*>(executor);

std::vector<int64_t> buffers_size;
std::vector<uint32_t*> buffers_array;
Expand All @@ -154,8 +154,8 @@ absl::Status TpuTransferManager::TransferLiteralFromOutfeed(
StatusHelper status;
XLA_Shape c_shape;
XLA_Literal c_literal;
auto* tpu_executor = static_cast<stream_executor::tpu::TpuExecutor*>(
executor->implementation());
auto* tpu_executor =
static_cast<stream_executor::tpu::TpuExecutor*>(executor);

ApiConverter::ToC(literal.shape(), &c_shape);
ApiConverter::ToC(literal, &c_literal);
Expand All @@ -177,8 +177,7 @@ absl::Status TpuTransferManager::ResetDevices(
std::vector<SE_StreamExecutor*> se;
se.reserve(executor.size());
for (int64_t i = 0; i < executor.size(); ++i) {
se.push_back(static_cast<stream_executor::tpu::TpuExecutor*>(
executor[i]->implementation())
se.push_back(static_cast<stream_executor::tpu::TpuExecutor*>(executor[i])
->se_executor());
}

Expand Down Expand Up @@ -272,8 +271,7 @@ StatusOr<xla::Shape> TpuTransferManager::ChooseCompactLayoutForShape(
bool TpuTransferManager::CanShapedBufferBeAccessedNow(
stream_executor::StreamExecutor* executor,
const xla::ShapedBuffer& device_buffer) const {
auto* tpu_executor =
down_cast<stream_executor::tpu::TpuExecutor*>(executor->implementation());
auto* tpu_executor = down_cast<stream_executor::tpu::TpuExecutor*>(executor);
XLA_ShapedBuffer c_device_buffer;
ApiConverter::ToC(device_buffer, &c_device_buffer);
absl::Cleanup cleanup = [&c_device_buffer]() {
Expand All @@ -287,8 +285,7 @@ bool TpuTransferManager::CanShapedBufferBeAccessedNow(
bool TpuTransferManager::CanBufferBeAccessedNow(
se::StreamExecutor* executor,
const se::DeviceMemoryBase& device_buffer) const {
auto* tpu_executor =
down_cast<stream_executor::tpu::TpuExecutor*>(executor->implementation());
auto* tpu_executor = down_cast<stream_executor::tpu::TpuExecutor*>(executor);
SE_DeviceMemoryBase c_device_buffer{const_cast<void*>(device_buffer.opaque()),
device_buffer.size(),
device_buffer.payload()};
Expand Down

0 comments on commit 083da14

Please sign in to comment.