Skip to content

Commit

Permalink
host_offloader: support unnested input parameter tuples
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625180012
  • Loading branch information
cota authored and tensorflower-gardener committed Apr 16, 2024
1 parent 5cc56a2 commit 9a575cb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
1 change: 1 addition & 0 deletions third_party/xla/xla/service/BUILD
Expand Up @@ -5961,6 +5961,7 @@ cc_library(
"//xla:literal_util",
"//xla:shape_util",
"//xla:status",
"//xla:status_macros",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
Expand Down
22 changes: 16 additions & 6 deletions third_party/xla/xla/service/host_offloader.cc
Expand Up @@ -40,6 +40,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -558,12 +559,21 @@ absl::StatusOr<bool> HostOffloader::TryParameterStreaming(
value->defining_position().instruction;
if (defining_instruction->opcode() == HloOpcode::kParameter) {
if (defining_instruction->parent() == entry_computation) {
const Shape& param_shape =
entry_computation->parent()
->entry_computation_layout()
.parameter_shape(defining_instruction->parameter_number());
CHECK(param_shape.has_layout());
if (param_shape.layout().memory_space() == kHostMemorySpaceColor) {
const Shape* param_shape =
&entry_computation->parent()
->entry_computation_layout()
.parameter_shape(defining_instruction->parameter_number());
if (param_shape->IsTuple()) {
// Fetch the memory space annotation from the tuple element's layout.
TF_RET_CHECK(value->index().size() == 1)
<< value->index().size()
<< " != 1: nested parameter tuples aren't supported";
int tuple_index = value->index()[0];
TF_RET_CHECK(tuple_index < param_shape->tuple_shapes_size());
param_shape = &param_shape->tuple_shapes(tuple_index);
}
TF_RET_CHECK(param_shape->has_layout());
if (param_shape->layout().memory_space() == kHostMemorySpaceColor) {
is_defined_by_entry_param_with_host_memory_space = true;
}
}
Expand Down
48 changes: 48 additions & 0 deletions third_party/xla/xla/service/host_offloader_test.cc
Expand Up @@ -1954,6 +1954,54 @@ ENTRY main {
EXPECT_FALSE(HaveRemainingOffloadAnnotations(module.get()));
}

TEST_F(HostOffloaderTest, TupleParameterStreaming) {
const std::string& hlo_string = R"(
HloModule ParameterStreaming, entry_computation_layout={((s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)S(5)}))->s32[2,1]{1,0:T(2,128)}}
ENTRY main {
param_tuple = (s32[2,1], s32[2,1]) parameter(0)
x = get-tuple-element(param_tuple), index=0
y_host = get-tuple-element(param_tuple), index=1
y = s32[2,1] custom-call(y_host), custom_call_target="MoveToDevice"
ROOT crs = s32[2,1] add(x, y)
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);

// Look for the following pattern:
// param0: tuple(x , y)
// / \
// get-tuple-element get-tuple-element
// \ |
// \ copy
// \ /
// add
HloInstruction* param;
HloInstruction* gte_x;
HloInstruction* gte_y;
HloInstruction* copy;
HloInstruction* add;
auto parameter_pattern = m::Parameter(&param, 0);
ASSERT_THAT(
module->entry_computation()->root_instruction(),
GmockMatch(m::Add(
&add, m::GetTupleElement(&gte_x, parameter_pattern),
m::Copy(&copy, m::GetTupleElement(&gte_y, parameter_pattern)))));
TestShapeHasMemorySpace(param->shape().tuple_shapes(0),
Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(gte_x->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(add->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(copy->shape(), Layout::kDefaultMemorySpace);
TestShapeHasMemorySpace(param->shape().tuple_shapes(1),
kHostMemorySpaceColor);
TestShapeHasMemorySpace(gte_y->shape(), kHostMemorySpaceColor);
}

TEST_F(HostOffloaderTest, OutputStreaming) {
const std::string& hlo_string = R"(
HloModule ParameterStreaming, entry_computation_layout={(s32[2,1]{1,0:T(2,128)}, s32[2,1]{1,0:T(2,128)})->(s32[2,1]{1,0:T(2,128)S(5)}, s32[2,1]{1,0:T(2,128)})}
Expand Down

0 comments on commit 9a575cb

Please sign in to comment.