diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index b81cc14459c609..b1d20bd7867ccf 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -96,9 +96,10 @@ namespace { using primitive_util::NativeTypeOf; template -StatusOr Compare(const Shape& shape, Comparison comparison, - LiteralSlice lhs_literal, LiteralSlice rhs_literal) { - auto populate = [&](auto compare_op) -> StatusOr { +absl::StatusOr Compare(const Shape& shape, Comparison comparison, + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { + auto populate = [&](auto compare_op) -> absl::StatusOr { Literal result(shape); TF_RETURN_IF_ERROR(result.PopulateParallel( [&](absl::Span multi_index, int /*thread_id*/) { @@ -147,7 +148,7 @@ StatusOr Compare(const Shape& shape, Comparison comparison, std::optional GetInstructionStaticValueAsBool( const HloInstruction* instruction) { HloEvaluator evaluator; - StatusOr static_value = evaluator.Evaluate( + absl::StatusOr static_value = evaluator.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { return static_value->GetFirstElement(); @@ -251,7 +252,7 @@ struct DynamicOrStaticInteger { std::optional GetInstructionValueAsInteger( const HloInstruction* instruction) { HloEvaluator evaluator; - StatusOr static_value = evaluator.Evaluate( + absl::StatusOr static_value = evaluator.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { if (instruction->shape().element_type() == PrimitiveType::PRED) { @@ -859,7 +860,7 @@ HloEvaluator::HloEvaluator(int64_t max_loop_iterations) }); } -StatusOr HloEvaluator::Evaluate( +absl::StatusOr HloEvaluator::Evaluate( const HloComputation& computation, absl::Span arg_literals) { CHECK(computation.parent() != nullptr); @@ -920,7 +921,7 @@ StatusOr HloEvaluator::Evaluate( return result.Clone(); } -StatusOr HloEvaluator::Evaluate( +absl::StatusOr HloEvaluator::Evaluate( const HloInstruction* instruction, bool recursively_evaluate_nonconstant_operands) { arg_literals_.clear(); @@ -955,7 +956,7 @@ bool HloEvaluator::TryEvaluate(const HloInstruction* instruction, return true; } -StatusOr HloEvaluator::EvaluateWithSubstitutions( +absl::StatusOr HloEvaluator::EvaluateWithSubstitutions( const HloInstruction* instruction, const absl::flat_hash_map& substitutions) { @@ -983,7 +984,7 @@ StatusOr HloEvaluator::EvaluateWithSubstitutions( return result; } -StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.Clone()); @@ -998,7 +999,7 @@ StatusOr HloEvaluator::EvaluateElementwiseBinaryOp( return result; } -StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( HloOpcode opcode, const Literal& lhs, const Literal& rhs, const Literal& ehs) { std::unique_ptr lhs_instr = @@ -1016,7 +1017,7 @@ StatusOr HloEvaluator::EvaluateElementwiseTernaryOp( return Evaluate(cloned_instruction.get()); } -StatusOr HloEvaluator::EvaluateElementwiseCompareOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseCompareOp( ComparisonDirection direction, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.Clone()); @@ -1032,7 +1033,7 @@ StatusOr HloEvaluator::EvaluateElementwiseCompareOp( return result; } -StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( +absl::StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( HloOpcode opcode, const Literal& operand) { std::unique_ptr operand_instr = HloInstruction::CreateConstant(operand.Clone()); @@ -1046,7 +1047,7 @@ StatusOr HloEvaluator::EvaluateElementwiseUnaryOp( return result; } -StatusOr HloEvaluator::EvaluateDotOp( +absl::StatusOr HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { @@ -1189,7 +1190,7 @@ Status HloEvaluator::EvaluateInternal( } if (!tuple_points_to_analysis_cache_) { HloModule* module = instruction->GetModule(); - StatusOr> + absl::StatusOr> tuple_points_to_analysis = TuplePointsToAnalysis::Run(module); if (tuple_points_to_analysis.ok()) { tuple_points_to_analysis_cache_ = @@ -2347,7 +2348,7 @@ class OutputBatchIndexToInputIndex { // same storage for all invocations. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span output_index) { PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index); TF_RETURN_IF_ERROR(FetchIndexVector()); @@ -2467,7 +2468,7 @@ class OutputOffsetIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span output_index) { PropagateOutputIndexWindowDimsToInputIndex(output_index); return absl::Span(input_index_); @@ -2507,9 +2508,9 @@ class OutputOffsetIndexToInputIndex { // Reshapes the gather indices input to have a trailing degenerate `1` dimension // if necessary. Hands over the ownership of the newly created literal (if // there is one) to `reshaped_start_indices`. -static StatusOr> ReshapedGatherIndices( - int64_t index_vector_dim, const Literal& start_indices, - Literal* reshaped_start_indices) { +static absl::StatusOr> +ReshapedGatherIndices(int64_t index_vector_dim, const Literal& start_indices, + Literal* reshaped_start_indices) { if (start_indices.shape().dimensions_size() != index_vector_dim) { return std::cref(start_indices); } @@ -2574,7 +2575,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) { auto gather_inner_loop_body = [&](absl::Span output_window_index, absl::Span input_gather_index, - absl::Span output_gather_index) -> StatusOr { + absl::Span output_gather_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_window_index, output_offset_index_to_input_index(output_window_index)); @@ -2608,7 +2610,8 @@ Status HloEvaluator::HandleGather(const HloInstruction* gather) { }; auto gather_outer_loop_body = - [&](absl::Span output_gather_index) -> StatusOr { + [&](absl::Span output_gather_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN(absl::Span input_gather_index, output_batch_index_to_input_index(output_gather_index)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( @@ -2628,7 +2631,7 @@ namespace { // Reshapes the scatter indices input to have a trailing degenerate `1` // dimension if necessary. Hands over the ownership of the newly created // literal (if there is one) to `reshaped_indices`. -StatusOr> ReshapedScatterIndices( +absl::StatusOr> ReshapedScatterIndices( int64_t index_vector_dim, const Literal& indices, Literal* reshaped_indices) { if (indices.shape().dimensions_size() != index_vector_dim) { @@ -2750,7 +2753,7 @@ class UpdateScatterIndexToInputIndex { // same storage for all invocations. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span update_index) { PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index); TF_RETURN_IF_ERROR(FetchIndexVector()); @@ -2873,7 +2876,7 @@ class UpdateWindowIndexToInputIndex { // result (input_index_), mutating it in place. // // This returns a Span into memory owned by the class. - StatusOr> operator()( + absl::StatusOr> operator()( absl::Span update_index) { PropagateUpdateIndexWindowDimsToInputIndex(update_index); return absl::Span(input_index_); @@ -2966,7 +2969,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) { auto scatter_inner_loop_body = [&](absl::Span update_window_index, absl::Span input_scatter_index, - absl::Span update_scatter_index) -> StatusOr { + absl::Span update_scatter_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_window_index, update_window_index_to_input_index(update_window_index)); @@ -3018,7 +3022,8 @@ Status HloEvaluator::HandleScatter(const HloInstruction* hlo) { }; auto scatter_outer_loop_body = - [&](absl::Span update_scatter_index) -> StatusOr { + [&](absl::Span update_scatter_index) + -> absl::StatusOr { TF_ASSIGN_OR_RETURN( absl::Span input_scatter_index, update_scatter_index_to_input_index(update_scatter_index)); @@ -3416,10 +3421,10 @@ Status HloEvaluator::HandleSelect(const HloInstruction* select) { namespace { -StatusOr CreateScalarLiteral(int64_t value, - PrimitiveType element_type) { +absl::StatusOr CreateScalarLiteral(int64_t value, + PrimitiveType element_type) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsIntegralType(primitive_type_constant)) { return LiteralUtil::CreateR0( static_cast>(value)); @@ -3432,7 +3437,7 @@ StatusOr CreateScalarLiteral(int64_t value, // Parses the while loop if it matches one of the known patterns. Returns the // value of the loop induction variable after the loop execution if the loop is // static. -StatusOr TryParseAndEvaluateWhileInductionVar( +absl::StatusOr TryParseAndEvaluateWhileInductionVar( const HloInstruction* while_hlo) { std::optional parsed_while_loop = PatternMatchParseWhileLoop(while_hlo); @@ -3507,7 +3512,7 @@ Status HloEvaluator::HandleWhile(const HloInstruction* while_hlo) { dynamic_dimension_inference_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - StatusOr result = + absl::StatusOr result = TryParseAndEvaluateWhileInductionVar(while_hlo); if (result.ok()) { lcv = std::move(result).value(); @@ -3546,11 +3551,11 @@ Literal ExtractLiteralFromIndexPositions(const Literal& from, return LiteralUtil::CreateR1(values); } -StatusOr ExtractFromIndexPositions(const Literal& from, - absl::Span indices) { +absl::StatusOr ExtractFromIndexPositions( + const Literal& from, absl::Span indices) { PrimitiveType type = from.shape().element_type(); return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { return ExtractLiteralFromIndexPositions< NativeTypeOf>(from, indices); @@ -3609,9 +3614,9 @@ void IterateThroughWindow( } template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { std::function stochastic_convert_op = [](Fp operand, Uint random) -> ResultT { bool is_negative = static_cast(Eigen::numext::signbit(operand)); @@ -3673,9 +3678,9 @@ StatusOr StochasticConvertOp(const Literal& operand_literal, // Converts from primitive types to native types. template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return StochasticConvertOp< typename primitive_util::PrimitiveTypeToNative::type, typename primitive_util::PrimitiveTypeToNative::type, @@ -3685,11 +3690,11 @@ StatusOr StochasticConvertOp(const Literal& operand_literal, // Evaluates all possible paths of converting to different integers. template -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsSignedIntegralType( primitive_type_constant)) { return StochasticConvertOp StochasticConvertOp(const Literal& operand_literal, result_shape.element_type()); } -StatusOr StochasticConvertOp(const Literal& operand_literal, - const Literal& random_literal, - const Shape& result_shape) { +absl::StatusOr StochasticConvertOp(const Literal& operand_literal, + const Literal& random_literal, + const Shape& result_shape) { return primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> StatusOr { + [&](auto primitive_type_constant) -> absl::StatusOr { if constexpr (primitive_util::IsFloatingPointType( primitive_type_constant)) { return StochasticConvertOp< @@ -3925,9 +3930,9 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { << " accessing increment of size " << increment.size(); increment[sort_dim] = sort_dim_elements; - auto comparator = [sort](absl::Span literals_to_sort, - int64_t a, int64_t b, - HloEvaluator* embedded_evaluator) -> StatusOr { + auto comparator = + [sort](absl::Span literals_to_sort, int64_t a, int64_t b, + HloEvaluator* embedded_evaluator) -> absl::StatusOr { absl::InlinedVector literals; literals.reserve(2 * sort->operand_count()); for (int64_t i = 0; i < sort->operand_count(); ++i) { @@ -3948,10 +3953,10 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { embedded_evaluator->ResetVisitStates(); return computed_result.Get({}); }; - auto less_than = [&comparator]( - absl::Span literals_to_sort, int64_t a, - int64_t b, - HloEvaluator* embedded_evaluator) -> StatusOr { + auto less_than = + [&comparator](absl::Span literals_to_sort, int64_t a, + int64_t b, + HloEvaluator* embedded_evaluator) -> absl::StatusOr { TF_ASSIGN_OR_RETURN(bool a_is_smaller, comparator(literals_to_sort, a, b, embedded_evaluator)); #ifndef NDEBUG @@ -4101,7 +4106,7 @@ Status HloEvaluator::HandleSort(const HloInstruction* sort) { // Iterate through each dimension except 'sort_dim'. TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( key_shape, zero_base, key_shape.dimensions(), increment, - [&](absl::Span indices) -> StatusOr { + [&](absl::Span indices) -> absl::StatusOr { // Extract a slice from each operand literal that corresponds to // exactly the row in dimension 'sort_dim'. std::vector limit_indices(indices.begin(), indices.end()); @@ -4186,7 +4191,7 @@ static bool IsScalarAdd(HloComputation* computation) { // the user-provided computation on the accumulator and the output element // (until the reduction is completed, the output element is also used as // an accumulator). -static StatusOr PerformReductionStep( +static absl::StatusOr PerformReductionStep( bool is_tuple, absl::Span input_index, absl::Span output_index, absl::Span input_args, absl::Span results, @@ -4236,7 +4241,7 @@ static StatusOr PerformReductionStep( return true; } -static StatusOr GenerateReduceOutputElement( +static absl::StatusOr GenerateReduceOutputElement( bool is_tuple, absl::Span output_index, absl::Span init_values, diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index 96004d73eea68d..ed6accfacf96e9 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -107,13 +107,13 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // (Dummy template arg is to reduce the overloading priority of one overload // so that Evaluate(module, {}) resolves unambiguously.) - StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals) { + absl::StatusOr Evaluate( + const HloModule& module, absl::Span arg_literals) { return Evaluate(*module.entry_computation(), arg_literals); } template - StatusOr Evaluate(const HloModule& module, - absl::Span arg_literals) { + absl::StatusOr Evaluate(const HloModule& module, + absl::Span arg_literals) { return Evaluate(*module.entry_computation(), arg_literals); } @@ -136,11 +136,12 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // (Dummy template arg is to reduce the overloading priority of one overload // so that Evaluate(module, {}) resolves unambiguously.) - StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals); + absl::StatusOr Evaluate( + const HloComputation& computation, + absl::Span arg_literals); template - StatusOr Evaluate(const HloComputation& computation, - absl::Span arg_literals) { + absl::StatusOr Evaluate(const HloComputation& computation, + absl::Span arg_literals) { std::vector arg_literal_ptrs; for (const auto& l : arg_literals) { arg_literal_ptrs.push_back(&l); @@ -154,7 +155,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // within its parent computation until it encounters something that cannot be // evaluated, such as an Infeed or a Parameter instruction. // It makes best effort to partially evaluate a dependency if possible. - StatusOr Evaluate( + absl::StatusOr Evaluate( const HloInstruction* instruction, bool recursively_evaluate_nonconstant_operands = false); @@ -168,30 +169,29 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // // For example, given instruction = op(A, B, C) and the map // {A = x, C = y}, this evaluates op(x, B, y). - StatusOr EvaluateWithSubstitutions( + absl::StatusOr EvaluateWithSubstitutions( const HloInstruction* instruction, const absl::flat_hash_map& substitutions); - StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, - const Literal& lhs, - const Literal& rhs); + absl::StatusOr EvaluateElementwiseBinaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs); - StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, - const Literal& operand); + absl::StatusOr EvaluateElementwiseUnaryOp(HloOpcode opcode, + const Literal& operand); - StatusOr EvaluateElementwiseTernaryOp(HloOpcode opcode, - const Literal& lhs, - const Literal& rhs, - const Literal& ehs); + absl::StatusOr EvaluateElementwiseTernaryOp(HloOpcode opcode, + const Literal& lhs, + const Literal& rhs, + const Literal& ehs); - StatusOr EvaluateElementwiseCompareOp(ComparisonDirection direction, - const Literal& lhs, - const Literal& rhs); + absl::StatusOr EvaluateElementwiseCompareOp( + ComparisonDirection direction, const Literal& lhs, const Literal& rhs); - StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, - const Literal& lhs, const Literal& rhs); + absl::StatusOr EvaluateDotOp(const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + const Literal& lhs, const Literal& rhs); void set_dynamic_dimension_inference( DynamicDimensionInference* dynamic_dimension_inference) { @@ -208,7 +208,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // Handles evaluation of a custom-call op. // Operand literals are provided in |operands| and implementations must // populate |output| before returning. - using CustomCallHandler = std::function( + using CustomCallHandler = std::function( const HloInstruction* custom_call, absl::Span operands)>; // Sets a handler that is called during evaluation for custom-call ops. @@ -436,7 +436,7 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { private: template - static StatusOr ElementWiseUnaryOpImpl( + static absl::StatusOr ElementWiseUnaryOpImpl( const HloInstruction* instruction, const std::function& unary_op, const Literal& operand_literal) { diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc index f5f3f6d6c2056b..ef3cb8ad8cee90 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -77,7 +77,7 @@ class HloEvaluatorTest : public HloTestBase { public: HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); } - StatusOr Evaluate( + absl::StatusOr Evaluate( absl::Span arg_literals = {}) { if (use_bfloat16_) { HloElementTypeConverter(F32, BF16).Run(m_.get()).value(); @@ -155,7 +155,7 @@ class HloEvaluatorTest : public HloTestBase { } void TestEvaluationFailure(HloInstruction* instruction) { - StatusOr result = evaluator_.Evaluate(instruction); + absl::StatusOr result = evaluator_.Evaluate(instruction); EXPECT_TRUE(!result.ok()); } @@ -170,7 +170,7 @@ class HloEvaluatorTest : public HloTestBase { } void TestRecursiveEvaluationFailure(HloInstruction* instruction) { - StatusOr result = evaluator_.Evaluate( + absl::StatusOr result = evaluator_.Evaluate( instruction, /*recursively_evaluate_nonconstant_operands=*/true); EXPECT_TRUE(!result.ok()); } diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index d08c824001d760..68b79d25bf5d24 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1605,7 +1605,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } private: - StatusOr ElementWiseUnaryOp( + absl::StatusOr ElementWiseUnaryOp( const HloInstruction* instruction, const std::function& unary_op) { const Literal& operand_literal = @@ -1618,7 +1618,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { return std::move(result_literal); } - StatusOr ElementWiseBinaryOp( + absl::StatusOr ElementWiseBinaryOp( const HloInstruction* instruction, const std::function& binary_op) { @@ -1643,7 +1643,7 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { } template - StatusOr ElementwiseTernaryOp( + absl::StatusOr ElementwiseTernaryOp( const HloInstruction* instruction, const std::function& ternary_op) { const auto& shape = instruction->shape();