From fb9882f08903159632a7f7030f2609db254e079f Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Mon, 1 Feb 2021 06:28:45 -0800 Subject: [PATCH] [XLA] More readable emitted LLVM code. --- .../xla/service/cpu/elemental_ir_emitter.cc | 3 +- .../xla/service/cpu/elemental_ir_emitter.h | 3 +- .../xla/service/elemental_ir_emitter.cc | 53 ++++++++++--------- .../xla/service/elemental_ir_emitter.h | 15 ++++-- .../xla/service/gpu/elemental_ir_emitter.cc | 23 ++++---- .../xla/service/gpu/elemental_ir_emitter.h | 13 +++-- .../compiler/xla/service/gpu/ir_emitter.cc | 5 +- .../xla/service/gpu/ir_emitter_unnested.cc | 38 ++++++++----- .../compiler/xla/service/gpu/target_util.cc | 4 +- .../compiler/xla/service/gpu/target_util.h | 2 +- .../xla/service/llvm_ir/fused_ir_emitter.cc | 2 +- .../compiler/xla/service/llvm_ir/ir_array.cc | 2 +- .../compiler/xla/service/llvm_ir/llvm_util.cc | 25 +++++---- .../compiler/xla/service/llvm_ir/llvm_util.h | 11 ++-- 14 files changed, 118 insertions(+), 81 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index a4566b11a78817..ded2629b11f061 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -35,7 +35,8 @@ namespace cpu { StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::Value* rhs, + absl::string_view name) { string function_name; bool cast_result_to_fp16 = false; switch (prim_type) { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index a002df25493c5e..23a446fb0df99c 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -37,7 +37,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { protected: StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) override; + llvm::Value* rhs, + absl::string_view name = "") override; StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 31ca1ab66d789e..e95c40cdc0f282 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -828,15 +828,15 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return FAdd(lhs_value, rhs_value); + return FAdd(lhs_value, rhs_value, op->name()); case HloOpcode::kSubtract: - return FSub(lhs_value, rhs_value); + return FSub(lhs_value, rhs_value, op->name()); case HloOpcode::kMultiply: - return FMul(lhs_value, rhs_value); + return FMul(lhs_value, rhs_value, op->name()); case HloOpcode::kDivide: - return FDiv(lhs_value, rhs_value); + return FDiv(lhs_value, rhs_value, op->name()); case HloOpcode::kRemainder: - return FRem(lhs_value, rhs_value); + return FRem(lhs_value, rhs_value, op->name()); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -848,32 +848,32 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( switch (op->comparison_direction()) { case ComparisonDirection::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, - rhs_value, b_); + rhs_value, b_, op->name()); case ComparisonDirection::kNe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, - rhs_value, b_); + rhs_value, b_, op->name()); case ComparisonDirection::kLt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, - rhs_value, b_); + rhs_value, b_, op->name()); case ComparisonDirection::kGt: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, - rhs_value, b_); + rhs_value, b_, op->name()); case ComparisonDirection::kLe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, - rhs_value, b_); + rhs_value, b_, op->name()); case ComparisonDirection::kGe: return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, - rhs_value, b_); + rhs_value, b_, op->name()); } } case HloOpcode::kMaximum: - return EmitFloatMax(lhs_value, rhs_value); + return EmitFloatMax(lhs_value, rhs_value, op->name()); case HloOpcode::kMinimum: - return EmitFloatMin(lhs_value, rhs_value); + return EmitFloatMin(lhs_value, rhs_value, op->name()); case HloOpcode::kPower: - return EmitPow(op->shape().element_type(), lhs_value, rhs_value); + return EmitPow(op->shape().element_type(), lhs_value, rhs_value, op->name()); case HloOpcode::kAtan2: - return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); + return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value, op->name()); default: return Unimplemented("binary floating point op '%s'", HloOpcodeString(op->opcode())); @@ -1314,13 +1314,15 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( } llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) { - return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max()); + llvm::Value* rhs_value, + absl::string_view name) { + return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) { - return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max()); + llvm::Value* rhs_value, + absl::string_view name) { + return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, @@ -1404,9 +1406,10 @@ StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, } StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value) { + llvm::Value* value, + absl::string_view name) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, - {value->getType()}, b_); + {value->getType()}, b_, name); } StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, @@ -1438,9 +1441,10 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::Value* rhs, + absl::string_view name) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, - {lhs->getType()}, b_); + {lhs->getType()}, b_, name); } StatusOr ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, @@ -1458,7 +1462,8 @@ StatusOr ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type, StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::Value* rhs, + absl::string_view name) { return Unimplemented("atan2"); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 7eff80d9f6c6bb..bcdbfc830d4aaa 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -105,10 +105,12 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value); + llvm::Value* rhs_value, + absl::string_view name = ""); virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value); + llvm::Value* rhs_value, + absl::string_view name = ""); llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed); @@ -117,7 +119,8 @@ class ElementalIrEmitter : public IrBuilderMixin { bool is_signed); virtual StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* lhs, llvm::Value* rhs, + absl::string_view name = ""); virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value); @@ -141,13 +144,15 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Value* value); virtual StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value); + llvm::Value* value, + absl::string_view name = ""); virtual StatusOr EmitExpm1(PrimitiveType prim_type, llvm::Value* value); virtual StatusOr EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* lhs, llvm::Value* rhs, + absl::string_view name = ""); virtual StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index e72c12813b71d4..a3f02164f1e2d1 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -78,7 +78,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter( StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( TargetDeviceFunctionID funcid, absl::Span operands, - absl::Span input_types, PrimitiveType output_type) { + absl::Span input_types, PrimitiveType output_type, + absl::string_view name) { // Device functions dont have f16 math functions, so we convert the operands // to f32 before calling the function and then convert the result back to f16. bool cast_result_to_fp16 = false; @@ -109,7 +110,7 @@ StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( const string& munged_callee = ObtainDeviceFunctionName(funcid, output_type, b()); llvm::Value* result = EmitMathCall(munged_callee, converted_operands, - converted_input_types, output_type) + converted_input_types, output_type, name) .ValueOrDie(); if (cast_result_to_fp16) { result = FPCast(result, b()->getHalfTy()); @@ -142,7 +143,8 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( StatusOr GpuElementalIrEmitter::EmitMathCall( const string& callee_name, absl::Span operands, - absl::Span input_types, PrimitiveType output_type) { + absl::Span input_types, PrimitiveType output_type, + absl::string_view name) { // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { @@ -154,7 +156,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( return EmitDeviceFunctionCall( callee_name, operands, input_types, output_type, - {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b()); + {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name); } StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( @@ -222,7 +224,8 @@ StatusOr GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type, } StatusOr GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value) { + llvm::Value* value, + absl::string_view name) { return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type}, prim_type); } @@ -235,9 +238,10 @@ StatusOr GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::Value* rhs, + absl::string_view name) { return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs}, - {prim_type, prim_type}, prim_type); + {prim_type, prim_type}, prim_type, name); } StatusOr GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type, @@ -254,9 +258,10 @@ StatusOr GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type, StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) { + llvm::Value* rhs, + absl::string_view name) { return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs}, - {prim_type, prim_type}, prim_type); + {prim_type, prim_type}, prim_type, name); } StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 0303ea47e8d607..17e2db5a8cc3f6 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -65,7 +65,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Value* value) override; StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) override; + llvm::Value* value, + absl::string_view name = "") override; StatusOr EmitExpm1(PrimitiveType prim_type, llvm::Value* value) override; @@ -77,10 +78,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Value* value) override; StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) override; + llvm::Value* rhs, absl::string_view name = "") override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) override; + llvm::Value* rhs, absl::string_view name = "") override; StatusOr EmitTanh(PrimitiveType prim_type, llvm::Value* value) override; @@ -118,13 +119,15 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { // return value of the function. StatusOr EmitDeviceMathCall( TargetDeviceFunctionID funcid, absl::Span operands, - absl::Span input_types, PrimitiveType output_type); + absl::Span input_types, PrimitiveType output_type, + absl::string_view name = ""); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. StatusOr EmitMathCall( const string& callee_name, absl::Span operands, - absl::Span input_types, PrimitiveType output_type); + absl::Span input_types, PrimitiveType output_type, + absl::string_view name = ""); const HloModuleConfig& hlo_module_config_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 613696a6fc423c..54804d700309a2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -91,7 +91,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_); + return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_, operand->name()); }; } return EmitTargetElementLoop( @@ -688,7 +688,8 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion, fused_emitter->BindGenerator( fusion->fused_parameter(i), [this, operand, fusion](llvm_ir::IrArray::Index index) { - return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_); + return GetIrArray(*operand, *fusion).EmitReadArrayElement( + index, &b_, operand->name()); }); } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 93316268ebde44..9e9370c29304d2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1900,10 +1900,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { GetNestedComputer()); FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter); for (int i = 0; i < fused_computation->num_parameters(); i++) { + auto fused_operand = fused_computation->parameter_instruction(i); operand_fused_emitter.BindGenerator( - fused_computation->parameter_instruction(i), - [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { - return ir_arrays[i].EmitReadArrayElement(index, &b_); + fused_operand, + [this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) { + return ir_arrays[i].EmitReadArrayElement( + index, &b_, fused_operand->name()); }); } TF_ASSIGN_OR_RETURN( @@ -1942,10 +1944,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { GetNestedComputer()); FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter); for (int i = 0; i < fused_computation->num_parameters(); i++) { + auto fused_operand = fused_computation->parameter_instruction(i); scatter_fused_emitter.BindGenerator( - fused_computation->parameter_instruction(i), - [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { - return ir_arrays[i].EmitReadArrayElement(index, &b_); + fused_operand, + [this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) { + return ir_arrays[i].EmitReadArrayElement( + index, &b_, fused_operand->name()); }); } @@ -2049,10 +2053,12 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { /*is_fusion=*/true)); for (int i = 0; i < fused_computation->num_parameters(); i++) { + auto fused_operand = fused_computation->parameter_instruction(i); fused_emitter.BindGenerator( - fused_computation->parameter_instruction(i), - [this, &ir_arrays, i](llvm_ir::IrArray::Index index) { - return ir_arrays[i].EmitReadArrayElement(index, &b_); + fused_operand, + [this, &ir_arrays, i, fused_operand](llvm_ir::IrArray::Index index) { + return ir_arrays[i].EmitReadArrayElement( + index, &b_, fused_operand->name()); }); } @@ -4165,8 +4171,10 @@ void IrEmitterUnnested::EmitTileElementForFusion( }; } else { auto array = operand_arrays[i]; - gen = [this, array](llvm_ir::IrArray::Index index) { - return array.EmitReadArrayElement(index, &b_); + auto name = fused_computation->parameter_instruction(i)->name(); + gen = [this, array, name](llvm_ir::IrArray::Index index) { + return array.EmitReadArrayElement( + index, &b_, name); }; } fused_emitter.BindGenerator(fused_computation->parameter_instruction(i), @@ -5621,10 +5629,12 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( CHECK_LT(fused_computation->num_parameters(), ir_arrays.size()); for (int i = 0; i < fused_computation->num_parameters(); i++) { auto ir_array = ir_arrays[i]; + auto fused_operand = fused_computation->parameter_instruction(i); fused_emitter->BindGenerator( - fused_computation->parameter_instruction(i), - [this, ir_array](llvm_ir::IrArray::Index index) { - return ir_array.EmitReadArrayElement(index, &b_); + fused_operand, + [this, ir_array, fused_operand](llvm_ir::IrArray::Index index) { + return ir_array.EmitReadArrayElement( + index, &b_, fused_operand->name()); }); } result_ir_arrays = absl::MakeSpan(ir_arrays).subspan( diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc index 31b590a19ffdee..978419989ea08d 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.cc +++ b/tensorflow/compiler/xla/service/gpu/target_util.cc @@ -194,7 +194,7 @@ llvm::CallInst* EmitDeviceFunctionCall( const string& callee_name, absl::Span operands, absl::Span input_types, PrimitiveType output_type, absl::Span attributes, - llvm::IRBuilder<>* b) { + llvm::IRBuilder<>* b, absl::string_view name) { std::vector ir_input_types; llvm::Module* module = b->GetInsertBlock()->getModule(); for (PrimitiveType input_type : input_types) { @@ -217,7 +217,7 @@ llvm::CallInst* EmitDeviceFunctionCall( callee->addFnAttr(attribute); } - return b->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return b->CreateCall(callee, llvm_ir::AsArrayRef(operands), name.data()); } llvm::CallInst* EmitCallToTargetIntrinsic( diff --git a/tensorflow/compiler/xla/service/gpu/target_util.h b/tensorflow/compiler/xla/service/gpu/target_util.h index 2bdaea7734ace6..115609d18c2ef8 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.h +++ b/tensorflow/compiler/xla/service/gpu/target_util.h @@ -69,7 +69,7 @@ llvm::CallInst* EmitDeviceFunctionCall( const std::string& callee_name, absl::Span operands, absl::Span input_type, PrimitiveType output_type, absl::Span attributes, - llvm::IRBuilder<>* b); + llvm::IRBuilder<>* b, absl::string_view name = ""); // Emits a call to the specified target intrinsic with the given operands. // Overloaded intrinsics (for example, "minnum") must include a type diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 0a26a2bb7ce545..135fd4ee9dd7ea 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -102,7 +102,7 @@ Status FusedIrEmitter::HandleConstant(const HloInstruction* constant) { global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); return IrArray(shape_constant, constant->shape()) - .EmitReadArrayElement(index, b_); + .EmitReadArrayElement(index, b_, constant->name()); }; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index a21e7fafb0a745..beb06c3184f50e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -504,7 +504,7 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index, bool use_linear_index) const { llvm::Value* element_address = EmitArrayElementAddress(index, b, name, use_linear_index); - llvm::LoadInst* load = b->CreateLoad(element_address); + llvm::LoadInst* load = b->CreateLoad(element_address, name.data()); AnnotateLoadStoreInstructionWithMetadata(load); return load; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index a00156a0e4a5a9..aaaec349d6aaf1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -83,36 +83,39 @@ string DumpModuleToString(const llvm::Module& module) { llvm::CallInst* EmitCallToIntrinsic( llvm::Intrinsic::ID intrinsic_id, absl::Span operands, - absl::Span overloaded_types, llvm::IRBuilder<>* b) { + absl::Span overloaded_types, llvm::IRBuilder<>* b, + absl::string_view name) { llvm::Module* module = ModuleFromIRBuilder(b); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); - return b->CreateCall(intrinsic, AsArrayRef(operands)); + return b->CreateCall(intrinsic, AsArrayRef(operands), name.data()); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b, bool enable_fast_min_max) { + llvm::IRBuilder<>* b, bool enable_fast_min_max, + absl::string_view name) { if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) { auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); - return b->CreateSelect(cmp, lhs_value, rhs_value); + return b->CreateSelect(cmp, lhs_value, rhs_value, name.data()); } else { auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value); auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan); - return b->CreateSelect(sel_lhs, lhs_value, rhs_value); + return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data()); } } llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b, bool enable_fast_min_max) { + llvm::IRBuilder<>* b, bool enable_fast_min_max, + absl::string_view name) { if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) { auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); - return b->CreateSelect(cmp, lhs_value, rhs_value); + return b->CreateSelect(cmp, lhs_value, rhs_value, name.data()); } else { auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value); auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan); - return b->CreateSelect(sel_lhs, lhs_value, rhs_value); + return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data()); } } @@ -351,12 +354,12 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b) { + llvm::IRBuilder<>* b, absl::string_view name) { llvm::Value* comparison_result; if (lhs_value->getType()->isIntegerTy()) { - comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value); + comparison_result = b->CreateICmp(predicate, lhs_value, rhs_value, name.data()); } else { - comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value); + comparison_result = b->CreateFCmp(predicate, lhs_value, rhs_value, name.data()); } // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 // arrays. So we extend it to i8 so that it's addressable. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 3a3b4b77d702e5..1171d29c154c49 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -103,17 +103,20 @@ string SanitizeFunctionName(string function_name); // overloaded type. llvm::CallInst* EmitCallToIntrinsic( llvm::Intrinsic::ID intrinsic_id, absl::Span operands, - absl::Span overloaded_types, llvm::IRBuilder<>* b); + absl::Span overloaded_types, llvm::IRBuilder<>* b, + absl::string_view name = ""); // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b, bool enable_fast_min_max); + llvm::IRBuilder<>* b, bool enable_fast_min_max, + absl::string_view name = ""); // Emit float min. Emit minnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b, bool enable_fast_min_max); + llvm::IRBuilder<>* b, bool enable_fast_min_max, + absl::string_view name = ""); // Convenience methods for emitting a GEP instruction that indexes into a buffer // (1-dimensional array), equivalent to array[index]. The type is automatically @@ -214,7 +217,7 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, // and then converts the result to i8 so that it is addressable. llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, llvm::Value* lhs, llvm::Value* rhs, - llvm::IRBuilder<>* b); + llvm::IRBuilder<>* b, absl::string_view name = ""); // Emits a call that logs the given value with the given tag as a prefix. // The provided tag and value are passed to a runtime logging call that is