Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] More readable emitted LLVM code. #47164

Merged
merged 1 commit into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace cpu {

StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {

protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
llvm::Value* rhs,
absl::string_view name = "") override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;

Expand Down
53 changes: 29 additions & 24 deletions tensorflow/compiler/xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -828,15 +828,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
return FAdd(lhs_value, rhs_value);
return FAdd(lhs_value, rhs_value, op->name());
case HloOpcode::kSubtract:
return FSub(lhs_value, rhs_value);
return FSub(lhs_value, rhs_value, op->name());
case HloOpcode::kMultiply:
return FMul(lhs_value, rhs_value);
return FMul(lhs_value, rhs_value, op->name());
case HloOpcode::kDivide:
return FDiv(lhs_value, rhs_value);
return FDiv(lhs_value, rhs_value, op->name());
case HloOpcode::kRemainder:
return FRem(lhs_value, rhs_value);
return FRem(lhs_value, rhs_value, op->name());
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
Expand All @@ -848,32 +848,32 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
}
}
case HloOpcode::kMaximum:
return EmitFloatMax(lhs_value, rhs_value);
return EmitFloatMax(lhs_value, rhs_value, op->name());
case HloOpcode::kMinimum:
return EmitFloatMin(lhs_value, rhs_value);
return EmitFloatMin(lhs_value, rhs_value, op->name());
case HloOpcode::kPower:
return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
return EmitPow(op->shape().element_type(), lhs_value, rhs_value, op->name());
case HloOpcode::kAtan2:
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value, op->name());
default:
return Unimplemented("binary floating point op '%s'",
HloOpcodeString(op->opcode()));
Expand Down Expand Up @@ -1314,13 +1314,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
}

llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
llvm::Value* rhs_value,
absl::string_view name) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
}

llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
llvm::Value* rhs_value,
absl::string_view name) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
Expand Down Expand Up @@ -1404,9 +1406,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
llvm::Value* value) {
llvm::Value* value,
absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
{value->getType()}, b_);
{value->getType()}, b_, name);
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
Expand Down Expand Up @@ -1438,9 +1441,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,

StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
{lhs->getType()}, b_);
{lhs->getType()}, b_, name);
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
Expand All @@ -1458,7 +1462,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,

StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return Unimplemented("atan2");
}

Expand Down
15 changes: 10 additions & 5 deletions tensorflow/compiler/xla/service/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* rhs_value);

virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value);
llvm::Value* rhs_value,
absl::string_view name = "");

virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value);
llvm::Value* rhs_value,
absl::string_view name = "");

llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed);
Expand All @@ -117,7 +119,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
bool is_signed);

virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* lhs, llvm::Value* rhs,
absl::string_view name = "");

virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
llvm::Value* value);
Expand All @@ -141,13 +144,15 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* value);

virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
llvm::Value* value);
llvm::Value* value,
absl::string_view name = "");

virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
llvm::Value* value);

virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* lhs, llvm::Value* rhs,
absl::string_view name = "");

virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value);
Expand Down
23 changes: 14 additions & 9 deletions tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ GpuElementalIrEmitter::GpuElementalIrEmitter(

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name) {
// Device functions dont have f16 math functions, so we convert the operands
// to f32 before calling the function and then convert the result back to f16.
bool cast_result_to_fp16 = false;
Expand Down Expand Up @@ -109,7 +110,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
const string& munged_callee =
ObtainDeviceFunctionName(funcid, output_type, b());
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
converted_input_types, output_type)
converted_input_types, output_type, name)
.ValueOrDie();
if (cast_result_to_fp16) {
result = FPCast(result, b()->getHalfTy());
Expand Down Expand Up @@ -142,7 +143,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall(

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type) {
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name) {
// Binary math functions transform are of type [T] -> T.
for (PrimitiveType input_type : input_types) {
if (output_type != input_type) {
Expand All @@ -154,7 +156,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(

return EmitDeviceFunctionCall(
callee_name, operands, input_types, output_type,
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b());
{llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b(), name);
}

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
Expand Down Expand Up @@ -222,7 +224,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type,
}

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type,
llvm::Value* value) {
llvm::Value* value,
absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kExp, {value}, {prim_type},
prim_type);
}
Expand All @@ -235,9 +238,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kPow, {lhs, rhs},
{prim_type, prim_type}, prim_type);
{prim_type, prim_type}, prim_type, name);
}

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSqrt(PrimitiveType prim_type,
Expand All @@ -254,9 +258,10 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitRsqrt(PrimitiveType prim_type,

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return EmitDeviceMathCall(TargetDeviceFunctionID::kAtan2, {lhs, rhs},
{prim_type, prim_type}, prim_type);
{prim_type, prim_type}, prim_type, name);
}

StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type,
Expand Down
13 changes: 8 additions & 5 deletions tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Value* value) override;

StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
llvm::Value* value) override;
llvm::Value* value,
absl::string_view name = "") override;

StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
llvm::Value* value) override;
Expand All @@ -77,10 +78,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
llvm::Value* value) override;

StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
llvm::Value* rhs, absl::string_view name = "") override;

StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
llvm::Value* rhs, absl::string_view name = "") override;

StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;
Expand Down Expand Up @@ -118,13 +119,15 @@ class GpuElementalIrEmitter : public ElementalIrEmitter {
// return value of the function.
StatusOr<llvm::Value*> EmitDeviceMathCall(
TargetDeviceFunctionID funcid, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name = "");

// Emits IR to call a function of type [T] -> T. Does not munge callee_name.
// Returns the IR value that represents the return value of the function.
StatusOr<llvm::Value*> EmitMathCall(
const string& callee_name, absl::Span<llvm::Value* const> operands,
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type);
absl::Span<const PrimitiveType> input_types, PrimitiveType output_type,
absl::string_view name = "");

const HloModuleConfig& hlo_module_config_;

Expand Down
5 changes: 3 additions & 2 deletions tensorflow/compiler/xla/service/gpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
for (const HloInstruction* operand : hlo->operands()) {
operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_, operand->name());
};
}
return EmitTargetElementLoop(
Expand Down Expand Up @@ -688,7 +688,8 @@ void IrEmitter::BindFusionArguments(const HloInstruction* fusion,
fused_emitter->BindGenerator(
fusion->fused_parameter(i),
[this, operand, fusion](llvm_ir::IrArray::Index index) {
return GetIrArray(*operand, *fusion).EmitReadArrayElement(index, &b_);
return GetIrArray(*operand, *fusion).EmitReadArrayElement(
index, &b_, operand->name());
});
}
}
Expand Down