diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 64d2da499db04b..2985835a8e1ba6 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3970,3 +3970,33 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "conditional_to_select", + srcs = ["conditional_to_select.cc"], + hdrs = ["conditional_to_select.h"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "conditional_to_select_test", + srcs = ["conditional_to_select_test.cc"], + deps = [ + ":conditional_to_select", + ":hlo", + ":hlo_matchers", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/memory", + ], +) diff --git a/tensorflow/compiler/xla/service/conditional_to_select.cc b/tensorflow/compiler/xla/service/conditional_to_select.cc new file mode 100644 index 00000000000000..d9b246bd6284bb --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_to_select.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_to_select.h" + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +static StatusOr DoConditionalToSelect(HloInstruction* conditional) { + // Only allow conditional to select if the called computations + // do not have side effects. + if (conditional->true_computation()->HasSideEffect() || + conditional->false_computation()->HasSideEffect()) { + VLOG(1) << "Not transforming conditional; branches have side effects:" + << conditional->ToString(); + return false; + } + + auto computation = conditional->parent(); + + // Create new instructions + HloInstruction* if_call_op = + computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(1)}, + conditional->true_computation())); + conditional->SetupDerivedInstruction(if_call_op); + HloInstruction* else_call_op = + computation->AddInstruction(HloInstruction::CreateCall( + conditional->shape(), {conditional->mutable_operand(2)}, + conditional->false_computation())); + conditional->SetupDerivedInstruction(else_call_op); + HloInstruction* condition = conditional->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * select_op, + MakeSelectHlo(condition, if_call_op, else_call_op, conditional)); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, select_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(if_call_op).status()); + TF_RETURN_IF_ERROR(CallInliner::Inline(else_call_op).status()); + return true; +} + +StatusOr ConditionalToSelect::Run(HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + bool did_mutate = false; + VLOG(1) << "Running conditional-to-select pass"; + TF_RETURN_IF_ERROR( + call_graph->VisitNodes([&](const CallGraphNode& node) -> Status { + std::vector ToInline; + if (node.context() != CallContext::kParallel) { + return Status::OK(); + } + for (const CallSite& callsite : node.callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + VLOG(1) << "Visiting conditional: " << callsite.ToString(); + HloInstruction* conditional = callsite.instruction(); + TF_ASSIGN_OR_RETURN(bool result, + DoConditionalToSelect(conditional)); + did_mutate |= result; + } + } + return Status::OK(); + })); + return did_mutate; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_to_select.h b/tensorflow/compiler/xla/service/conditional_to_select.h new file mode 100644 index 00000000000000..3b99e8192e8823 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_to_select.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which transforms conditionals to selects in places where conditionals +// are legal, but not currently supported by the backends (e.g. inside kMap) +class ConditionalToSelect : public HloModulePass { + public: + ~ConditionalToSelect() override = default; + absl::string_view name() const override { return "conditional-to-select"; } + + // Run conditional to select on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_TO_SELECT_H_ diff --git a/tensorflow/compiler/xla/service/conditional_to_select_test.cc b/tensorflow/compiler/xla/service/conditional_to_select_test.cc new file mode 100644 index 00000000000000..c0c90e07453556 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_to_select_test.cc @@ -0,0 +1,189 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_to_select.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using ConditionalToSelectTest = HloTestBase; +using ::testing::_; + +// Test that a conditional of simple constants is transformed to a select +TEST_F(ConditionalToSelectTest, MapConditionalConstants) { + const string hlo_text = R"( +HloModule MapConditionalConstants + +if { + %pif = () parameter(0) + ROOT %cif = f32[] constant(0) +} + +else { + %pelse = () parameter(0) + ROOT %celse = f32[] constant(1) +} + +mapped { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + %lt = pred[] less-than(%a, %b) + %t = () tuple() + ROOT %conditional = f32[] conditional(%lt, %t, %t), true_computation=if, false_computation=else +} + +ENTRY comp { + %p1 = f32[1000]{0} parameter(0) + %p2 = f32[1000]{0} parameter(1) + ROOT %mapped = f32[1000]{0} map(%p1, %p2), dimensions={0}, to_apply=mapped +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + ConditionalToSelect pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_EQ(root->opcode(), HloOpcode::kMap); + HloComputation* mapped = root->called_computations()[0]; + EXPECT_THAT(mapped->root_instruction(), + op::Select(op::Lt(op::Parameter(0), op::Parameter(1)), + op::Constant(), op::Constant())); +} + +// Test that the condition gets broadcasted for feeding into +// select when the output is non-scalar. +TEST_F(ConditionalToSelectTest, MapConditionalNonScalar) { + const string hlo_text = R"( +HloModule MapConditionalNonScalar + +if { + %pif = () parameter(0) + %zero = f32[] constant(0) + ROOT %zero_broadcasted = f32[2,2]{1,0} broadcast(%zero), dimensions={} +} + +else { + %pelse = () parameter(0) + %one = f32[] constant(0) + ROOT %one_broadcasted = f32[2,2]{1,0} broadcast(%one), dimensions={} +} + +add { + %add_lhs = f32[] parameter(0) + %add_rhs = f32[] parameter(1) + ROOT %add = f32[] add(%add_lhs, %add_rhs) +} + +mapped { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + %lt = pred[] less-than(%a, %b) + %t = () tuple() + %conditional = f32[2,2]{1,0} conditional(%lt, %t, %t), true_computation=if, false_computation=else + %zero = f32[] constant(0) + ROOT %reduced = f32[] reduce(%conditional, %zero), dimensions={0,1}, to_apply=add +} + +ENTRY comp { + %p1 = f32[1000]{0} parameter(0) + %p2 = f32[1000]{0} parameter(1) + ROOT %mapped = f32[1000]{0} map(%p1, %p2), dimensions={0}, to_apply=mapped +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + ConditionalToSelect pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_EQ(root->opcode(), HloOpcode::kMap); + HloComputation* mapped = root->called_computations()[0]; + EXPECT_THAT( + mapped->root_instruction(), + op::Reduce( + op::Select(op::Broadcast(op::Lt(op::Parameter(0), op::Parameter(1))), + _, _), + _)); +} + +// Test that conditionals of tuple type get turned into kTupleSelect +TEST_F(ConditionalToSelectTest, MapConditionalTuples) { + const string hlo_text = R"( +HloModule MapConditionalTuples + +if { + %pif = () parameter(0) + %zero = f32[] constant(0) + ROOT %tup = (f32[],f32[]) tuple(%zero, %zero) +} + +else { + %pelse = () parameter(0) + %one = f32[] constant(0) + ROOT %tup = (f32[],f32[]) tuple(%one, %one) +} + +add { + %add_lhs = f32[] parameter(0) + %add_rhs = f32[] parameter(1) + ROOT %add = f32[] add(%add_lhs, %add_rhs) +} + +mapped { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + %lt = pred[] less-than(%a, %b) + %t = () tuple() + %conditional = (f32[], f32[]) conditional(%lt, %t, %t), true_computation=if, false_computation=else + %el1 = f32[] get-tuple-element(%conditional), index=0 + %el2 = f32[] get-tuple-element(%conditional), index=1 + %reduced = f32[] add(%el1, %el2) +} + +ENTRY comp { + %p1 = f32[1000]{0} parameter(0) + %p2 = f32[1000]{0} parameter(1) + ROOT %mapped = f32[1000]{0} map(%p1, %p2), dimensions={0}, to_apply=mapped +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).ValueOrDie(); + ConditionalToSelect pass; + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_EQ(root->opcode(), HloOpcode::kMap); + HloComputation* mapped = root->called_computations()[0]; + EXPECT_THAT(mapped->root_instruction(), + op::Add(op::GetTupleElement(op::TupleSelect(_, _, _)), _)); +} +} +} diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 66ceb57227cf20..e8a6ef3fb4acb5 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -98,6 +98,7 @@ cc_library( "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", + "//tensorflow/compiler/xla/service:conditional_to_select", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 7de159cf647190..acb76038053d87 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/conditional_to_select.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" @@ -257,6 +258,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index d9c5f7c66de03a..e7568ea76ddfe6 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -274,15 +274,38 @@ StatusOr MakeReduceHlo(HloInstruction* operand, StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, - HloInstruction* on_false) { + HloInstruction* on_false, + HloInstruction* derived_from) { HloComputation* computation = pred->parent(); DCHECK_EQ(computation, on_true->parent()); DCHECK_EQ(computation, on_false->parent()); + Shape op_shape = on_true->shape(); + if (ShapeUtil::IsScalar(pred->shape())) { + if (!ShapeUtil::IsScalar(op_shape) && !ShapeUtil::IsTuple(op_shape)) { + // If the output is not scalar, we need to broadcast the condition + // to match the contract of kSelect. For tuples, we use kTupleSelect + // which expects the condition to be a scalar. + pred = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(op_shape, PrimitiveType::PRED), pred, + {})); + if (derived_from) { + derived_from->SetupDerivedInstruction(pred); + } + } + } + HloOpcode select_op_code = ShapeUtil::IsTuple(op_shape) + ? HloOpcode::kTupleSelect + : HloOpcode::kSelect; TF_ASSIGN_OR_RETURN(Shape select_shape, - ShapeInference::InferTernaryOpShape( - HloOpcode::kSelect, pred, on_true, on_false)); - return computation->AddInstruction(HloInstruction::CreateTernary( - select_shape, HloOpcode::kSelect, pred, on_true, on_false)); + ShapeInference::InferTernaryOpShape(select_op_code, pred, + on_true, on_false)); + HloInstruction* select = + computation->AddInstruction(HloInstruction::CreateTernary( + select_shape, select_op_code, pred, on_true, on_false)); + if (derived_from) { + derived_from->SetupDerivedInstruction(select); + } + return select; } StatusOr MakeSortHlo( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index f163112f7ff54b..61df5fb328fe20 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -124,10 +124,12 @@ StatusOr MakeReduceHlo(HloInstruction* operand, // Creates a Select HLO instruction and adds it to the computation containing // the predicate. The on_true and on_false instructions must also be contained -// in the same computation. +// in the same computation. If on_true and on_false are tuples, create a tuple +// select instead. `pred` is broadcasted up from a scalar if necessary. StatusOr MakeSelectHlo(HloInstruction* pred, HloInstruction* on_true, - HloInstruction* on_false); + HloInstruction* on_false, + HloInstruction* derived_from = nullptr); // Creates a Sort HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. Also creates a