Skip to content

Commit

Permalink
[XLA:TPU] Support output streaming and refactor TryOutputStreaming in…
Browse files Browse the repository at this point in the history
…to a bottoms-up approach.

Previously, output streaming took a top-down approach which indiscriminately checks if a MoveToHost custom call would trace down to an output marked with host memory space. This did not work when a dynamic-update-slice existed between the MTH call and the output. This CL fixes this problem by handling output streaming before other MTH calls, while also improving efficiency with the bottoms-up approach so we only trace a single path in the graph.

PiperOrigin-RevId: 630885979
  • Loading branch information
jvstokes authored and tensorflower-gardener committed May 9, 2024
1 parent 4f025f0 commit b3d4af5
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 62 deletions.
13 changes: 13 additions & 0 deletions tensorflow/core/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ load(
load("//tensorflow:tensorflow.default.bzl", "filegroup", "tf_cuda_cc_test", "tf_generate_proto_text_sources")
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_jspb_proto_library",
"tf_proto_library",
"tf_pyclif_proto_library",
)
Expand Down Expand Up @@ -1711,6 +1712,12 @@ tf_proto_library(
],
)

tf_jspb_proto_library(
name = "tensor_jspb_proto",
visibility = ["//visibility:public"],
deps = [":tensor_proto"],
)

