Skip to content

Commit

Permalink
Add unbounded dynamism test for RecvOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621355707
  • Loading branch information
ghpvnist authored and tensorflower-gardener committed Apr 10, 2024
1 parent f810e02 commit fee0fd5
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 16 deletions.
87 changes: 87 additions & 0 deletions third_party/xla/xla/client/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2352,6 +2352,28 @@ TEST(XlaBuilderTest, UnboundedOr) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedOutfeed) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape shape_with_layout,
ParseShape("f32[?, 10]"));
Outfeed(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*shape_with_layout=*/shape_with_layout, /*outfeed_config=*/"");
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedOutfeedWithToken) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape shape_with_layout,
ParseShape("f32[?, 10]"));
OutfeedWithToken(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*token=*/CreateToken(&b),
/*shape_with_layout=*/shape_with_layout,
/*outfeed_config=*/"");
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedPad) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
Expand All @@ -2370,6 +2392,36 @@ TEST(XlaBuilderTest, UnboundedPad) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedRecv) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]"));
ChannelHandle handle;
handle.set_handle(1);
handle.set_type(ChannelHandle::DEVICE_TO_DEVICE);
Recv(/*builder=*/&b, /*shape=*/shape, /*handle=*/handle);
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedRecvFromHost) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]"));
ChannelHandle handle;
handle.set_handle(1);
handle.set_type(ChannelHandle::HOST_TO_DEVICE);
RecvFromHost(/*token=*/CreateToken(&b), /*shape=*/shape, /*handle=*/handle);
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedRecvWithToken) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape shape, ParseShape("f32[?, 10]"));
ChannelHandle handle;
handle.set_handle(1);
handle.set_type(ChannelHandle::DEVICE_TO_DEVICE);
RecvWithToken(/*token=*/CreateToken(&b), /*shape=*/shape, /*handle=*/handle);
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedReduce) {
XlaBuilder b(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {7}, {false});
Expand Down Expand Up @@ -2673,6 +2725,41 @@ TEST(XlaBuilderTest, UnboundedSelectAndScatter) {
GmockMatch(m::Op().WithShapeEqualTo(&expected)));
}

TEST(XlaBuilderTest, UnboundedSend) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
ChannelHandle handle;
handle.set_handle(1);
handle.set_type(ChannelHandle::DEVICE_TO_DEVICE);
Send(/*operand=*/Parameter(&b, 0, operand, "operand"), /*handle=*/handle);
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedSendToHost) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
TF_ASSERT_OK_AND_ASSIGN(const Shape shape_with_layout,
ParseShape("f32[?, 10]"));
ChannelHandle handle;
handle.set_handle(1);
handle.set_type(ChannelHandle::DEVICE_TO_HOST);
SendToHost(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*token=*/CreateToken(&b), /*shape_with_layout=*/shape_with_layout,
/*handle=*/handle);
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedSendWithToken) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[?, 10]"));
ChannelHandle handle;
handle.set_handle(1);
handle.set_type(ChannelHandle::DEVICE_TO_DEVICE);
SendWithToken(/*operand=*/Parameter(&b, 0, operand, "operand"),
/*token=*/CreateToken(&b), /*handle=*/handle);
EXPECT_IS_OK(BuildHloModule(b));
}

