Skip to content

Commit

Permalink
Merge pull request #47164 from nouiz:upstream-llvm_var_name
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 357932525
Change-Id: I7112d2a77f3af87b64f6d9ce115260031a67f594
  • Loading branch information
tensorflower-gardener committed Feb 17, 2021
2 parents 49614f5 + fb9882f commit 2f7e5cf
Show file tree
Hide file tree
Showing 16 changed files with 171 additions and 129 deletions.
6 changes: 3 additions & 3 deletions tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ using xla::llvm_ir::IrArray;
namespace xla {
namespace cpu {

StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(
PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs,
absl::string_view /*name*/) {
string function_name;
bool cast_result_to_fp16 = false;
switch (prim_type) {
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter {

protected:
StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs,
llvm::Value* rhs) override;
llvm::Value* rhs,
absl::string_view name) override;
StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value) override;

Expand Down
98 changes: 54 additions & 44 deletions tensorflow/compiler/xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
primitive_util::BitWidth(to_type));
}
case HloOpcode::kExp:
return EmitExp(op->shape().element_type(), operand_value);
return EmitExp(op->shape().element_type(), operand_value, "");
case HloOpcode::kExpm1:
return EmitExpm1(op->shape().element_type(), operand_value);
case HloOpcode::kLog:
Expand Down Expand Up @@ -528,7 +528,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
// log(a+bi) = log(abs(a+bi)) + i*atan2(b,a)
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a));
TF_ASSIGN_OR_RETURN(llvm::Value * angle,
EmitAtan2(component_type, b, a, ""));
TF_ASSIGN_OR_RETURN(llvm::Value * abs,
EmitComplexAbs(component_type, operand_value));
TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs));
Expand All @@ -543,7 +544,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a_plus_one = FAdd(a, one);
auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
TF_ASSIGN_OR_RETURN(auto angle,
EmitAtan2(component_type, b, a_plus_one, ""));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle);
}
Expand All @@ -566,7 +568,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
TF_ASSIGN_OR_RETURN(
auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
auto exp_a,
EmitExp(component_type, EmitExtractReal(operand_value), ""));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
Expand All @@ -576,7 +579,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
case HloOpcode::kExpm1: {
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
TF_ASSIGN_OR_RETURN(
auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
auto exp_a,
EmitExp(component_type, EmitExtractReal(operand_value), ""));
TF_ASSIGN_OR_RETURN(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
Expand All @@ -597,7 +601,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
Expand All @@ -619,7 +623,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b, ""));
auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
Expand Down Expand Up @@ -828,15 +832,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
return FAdd(lhs_value, rhs_value);
return FAdd(lhs_value, rhs_value, op->name());
case HloOpcode::kSubtract:
return FSub(lhs_value, rhs_value);
return FSub(lhs_value, rhs_value, op->name());
case HloOpcode::kMultiply:
return FMul(lhs_value, rhs_value);
return FMul(lhs_value, rhs_value, op->name());
case HloOpcode::kDivide:
return FDiv(lhs_value, rhs_value);
return FDiv(lhs_value, rhs_value, op->name());
case HloOpcode::kRemainder:
return FRem(lhs_value, rhs_value);
return FRem(lhs_value, rhs_value, op->name());
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
Expand All @@ -848,32 +852,34 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
switch (op->comparison_direction()) {
case ComparisonDirection::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
case ComparisonDirection::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
rhs_value, b_);
rhs_value, b_, op->name());
}
}
case HloOpcode::kMaximum:
return EmitFloatMax(lhs_value, rhs_value);
return EmitFloatMax(lhs_value, rhs_value, op->name());
case HloOpcode::kMinimum:
return EmitFloatMin(lhs_value, rhs_value);
return EmitFloatMin(lhs_value, rhs_value, op->name());
case HloOpcode::kPower:
return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
return EmitPow(op->shape().element_type(), lhs_value, rhs_value,
op->name());
case HloOpcode::kAtan2:
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value,
op->name());
default:
return Unimplemented("binary floating point op '%s'",
HloOpcodeString(op->opcode()));
Expand Down Expand Up @@ -901,8 +907,8 @@ ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type,
llvm::Intrinsic::fabs, {real}, {real->getType()}, b_);
llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_);
llvm::Value* max = EmitFloatMax(abs_real, abs_imag);
llvm::Value* min = EmitFloatMin(abs_real, abs_imag);
llvm::Value* max = EmitFloatMax(abs_real, abs_imag, "");
llvm::Value* min = EmitFloatMin(abs_real, abs_imag, "");

