Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:TPU] Support output streaming and refactor TryOutputStreaming into a bottoms-up approach. #66989

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6063,6 +6063,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down
127 changes: 69 additions & 58 deletions third_party/xla/xla/service/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

namespace xla {
Expand Down Expand Up @@ -215,67 +216,41 @@ std::optional<HloInstruction*> FindDUSFromAnnotation(HloInstruction* hlo) {
return std::nullopt;
}

} // namespace

absl::StatusOr<bool> HostOffloader::TryOutputStreaming(
HloInstruction* custom_call) {
// Check if this custom call traces down to a dynamic-update-slice. If so, we
// must use HloAliasAnalysis on the buffer of that dynamic-update-slice.
std::optional<HloInstruction*> dus = FindDUSFromAnnotation(custom_call);
const HloBuffer& unique_buffer =
alias_analysis_->GetUniqueBufferAt(dus.value_or(custom_call));
bool is_used_as_output_with_host_memory_space = false;
const HloComputation* const entry_computation =
custom_call->GetModule()->entry_computation();
for (const HloValue* value : unique_buffer.values()) {
// Check if this is memory-only.
if (!AllPositionsAreAllowed(value)) {
// Found a position which is not allowed.
return false;
}

// Look for a value used as a output.
for (const auto& position : value->positions()) {
const HloInstruction* instruction = position.instruction;
const ShapeIndex& index = position.index;
if (instruction->parent() == entry_computation && instruction->IsRoot()) {
const Shape& output_shape =
ShapeUtil::GetSubshape(entry_computation->parent()
->entry_computation_layout()
.result_shape(),
index);
CHECK(output_shape.has_layout());

if (output_shape.layout().memory_space() != kHostMemorySpaceColor) {
return FailedPrecondition(
"Output buffer is annotated with %s but is not marked with host "
"memory space in the entry computation.",
custom_call->name());
}
is_used_as_output_with_host_memory_space = true;
}
// Starting from a dynamic-update-slice, trace the graph up reshapes,
// bitcasts and reduces to return the MoveToHost custom call that feeds into the
// DUS, if it exists. If no MoveToHost call is found, returns an empty optional.
std::optional<HloInstruction*> FindAnnotationFromDUS(HloInstruction* hlo) {
CHECK(hlo->opcode() == HloOpcode::kDynamicUpdateSlice)
<< "Expected a dynamic-update-slice as input.";
// We expect the custom call to come from the written slice, i.e. operand 1.
hlo = hlo->mutable_operand(1);
while (!hlo->IsCustomCall(kMoveToHostCustomCallTarget)) {
if (!CanTraverseOpBetweenAnnotation(hlo)) {
break;
}
hlo = hlo->mutable_operand(0);
}
if (!is_used_as_output_with_host_memory_space) {
VLOG(1) << "Buffer annotated by " << custom_call->name()
<< " is not used as an output with host memory space.";
return false;
if (hlo->IsCustomCall(kMoveToHostCustomCallTarget)) {
return hlo;
}

VLOG(3) << "Found an output buffer annotated with " << custom_call->name()
<< ". Expecting that we'll need to insert copies.";

annotations_for_copy_to_host_to_insert_.emplace(custom_call);
AddAllPositionsToBeMovedToHostMemory(unique_buffer);
return true;
return std::nullopt;
}

} // namespace

Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) {
VLOG(2) << "Found a custom call annotating start-of-host-offload: "
<< custom_call->ToString();
// Save a pointer to this custom call for when we want to remove it later.
custom_calls_to_remove_.emplace(custom_call);

// Skip this custom call if we've already handled it in output streaming.
if (annotations_for_copy_to_host_to_insert_.contains(custom_call)) {
VLOG(4) << "Skipping MoveToHost custom call that was already handled: "
<< custom_call->name();
return OkStatus();
}

// We expect that either the custom call is the root or the DUS is the only
// user of this custom call.
if (!custom_call->IsRoot() && custom_call->user_count() != 1) {
Expand Down Expand Up @@ -305,11 +280,7 @@ Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) {
} else if (consumer != nullptr && consumer->opcode() == HloOpcode::kCopy) {
TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithCopy(consumer));
} else {
TF_ASSIGN_OR_RETURN(bool did_output_streaming,
TryOutputStreaming(custom_call));
if (!did_output_streaming) {
TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call));
}
TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call));
}
return OkStatus();
}
Expand Down Expand Up @@ -668,6 +639,24 @@ Status HostOffloader::CreateCopyForInputStreaming(HloInstruction* custom_call) {
Status HostOffloader::HandleStreamedBuffer(const HloBuffer& unique_buffer) {
// Find all move-to-device custom calls that are using this buffer.
for (const HloValue* value : unique_buffer.values()) {
// First, handle the defining instruction of this buffer, as a potential
// move-to-host custom call.
if (value->defining_instruction()->IsCustomCall(
kMoveToHostCustomCallTarget)) {
annotations_for_copy_to_host_to_insert_.emplace(
value->defining_instruction());
AddAllPositionsToBeMovedToHostMemory(unique_buffer);
} else if (value->defining_instruction()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
std::optional<HloInstruction*> dus =
FindAnnotationFromDUS(value->defining_instruction());
if (dus.has_value()) {
annotations_for_copy_to_host_to_insert_.emplace(dus.value());
AddAllPositionsToBeMovedToHostMemory(unique_buffer);
}
}
// Next, handle uses of this buffer as potential move-to-device custom
// calls.
for (const HloUse& use : value->GetUses()) {
if (use.instruction->IsCustomCall(
host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
Expand Down Expand Up @@ -745,6 +734,28 @@ Status HostOffloader::HandleInputStreaming(HloComputation* computation) {
return OkStatus();
}

// Starts from the result of the entry computation and looks for a case of
// output streaming. This function will not change any hlo, it will only mark
// instructions to be converted to host memory space.
Status HostOffloader::HandleOutputStreaming(HloComputation* computation) {
const ComputationLayout& entry_computation_layout =
computation->parent()->entry_computation_layout();

ShapeUtil::ForEachSubshape(
entry_computation_layout.result_shape(),
[&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.has_layout() &&
subshape.layout().memory_space() == kHostMemorySpaceColor) {
VLOG(4) << "Handling streamed element in result with shape: "
<< subshape.ToString(true);
const HloBuffer& unique_buffer = alias_analysis_->GetUniqueBufferAt(
computation->root_instruction(), {index});
TF_CHECK_OK(HandleStreamedBuffer(unique_buffer));
}
});
return OkStatus();
}

absl::StatusOr<bool> HostOffloader::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand All @@ -755,12 +766,12 @@ absl::StatusOr<bool> HostOffloader::Run(
// Run HloAliasAnalysis on module.
TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module));

TF_RETURN_IF_ERROR(HandleInputStreaming(module->entry_computation()));
TF_RETURN_IF_ERROR(HandleOutputStreaming(module->entry_computation()));

// Iterate over all instructions and look for XLA host offload annotations.
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
if (computation->IsEntryComputation()) {
TF_RETURN_IF_ERROR(HandleInputStreaming(computation));
}
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kCustomCall) {
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class HostOffloader : public HloModulePass {
// Process streamed inputs for the given computation, finding the relevant
// move-to-device custom calls and inserting the appropriate copies.
Status HandleInputStreaming(HloComputation* computation);
// Process streamed outputs for the given computation, finding the relevant
// move-to-host custom calls and inserting the appropriate copies.
Status HandleOutputStreaming(HloComputation* computation);
// From a unique buffer on host memory, finds move-to-device custom calls
// for this buffer and inserts the appropriate copies.
Status HandleStreamedBuffer(const HloBuffer& unique_buffer);
Expand Down
71 changes: 71 additions & 0 deletions third_party/xla/xla/service/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,77 @@ TEST_F(HostOffloaderTest, OutputStreaming) {
EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get()));
}

TEST_F(HostOffloaderTest, OutputStreamingWithoutTuple) {
const std::string& hlo_string = R"(
HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->s32[2,1]{1,0:T(2,128)S(5)}}

ENTRY main {
param_0 = s32[2,1]{1,0} parameter(0)
param_1 = s32[2,1]{1,0} parameter(1)
constant_2 = s32[] constant(2)
constant_4 = s32[] constant(4)
broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={}
multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0)
multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0)
broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={}
multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1)
ROOT custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost"
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));