TEST(XlaBuilderTest, UnboundedSlice) {
XlaBuilder b(TestName());
TF_ASSERT_OK_AND_ASSIGN(const Shape operand, ParseShape("f32[1, <=3, ?]"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ absl::StatusOr<std::optional<int64_t>> GetOverriddenPreferredPrefetchTime(
}

bool DoesResultMatchFilter(const HloPositionMatcher& filter,
const ShapeIndex& index,
HloInstruction* instruction) {
const BufferInterval& buffer_interval) {
HloInstruction* instruction = buffer_interval.buffer->instruction();
if (filter.has_instruction_regex() &&
!RE2::FullMatch(instruction->ToString(), filter.instruction_regex())) {
return false;
Expand All @@ -478,8 +478,15 @@ bool DoesResultMatchFilter(const HloPositionMatcher& filter,
return false;
}
if (filter.has_tuple_index() &&
index != ShapeIndex(filter.tuple_index().index().begin(),
filter.tuple_index().index().end())) {
buffer_interval.buffer->index() !=
ShapeIndex(filter.tuple_index().index().begin(),
filter.tuple_index().index().end())) {
return false;
}
if (filter.has_size_gte() && filter.size_gte() > buffer_interval.size) {
return false;
}
if (filter.has_size_lte() && filter.size_lte() < buffer_interval.size) {
return false;
}
return true;
Expand All @@ -496,8 +503,7 @@ int64_t GetBufferIntervalOverridePriority(
for (int64_t i = 0; i < msa_sort_order_overrides.overrides_size(); ++i) {
const auto& override = msa_sort_order_overrides.overrides(i);
if (!DoesResultMatchFilter(override.hlo_position_matcher(),
buffer_interval.buffer->index(),
buffer_interval.buffer->instruction())) {
buffer_interval)) {
continue;
}
LOG(INFO) << "Override Sort Order Config " << i << " matches "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ message HloPositionMatcher {
// If output of an instruction is a tuple and indexing into the
// tuple is required.
optional TupleShapeIndex tuple_index = 3;
// Filters instructions with output size in bytes greater or equal to a value.
optional int64 size_gte = 4;
// Filters instructions with output size in bytes less or equal to a value.
optional int64 size_lte = 5;
}

// Options to override preferred prefetch time for an operand.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4546,8 +4546,6 @@ TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {

TEST_P(MemorySpaceAssignmentTest,
MemoryBoundednessOverrideSortOrderAssignFirst) {
// Override MSA sort order and try to assign all negates to alternate memory
// first.
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

Expand All @@ -4571,6 +4569,8 @@ TEST_P(MemorySpaceAssignmentTest,
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

// Override MSA sort order and try to assign all negates to alternate memory
// first. Alternate memory size is enough to fit 2 f32[4,3] tensors at a time.
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher { instruction_name_regex: "negate(.*)" }
Expand All @@ -4589,7 +4589,8 @@ TEST_P(MemorySpaceAssignmentTest,
EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* p1 = FindInstruction(module.get(), "p1");
EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace);
// All negates are in alternate memory space except negate4.
// Check that all negates are in alternate memory space except negate4.
// negate4 is a program output, so it has to land in default memory.
HloInstruction* negate0 = FindInstruction(module.get(), "negate0");
EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate1 = FindInstruction(module.get(), "negate1");
Expand All @@ -4614,8 +4615,6 @@ TEST_P(MemorySpaceAssignmentTest,

TEST_P(MemorySpaceAssignmentTest,
MemoryBoundednessOverrideSortOrderAssignLast) {
// Override MSA sort order and try to assign all negates to alternate memory
// last.
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

Expand All @@ -4639,9 +4638,11 @@ TEST_P(MemorySpaceAssignmentTest,
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

// Override MSA sort order and try to assign all tanhs to alternate memory
// last. Alternate memory size is enough to fit 2 f32[4,3] tensors at a time.
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher { instruction_name_regex: "negate(.*)" }
hlo_position_matcher { instruction_name_regex: "tanh(.*)" }
override_options { assign_last: true }
}
)pb";
Expand All @@ -4658,17 +4659,94 @@ TEST_P(MemorySpaceAssignmentTest,
EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* p1 = FindInstruction(module.get(), "p1");
EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace);
// All negates are in default memory space except negate3.
HloInstruction* negate0 = FindInstruction(module.get(), "negate0");
EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate1 = FindInstruction(module.get(), "negate1");
EXPECT_EQ(negate1->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate2 = FindInstruction(module.get(), "negate2");
EXPECT_EQ(negate2->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate3 = FindInstruction(module.get(), "negate3");
EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate4 = FindInstruction(module.get(), "negate4");
// negate4 is a program output, so it has to land in default memory.
EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace);
// Check that all tanhs are in default memory space.
const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0");
EXPECT_EQ(tanh0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1");
EXPECT_EQ(tanh1->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2");
EXPECT_EQ(tanh2->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3");
EXPECT_EQ(tanh3->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4");
EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace);
}

TEST_P(MemorySpaceAssignmentTest,
MemoryBoundednessOverrideSortOrderBySizeLteAssignFirst) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[3,4]{1,0} parameter(0)
p1 = f32[5,4]{1,0} parameter(1)
tanh0 = f32[3,4]{1,0} tanh(p0)
negate0 = f32[5,4]{1,0} negate(p1)
tanh1 = f32[3,4]{1,0} tanh(tanh0)
negate1 = f32[5,4]{1,0} negate(negate0)
tanh2 = f32[3,4]{1,0} tanh(tanh1)
negate2 = f32[5,4]{1,0} negate(negate1)
tanh3 = f32[3,4]{1,0} tanh(tanh2)
negate3 = f32[5,4]{1,0} negate(negate2)
tanh4 = f32[3,4]{1,0} tanh(tanh3)
negate4 = f32[5,4]{1,0} negate(negate3)
ROOT tuple = (f32[3,4]{1,0}, f32[5,4]{1,0}) tuple(tanh4, negate4)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

// Override MSA sort order and try to assign all buffers with size lesser
// than or equal to 48 bytes to alternate memory first.
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher { size_lte: 48 }
override_options { assign_first: true }
}
)pb";
TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));

Options memory_space_options = DefaultMemorySpaceOptions();
// Set max size to 120 bytes, such that 2 f32[4,3] tensors can fit in
// alternate memory at the same time but not 1 f32[4,3] tensor and 1
// f32[4,5] tensor. If the max size was 128 bytes, negate3 would be assigned
// to alternate memory.
memory_space_options.max_size_in_bytes = 120;
AssignMemorySpaceUsingCostAnalysis(
module.get(), memory_space_options,
/*cost_analysis_options_override=*/std::nullopt,
/*hlo_cost_options_override=*/std::nullopt,
/*optional_msa_sort_order_overrides=*/msa_sort_order_overrides);
// Parameters are in the default memory space.
const HloInstruction* p0 = FindInstruction(module.get(), "p0");
EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* p1 = FindInstruction(module.get(), "p1");
EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace);
HloInstruction* negate0 = FindInstruction(module.get(), "negate0");
EXPECT_EQ(negate0->shape().layout().memory_space(), kDefaultMemorySpace);
HloInstruction* negate1 = FindInstruction(module.get(), "negate1");
EXPECT_EQ(negate1->shape().layout().memory_space(), kDefaultMemorySpace);
HloInstruction* negate2 = FindInstruction(module.get(), "negate2");
EXPECT_EQ(negate2->shape().layout().memory_space(), kDefaultMemorySpace);
HloInstruction* negate3 = FindInstruction(module.get(), "negate3");
EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace);
EXPECT_EQ(negate3->shape().layout().memory_space(), kDefaultMemorySpace);
HloInstruction* negate4 = FindInstruction(module.get(), "negate4");
EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace);
// Check that all tanhs are in alternate memory space except tanh4. tanh4
// is a program output, so it has to land in default memory.
const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0");
EXPECT_EQ(tanh0->shape().layout().memory_space(), kAlternateMemorySpace);
const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1");
Expand All @@ -4681,6 +4759,82 @@ TEST_P(MemorySpaceAssignmentTest,
EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace);
}

TEST_P(MemorySpaceAssignmentTest,
MemoryBoundednessOverrideSortOrderBySizeGteAssignFirst) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[3,4]{1,0} parameter(0)
p1 = f32[5,4]{1,0} parameter(1)
tanh0 = f32[3,4]{1,0} tanh(p0)
negate0 = f32[5,4]{1,0} negate(p1)
tanh1 = f32[3,4]{1,0} tanh(tanh0)
negate1 = f32[5,4]{1,0} negate(negate0)
tanh2 = f32[3,4]{1,0} tanh(tanh1)
negate2 = f32[5,4]{1,0} negate(negate1)
tanh3 = f32[3,4]{1,0} tanh(tanh2)
negate3 = f32[5,4]{1,0} negate(negate2)
tanh4 = f32[3,4]{1,0} tanh(tanh3)
negate4 = f32[5,4]{1,0} negate(negate3)
ROOT tuple = (f32[3,4]{1,0}, f32[5,4]{1,0}) tuple(tanh4, negate4)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

// Override MSA sort order and try to assign all buffers with size greater
// than or equal to 80 bytes to alternate memory first.
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher { size_gte: 80 }
override_options { assign_first: true }
}
)pb";
TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));