llvm::Value* div = FDiv(min, max);
llvm::Value* div_sq = FMul(div, div);
Expand Down Expand Up @@ -939,7 +945,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitSqrtComplexAbs(
TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max));
TF_ASSIGN_OR_RETURN(llvm::Value * pow,
EmitPow(prim_type, one_p_div_sq,
llvm::ConstantFP::get(max->getType(), .25)));
llvm::ConstantFP::get(max->getType(), .25), ""));
llvm::Value* result = FMul(sqrt_max, pow);
// When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN.
// In such cases, we return `min` instead of `result`.
Expand Down Expand Up @@ -983,7 +989,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexSqrt(

llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));

llvm::Value* c = llvm::ConstantFP::get(type, 0.5);
llvm::Value* angle = FMul(t, c);
Expand Down Expand Up @@ -1039,7 +1045,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexRsqrt(

llvm::Value* a = EmitExtractReal(operand_value);
llvm::Value* b = EmitExtractImag(operand_value);
TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a));
TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a, ""));

llvm::Value* c = llvm::ConstantFP::get(type, -0.5);
llvm::Value* angle = FMul(t, c);
Expand Down Expand Up @@ -1116,13 +1122,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexPower(
auto half_c = FMul(one_half, c);

TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
EmitPow(component_type, aa_p_bb, half_c));
EmitPow(component_type, aa_p_bb, half_c, ""));

auto neg_d = FNeg(d);
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a, ""));
auto neg_d_arg_lhs = FMul(neg_d, arg_lhs);
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
EmitExp(component_type, neg_d_arg_lhs));
EmitExp(component_type, neg_d_arg_lhs, ""));
auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
auto half_d = FMul(one_half, d);
Expand Down Expand Up @@ -1314,13 +1320,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
}

llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max());
llvm::Value* rhs_value,
absl::string_view name) {
return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max(), name);
}

llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max());
llvm::Value* rhs_value,
absl::string_view name) {
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max(), name);
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
Expand Down Expand Up @@ -1404,9 +1412,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
llvm::Value* value) {
llvm::Value* value,
absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
{value->getType()}, b_);
{value->getType()}, b_, name);
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
Expand All @@ -1417,7 +1426,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
auto half = llvm::ConstantFP::get(type, 0.5);
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value, ""));
auto for_large_x = FSub(exp_x, one);
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
Expand All @@ -1438,9 +1447,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,

StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
llvm::Value* rhs,
absl::string_view name) {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
{lhs->getType()}, b_);
{lhs->getType()}, b_, name);
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
Expand All @@ -1450,15 +1460,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitCbrt(PrimitiveType prim_type,
auto abs_value =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * abs_res,
EmitPow(prim_type, abs_value, third));
EmitPow(prim_type, abs_value, third, ""));
auto signed_res = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign,
{abs_res, value}, {type}, b_);
return signed_res;
}

StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) {
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(
PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* /*rhs*/,
absl::string_view /*name*/) {
return Unimplemented("atan2");
}

Expand Down Expand Up @@ -1728,7 +1738,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
operand_to_generator.at(hlo->operand(2))(index));
PrimitiveType prim_type = hlo->shape().element_type();
if (primitive_util::IsFloatingPointType(prim_type)) {
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value, ""), "");
} else if (primitive_util::IsIntegralType(prim_type)) {
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
return EmitIntegralMin(
Expand Down
15 changes: 10 additions & 5 deletions tensorflow/compiler/xla/service/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* rhs_value);

virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value);
llvm::Value* rhs_value,
absl::string_view name);

virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value);
llvm::Value* rhs_value,
absl::string_view name);

llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed);
Expand All @@ -117,7 +119,8 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
bool is_signed);

virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type,
llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* lhs, llvm::Value* rhs,
absl::string_view name);

virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type,
llvm::Value* value);
Expand All @@ -141,13 +144,15 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
llvm::Value* value);

virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type,
llvm::Value* value);
llvm::Value* value,
absl::string_view name);

virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type,
llvm::Value* value);

virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type,
llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* lhs, llvm::Value* rhs,
absl::string_view name);

virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type,
llvm::Value* value);
Expand Down
Loading

0 comments on commit 2f7e5cf

Please sign in to comment.