tf_proto_library(
name = "api_def_proto",
srcs = ["api_def.proto"],
Expand Down Expand Up @@ -1761,6 +1768,12 @@ tf_proto_library(
make_default_target_header_only = True,
)

tf_jspb_proto_library(
name = "types_jspb_proto",
visibility = ["//visibility:public"],
deps = [":types_proto"],
)

tf_proto_library(
name = "cost_graph_proto",
srcs = ["cost_graph.proto"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ Shape ResolveShapeIndex(const xla::ShapeProto& shape_proto,
// Choosing the last subshape to maintain historical behavior.
int64_t i = shape_index.back();
if (i >= shape_proto.tuple_shapes_size()) {
LOG(WARNING) << "shape_index out of tuple_shapes range.";
return Shape(shape_proto);
}
return Shape(shape_proto.tuple_shapes(i));
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/go/op/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -19563,6 +19563,10 @@ func GatherV2BatchDims(value int64) GatherV2Attr {
// On GPU, if an out of bound index is found, a 0 is stored in the
// corresponding output value.
//
// Note that on TPU, if any dimension of `params` is of size 0 then the output will
// be the expected shape filled with zeros. On CPU and GPU an error will be
// returned.
//
// See also `tf.batch_gather` and `tf.gather_nd`.
//
// Arguments:
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6063,6 +6063,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4186,7 +4186,6 @@ cc_library(
hdrs = ["gpu_schedule_postprocessing.h"],
deps = [
":backend_configs_cc",
"//xla:statusor",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/service:hlo_pass",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ absl::StatusOr<bool> IsRelevantAsynchronousStart(const HloInstruction* hlo) {
}
TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
hlo->backend_config<GpuBackendConfig>());
CollectiveBackendConfig collective_backend_config =
const CollectiveBackendConfig& collective_backend_config =
gpu_config.collective_backend_config();
return !collective_backend_config.is_sync();
}
Expand Down Expand Up @@ -96,7 +96,8 @@ absl::StatusOr<bool> ProcessComputation(
// attribute no_parallel_custom_call to true. When we see a custom-call, clear
// the start ops from the collection and keep their attribute
// no_parallel_custom_call as false.
const std::vector<HloInstruction*> all_instructions = sequence.instructions();
const std::vector<HloInstruction*>& all_instructions =
sequence.instructions();
for (HloInstruction* hlo : all_instructions) {
if (MayInvokeCustomCall(hlo, custom_call_in_computation)) {
async_starts.clear();
Expand Down
127 changes: 69 additions & 58 deletions third_party/xla/xla/service/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

namespace xla {
Expand Down Expand Up @@ -215,67 +216,41 @@ std::optional<HloInstruction*> FindDUSFromAnnotation(HloInstruction* hlo) {
return std::nullopt;
}

} // namespace

absl::StatusOr<bool> HostOffloader::TryOutputStreaming(
HloInstruction* custom_call) {
// Check if this custom call traces down to a dynamic-update-slice. If so, we
// must use HloAliasAnalysis on the buffer of that dynamic-update-slice.
std::optional<HloInstruction*> dus = FindDUSFromAnnotation(custom_call);
const HloBuffer& unique_buffer =
alias_analysis_->GetUniqueBufferAt(dus.value_or(custom_call));
bool is_used_as_output_with_host_memory_space = false;
const HloComputation* const entry_computation =
custom_call->GetModule()->entry_computation();
for (const HloValue* value : unique_buffer.values()) {
// Check if this is memory-only.
if (!AllPositionsAreAllowed(value)) {
// Found a position which is not allowed.
return false;
}

// Look for a value used as a output.
for (const auto& position : value->positions()) {
const HloInstruction* instruction = position.instruction;
const ShapeIndex& index = position.index;
if (instruction->parent() == entry_computation && instruction->IsRoot()) {
const Shape& output_shape =
ShapeUtil::GetSubshape(entry_computation->parent()
->entry_computation_layout()
.result_shape(),
index);
CHECK(output_shape.has_layout());

if (output_shape.layout().memory_space() != kHostMemorySpaceColor) {
return FailedPrecondition(
"Output buffer is annotated with %s but is not marked with host "
"memory space in the entry computation.",
custom_call->name());
}
is_used_as_output_with_host_memory_space = true;
}
// Starting from a dynamic-update-slice, trace the graph up reshapes,
// bitcasts and reduces to return the MoveToHost custom call that feeds into the
// DUS, if it exists. If no MoveToHost call is found, returns an empty optional.
std::optional<HloInstruction*> FindAnnotationFromDUS(HloInstruction* hlo) {
CHECK(hlo->opcode() == HloOpcode::kDynamicUpdateSlice)
<< "Expected a dynamic-update-slice as input.";
// We expect the custom call to come from the written slice, i.e. operand 1.
hlo = hlo->mutable_operand(1);
while (!hlo->IsCustomCall(kMoveToHostCustomCallTarget)) {
if (!CanTraverseOpBetweenAnnotation(hlo)) {
break;
}
hlo = hlo->mutable_operand(0);
}
if (!is_used_as_output_with_host_memory_space) {
VLOG(1) << "Buffer annotated by " << custom_call->name()
<< " is not used as an output with host memory space.";
return false;
if (hlo->IsCustomCall(kMoveToHostCustomCallTarget)) {
return hlo;
}

VLOG(3) << "Found an output buffer annotated with " << custom_call->name()
<< ". Expecting that we'll need to insert copies.";

annotations_for_copy_to_host_to_insert_.emplace(custom_call);
AddAllPositionsToBeMovedToHostMemory(unique_buffer);
return true;
return std::nullopt;
}

} // namespace

Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) {
VLOG(2) << "Found a custom call annotating start-of-host-offload: "
<< custom_call->ToString();
// Save a pointer to this custom call for when we want to remove it later.
custom_calls_to_remove_.emplace(custom_call);

// Skip this custom call if we've already handled it in output streaming.
if (annotations_for_copy_to_host_to_insert_.contains(custom_call)) {
VLOG(4) << "Skipping MoveToHost custom call that was already handled: "
<< custom_call->name();
return OkStatus();
}

// We expect that either the custom call is the root or the DUS is the only
// user of this custom call.
if (!custom_call->IsRoot() && custom_call->user_count() != 1) {
Expand Down Expand Up @@ -305,11 +280,7 @@ Status HostOffloader::HandleMoveToHostCustomCall(HloInstruction* custom_call) {
} else if (consumer != nullptr && consumer->opcode() == HloOpcode::kCopy) {
TF_RETURN_IF_ERROR(MemoryOnlyOffloadStartingWithCopy(consumer));
} else {
TF_ASSIGN_OR_RETURN(bool did_output_streaming,
TryOutputStreaming(custom_call));
if (!did_output_streaming) {
TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call));
}
TF_RETURN_IF_ERROR(MemoryOnlyOffloadInsertCopies(custom_call));
}
return OkStatus();
}
Expand Down Expand Up @@ -668,6 +639,24 @@ Status HostOffloader::CreateCopyForInputStreaming(HloInstruction* custom_call) {
Status HostOffloader::HandleStreamedBuffer(const HloBuffer& unique_buffer) {
// Find all move-to-device custom calls that are using this buffer.
for (const HloValue* value : unique_buffer.values()) {
// First, handle the defining instruction of this buffer, as a potential
// move-to-host custom call.
if (value->defining_instruction()->IsCustomCall(
kMoveToHostCustomCallTarget)) {
annotations_for_copy_to_host_to_insert_.emplace(
value->defining_instruction());
AddAllPositionsToBeMovedToHostMemory(unique_buffer);
} else if (value->defining_instruction()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
std::optional<HloInstruction*> dus =
FindAnnotationFromDUS(value->defining_instruction());
if (dus.has_value()) {
annotations_for_copy_to_host_to_insert_.emplace(dus.value());
AddAllPositionsToBeMovedToHostMemory(unique_buffer);
}
}
// Next, handle uses of this buffer as potential move-to-device custom
// calls.
for (const HloUse& use : value->GetUses()) {
if (use.instruction->IsCustomCall(
host_memory_offload_annotations::kMoveToDeviceCustomCallTarget)) {
Expand Down Expand Up @@ -745,6 +734,28 @@ Status HostOffloader::HandleInputStreaming(HloComputation* computation) {
return OkStatus();
}

// Starts from the result of the entry computation and looks for a case of
// output streaming. This function will not change any hlo, it will only mark
// instructions to be converted to host memory space.
Status HostOffloader::HandleOutputStreaming(HloComputation* computation) {
const ComputationLayout& entry_computation_layout =
computation->parent()->entry_computation_layout();

ShapeUtil::ForEachSubshape(
entry_computation_layout.result_shape(),
[&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.has_layout() &&
subshape.layout().memory_space() == kHostMemorySpaceColor) {
VLOG(4) << "Handling streamed element in result with shape: "
<< subshape.ToString(true);
const HloBuffer& unique_buffer = alias_analysis_->GetUniqueBufferAt(
computation->root_instruction(), {index});
TF_CHECK_OK(HandleStreamedBuffer(unique_buffer));
}
});
return OkStatus();
}

absl::StatusOr<bool> HostOffloader::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand All @@ -755,12 +766,12 @@ absl::StatusOr<bool> HostOffloader::Run(
// Run HloAliasAnalysis on module.
TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module));

TF_RETURN_IF_ERROR(HandleInputStreaming(module->entry_computation()));
TF_RETURN_IF_ERROR(HandleOutputStreaming(module->entry_computation()));

// Iterate over all instructions and look for XLA host offload annotations.
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
if (computation->IsEntryComputation()) {
TF_RETURN_IF_ERROR(HandleInputStreaming(computation));
}
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kCustomCall) {
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class HostOffloader : public HloModulePass {
// Process streamed inputs for the given computation, finding the relevant
// move-to-device custom calls and inserting the appropriate copies.
Status HandleInputStreaming(HloComputation* computation);
// Process streamed outputs for the given computation, finding the relevant
// move-to-host custom calls and inserting the appropriate copies.
Status HandleOutputStreaming(HloComputation* computation);
// From a unique buffer on host memory, finds move-to-device custom calls
// for this buffer and inserts the appropriate copies.
Status HandleStreamedBuffer(const HloBuffer& unique_buffer);
Expand Down
71 changes: 71 additions & 0 deletions third_party/xla/xla/service/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,77 @@ TEST_F(HostOffloaderTest, OutputStreaming) {
EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get()));
}

TEST_F(HostOffloaderTest, OutputStreamingWithoutTuple) {
const std::string& hlo_string = R"(
HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->s32[2,1]{1,0:T(2,128)S(5)}}
ENTRY main {
param_0 = s32[2,1]{1,0} parameter(0)
param_1 = s32[2,1]{1,0} parameter(1)
constant_2 = s32[] constant(2)
constant_4 = s32[] constant(4)
broadcast_0 = s32[2,1]{1,0} broadcast(constant_2), dimensions={}
multiply_0 = s32[2,1]{1,0} multiply(param_1, broadcast_0)
multiply_1 = s32[2,1]{1,0} multiply(multiply_0, param_0)
broadcast_1 = s32[2,1]{1,0} broadcast(constant_4), dimensions={}
multiply_2 = s32[2,1]{1,0} multiply(multiply_1, broadcast_1)
ROOT custom_call = s32[2,1]{1,0} custom-call(multiply_2), custom_call_target="MoveToHost"
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));

EXPECT_TRUE(changed);

// Look for the following pattern:
// constant
// |
// param1 broadcast param0
// \ / /
// multiply /
// \ /
// \ /
// multiply constant
// | | |
// | ---+---broadcast
// | /
// multiply
// |
// copy

HloInstruction* param_1;
HloInstruction* broadcast_0;
HloInstruction* multiply_0;
HloInstruction* param_0;
HloInstruction* multiply_1;
HloInstruction* broadcast_1;
HloInstruction* multiply_2;
HloInstruction* copy;
auto multiplyPattern =
m::Multiply(&multiply_1,
m::Multiply(&multiply_0, m::Parameter(&param_1),
m::Broadcast(&broadcast_0, m::ConstantScalar(2))),
m::Parameter(&param_0));
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(m::Copy(
&copy, m::Multiply(&multiply_2, multiplyPattern,
m::Broadcast(&broadcast_1,
m::ConstantScalar(4))))));
TestShapeHasMemorySpace(param_1->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(broadcast_0->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(multiply_0->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(param_0->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(multiply_1->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(broadcast_1->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(multiply_2->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(copy->shape(), kHostMemorySpaceColor);

EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get()));
}

TEST_F(HostOffloaderTest, OutputStreamingCustomCallRoot) {
const std::string& hlo_string = R"(
HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->s32[2,1]{1,0:T(2,128)S(5)}}
Expand Down

0 comments on commit b3d4af5

Please sign in to comment.