Skip to content

Commit

Permalink
1. More comprehensive after-all handling. Specifically, the pass now …
Browse files Browse the repository at this point in the history
…handles after-alls with non-zero operands.

2. Add a case for merging two TupleOrToken labels. This case can arise when merging labels for multple conditional branch computations as they often return tuples.

Added tests for both these fixes.

PiperOrigin-RevId: 621254049
  • Loading branch information
tensorflower-gardener committed Apr 2, 2024
1 parent f1b19d0 commit e31a08d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
14 changes: 14 additions & 0 deletions third_party/xla/xla/service/hlo_value_semantics_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,13 @@ absl::Status EinsumDepthAnalysis::HandleCalledComputation(
}

absl::Status EinsumDepthAnalysis::HandleAfterAll(HloInstruction* after_all) {
auto depth_iter = GetDepthTreeOrDie(after_all);
const ShapeTree<int>& depth_tree = depth_iter->second;
int max_depth = GetMaxDepth(depth_tree);
for (HloInstruction* operand_token : after_all->mutable_operands()) {
CHECK(operand_token->shape().IsToken());
TF_RETURN_IF_ERROR(SetInstructionDepth(operand_token, max_depth));
}
return OkStatus();
}

Expand Down Expand Up @@ -1488,6 +1495,13 @@ HloValueSemanticsPropagation::MergeSemanticsForAnInstruction(
replace_operands_semantics_with(semantics);
continue;
}
if (operand_list[0].label() == HloValueSemanticLabel::kTupleOrToken &&
operand_list[1].label() == HloValueSemanticLabel::kTupleOrToken) {
HloValueSemantics semantics =
CopySemanticsWithNewOrigin(operand_list[0], instruction);
replace_operands_semantics_with(semantics);
continue;
}
LOG(FATAL) << "We don't expect to handle operands of label "
<< HloValueSemanticLabelToString(operand_list[0].label())
<< " and "
Expand Down
64 changes: 64 additions & 0 deletions third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ class HloValueSemanticsAnalysisTest : public HloTestBase {
return HasLabel(hlo_value_semantics_analysis, module, instruction_name,
HloValueSemanticLabel::kWeightGradient);
}
bool IsTupleOrToken(
const HloValueSemanticsAnalysis& hlo_value_semantics_analysis,
HloModule* module, absl::string_view instruction_name) {
return HasLabel(hlo_value_semantics_analysis, module, instruction_name,
HloValueSemanticLabel::kTupleOrToken);
}
};

TEST_F(HloValueSemanticsAnalysisTest, OneMatmul) {
Expand Down Expand Up @@ -244,6 +250,41 @@ ENTRY entry {
EXPECT_TRUE(IsWeight(*hlo_value_semantics_analysis, module.get(), "dot.2"));
}

TEST_F(HloValueSemanticsAnalysisTest, HandleConditional) {
const std::string module_str = R"(
HloModule Module
branch0 {
tparam = f32[4] parameter(0)
tgte1 = f32[4] ceil(tparam)
ROOT tuple = (f32[4], f32[4]) tuple(tparam, tgte1)
}
branch1 {
fparam = f32[4] parameter(0)
%async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo"
%async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start)
ROOT tuple = (f32[4], f32[4]) tuple(fparam, %async-done)
}
ENTRY entry {
p0 = f32[4] parameter(0)
b0 = s32[] parameter(1)
ROOT conditional = (f32[4], f32[4]) conditional(b0, p0, p0),
branch_computations={branch0, branch1}
}
)";

TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(module_str, /*replica_count=*/1,
/*num_partitions=*/2));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloValueSemanticsAnalysis> hlo_value_semantics_analysis,
HloValueSemanticsAnalysis::Run(*module));
EXPECT_TRUE(IsTupleOrToken(*hlo_value_semantics_analysis, module.get(),
"conditional"));
}

TEST_F(HloValueSemanticsAnalysisTest, TwoMatmuls) {
const std::string module_str = R"(
HloModule TwoMatmuls
Expand Down Expand Up @@ -622,6 +663,29 @@ TEST_F(EinsumDepthAnalysisTest, HandleConditional) {
0);
}

TEST_F(EinsumDepthAnalysisTest, HandleAfterAll) {
const char* const hlo_string = R"(
ENTRY entry {
after-all.1 = token[] after-all()
parameter.1 = f32[] parameter(0)
send.1 = (f32[], u32[], token[]) send(parameter.1, after-all.1), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"}
send-done.1 = token[] send-done(send.1), channel_id=1, is_host_transfer=true, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"}
ROOT after-all.2 = token[] after-all(send-done.1), frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous",_xla_host_transfer_rendezvous="rendezvous1"}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<EinsumDepthAnalysis> einsum_depth_analysis,
EinsumDepthAnalysis::Run(*module->entry_computation(),
SendRecvGroupMap(*module)));
const EinsumDepthMap& einsum_depth_map =
einsum_depth_analysis->GetEinsumDepthMap();
HloComputation* computation = module->GetComputationWithName("entry");
EXPECT_EQ(GetInstructionDepth(einsum_depth_map, computation, "after-all.2"),
0);
}

class EinsumHeightAnalysisTest : public HloTestBase {
public:
int GetInstructionHeight(const EinsumHeightMap& height_map,
Expand Down

0 comments on commit e31a08d

Please sign in to comment.