EXPECT_TRUE(changed);

// Look for the following pattern:
// constant
// |
// param1 broadcast param0
// \ / /
// multiply /
// \ /
// \ /
// multiply constant
// | | |
// | ---+---broadcast
// | /
// multiply
// |
// copy

HloInstruction* param_1;
HloInstruction* broadcast_0;
HloInstruction* multiply_0;
HloInstruction* param_0;
HloInstruction* multiply_1;
HloInstruction* broadcast_1;
HloInstruction* multiply_2;
HloInstruction* copy;
auto multiplyPattern =
m::Multiply(&multiply_1,
m::Multiply(&multiply_0, m::Parameter(&param_1),
m::Broadcast(&broadcast_0, m::ConstantScalar(2))),
m::Parameter(&param_0));
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Copy(
&copy, m::Multiply(&multiply_2, multiplyPattern,
m::Broadcast(&broadcast_1,
m::ConstantScalar(4))))));
TestShapeHasMemorySpace(param_1->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(broadcast_0->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(param_0->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(broadcast_1->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(copy->shape(), kHostMemorySpaceColor);

EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get()));
}

TEST_F(HostOffloaderTest, OutputStreamingCustomCallRoot) {
const std::string& hlo_string = R"(
HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->s32[2,1]{1,0:T(2,128)S(5)}}
Expand Down