Options memory_space_options = DefaultMemorySpaceOptions();
// Set max size to 160 bytes to allow 2 f32[4,5] tensors to fit in alternate
// memory at the same time. tanh3 would not be prefetched because negate2 and
// negate3 would be in alternate memory at the same time leaving no space for
// tanh3.
memory_space_options.max_size_in_bytes = 160;
AssignMemorySpaceUsingCostAnalysis(
module.get(), memory_space_options,
/*cost_analysis_options_override=*/std::nullopt,
/*hlo_cost_options_override=*/std::nullopt,
/*optional_msa_sort_order_overrides=*/msa_sort_order_overrides);
// Parameters are in the default memory space.
const HloInstruction* p0 = FindInstruction(module.get(), "p0");
EXPECT_EQ(p0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* p1 = FindInstruction(module.get(), "p1");
EXPECT_EQ(p1->shape().layout().memory_space(), kDefaultMemorySpace);
// Check that all negates are in alternate memory space except negate4.
// negate4 is a program output, so it has to land in default memory.
HloInstruction* negate0 = FindInstruction(module.get(), "negate0");
EXPECT_EQ(negate0->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate1 = FindInstruction(module.get(), "negate1");
EXPECT_EQ(negate1->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate2 = FindInstruction(module.get(), "negate2");
EXPECT_EQ(negate2->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate3 = FindInstruction(module.get(), "negate3");
EXPECT_EQ(negate3->shape().layout().memory_space(), kAlternateMemorySpace);
HloInstruction* negate4 = FindInstruction(module.get(), "negate4");
EXPECT_EQ(negate4->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh0 = FindInstruction(module.get(), "tanh0");
EXPECT_EQ(tanh0->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh1 = FindInstruction(module.get(), "tanh1");
EXPECT_EQ(tanh1->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh2 = FindInstruction(module.get(), "tanh2");
EXPECT_EQ(tanh2->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh3 = FindInstruction(module.get(), "tanh3");
EXPECT_EQ(tanh3->shape().layout().memory_space(), kDefaultMemorySpace);
const HloInstruction* tanh4 = FindInstruction(module.get(), "tanh4");
EXPECT_EQ(tanh4->shape().layout().memory_space(), kDefaultMemorySpace);
}

TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
Expand Down

0 comments on commit fee0fd5

Please sign in to comment.