Skip to content

Commit

Permalink
[CHLO] Add erf_inv and lowering to mhlo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 513183138
  • Loading branch information
atondwal authored and TensorFlow MLIR Team committed Mar 1, 2023
1 parent b5dfbff commit 221ac0e
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 0 deletions.
180 changes: 180 additions & 0 deletions mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/math-constants
#define _USE_MATH_DEFINES
#include <algorithm>
#include <array>
#include <cmath>
#include <numeric>
#include <vector>
Expand Down Expand Up @@ -633,6 +634,184 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
}
};

Value erfInv32(ConversionPatternRewriter &b, Location loc, ValueRange args) {
constexpr int kDegree = 9;
constexpr std::array<float, 9> wLessThan5Constants = {
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
constexpr std::array<float, 9> wGreaterThan5Constants = {
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};

Value x = args[0];
// Compute logarithm of (1+arg) using log1p(arg) which is more precise than
// log(1+arg) when arg is close to zero. For more details, see
// https://en.cppreference.com/w/cpp/numeric/math/log1p
Value minusXSquared =
b.create<mhlo::MulOp>(loc, x, b.create<mhlo::NegOp>(loc, x));
Value w =
b.create<mhlo::NegOp>(loc, b.create<mhlo::Log1pOp>(loc, minusXSquared));

Value lt = b.create<mhlo::CompareOp>(loc, w, getConstantLike(b, loc, 5.0, x),
mhlo::ComparisonDirection::LT);
auto coefficient = [&](int i) {
return b.create<mhlo::SelectOp>(
loc, lt, getConstantLike(b, loc, wLessThan5Constants[i], x),
getConstantLike(b, loc, wGreaterThan5Constants[i], x));
};
w = b.create<mhlo::SelectOp>(
loc, lt,
b.create<mhlo::SubtractOp>(loc, w, getConstantLike(b, loc, 2.5, x)),
b.create<mhlo::SubtractOp>(loc, b.create<mhlo::SqrtOp>(loc, w),
getConstantLike(b, loc, 3.0, x)));
Value p = coefficient(0);
for (int i = 1; i < kDegree; ++i) {
p = b.create<mhlo::AddOp>(loc, coefficient(i),
b.create<mhlo::MulOp>(loc, p, w));
}

// Result modulo edge cases.
Value result = b.create<mhlo::MulOp>(loc, p, x);

// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
// indeterminate, and can give nan or -/+inf.)
return b.create<mhlo::SelectOp>(
loc,
b.create<mhlo::CompareOp>(loc, b.create<mhlo::AbsOp>(loc, x),
getConstantLike(b, loc, 1, x),
mhlo::ComparisonDirection::EQ),
b.create<mhlo::MulOp>(loc, x, getConstantLikeMaxFiniteValue(b, loc, x)),
result);
}

Value erfInv64(ConversionPatternRewriter &b, Location loc, ValueRange args) {
constexpr std::array<double, 23> wLessThan625Constants = {
-3.6444120640178196996e-21, -1.685059138182016589e-19,
1.2858480715256400167e-18, 1.115787767802518096e-17,
-1.333171662854620906e-16, 2.0972767875968561637e-17,
6.6376381343583238325e-15, -4.0545662729752068639e-14,
-8.1519341976054721522e-14, 2.6335093153082322977e-12,
-1.2975133253453532498e-11, -5.4154120542946279317e-11,
1.051212273321532285e-09, -4.1126339803469836976e-09,
-2.9070369957882005086e-08, 4.2347877827932403518e-07,
-1.3654692000834678645e-06, -1.3882523362786468719e-05,
0.0001867342080340571352, -0.00074070253416626697512,
-0.0060336708714301490533, 0.24015818242558961693,
1.6536545626831027356};
constexpr std::array<double, 19> wLessThan16Constants = {
2.2137376921775787049e-09, 9.0756561938885390979e-08,
-2.7517406297064545428e-07, 1.8239629214389227755e-08,
1.5027403968909827627e-06, -4.013867526981545969e-06,
2.9234449089955446044e-06, 1.2475304481671778723e-05,
-4.7318229009055733981e-05, 6.8284851459573175448e-05,
2.4031110387097893999e-05, -0.0003550375203628474796,
0.00095328937973738049703, -0.0016882755560235047313,
0.0024914420961078508066, -0.0037512085075692412107,
0.005370914553590063617, 1.0052589676941592334,
3.0838856104922207635,
};
constexpr std::array<double, 17> wGreaterThan16Constants = {
-2.7109920616438573243e-11, -2.5556418169965252055e-10,
1.5076572693500548083e-09, -3.7894654401267369937e-09,
7.6157012080783393804e-09, -1.4960026627149240478e-08,
2.9147953450901080826e-08, -6.7711997758452339498e-08,
2.2900482228026654717e-07, -9.9298272942317002539e-07,
4.5260625972231537039e-06, -1.9681778105531670567e-05,
7.5995277030017761139e-05, -0.00021503011930044477347,
-0.00013871931833623122026, 1.0103004648645343977,
4.8499064014085844221,
};

Value x = args[0];
// Compute logarithm of (1+arg) using log1p(arg) which is more precise than
// log(1+arg) when arg is close to zero. For more details, see
// https://en.cppreference.com/w/cpp/numeric/math/log1p
Value minusXSquared =
b.create<mhlo::MulOp>(loc, x, b.create<mhlo::NegOp>(loc, x));
Value w =
b.create<mhlo::NegOp>(loc, b.create<mhlo::Log1pOp>(loc, minusXSquared));

Value lt625 = b.create<mhlo::CompareOp>(
loc, w, getConstantLike(b, loc, 6.25, x), mhlo::ComparisonDirection::LT);
Value lt16 = b.create<mhlo::CompareOp>(loc, w, getConstantLike(b, loc, 16, x),
mhlo::ComparisonDirection::LT);

auto coefficient = [&](int i) {
Value c = getConstantLike(b, loc, wLessThan625Constants[i], x);
if (i < 19) {
c = b.create<mhlo::SelectOp>(
loc, lt625, c, getConstantLike(b, loc, wLessThan16Constants[i], x));
}
if (i < 17) {
c = b.create<mhlo::SelectOp>(
loc, lt16, c, getConstantLike(b, loc, wGreaterThan16Constants[i], x));
}
return c;
};

Value sqrtW = b.create<mhlo::SqrtOp>(loc, w);
Value wMinus3125 =
b.create<mhlo::SubtractOp>(loc, w, getConstantLike(b, loc, 3.125, x));
Value select2 =
b.create<mhlo::SelectOp>(loc, lt16, getConstantLike(b, loc, 3.25, w),
getConstantLike(b, loc, 5.0, w));
Value select2Result = b.create<mhlo::SubtractOp>(loc, sqrtW, select2);
w = b.create<mhlo::SelectOp>(loc, lt625, wMinus3125, select2Result);

Value p = coefficient(0);
for (int i = 1; i < 17; ++i) {
p = b.create<mhlo::AddOp>(loc, coefficient(i),
b.create<mhlo::MulOp>(loc, p, w));
}
for (int i = 17; i < 19; ++i) {
p = b.create<mhlo::SelectOp>(
loc, lt16,
b.create<mhlo::AddOp>(loc, coefficient(i),
b.create<mhlo::MulOp>(loc, p, w)),
p);
}
for (int i = 19; i < 23; ++i) {
p = b.create<mhlo::SelectOp>(
loc, lt625,
b.create<mhlo::AddOp>(loc, coefficient(i),
b.create<mhlo::MulOp>(loc, p, w)),
p);
}

// Result modulo edge cases.
Value result = b.create<mhlo::MulOp>(loc, p, x);

// Handle edge cases, namely erfinv(+/-1) = +/-inf. (The above computation is
// indeterminate, and can give nan or -/+inf.)
return b.create<mhlo::SelectOp>(
loc,
b.create<mhlo::CompareOp>(loc, b.create<mhlo::AbsOp>(loc, x),
getConstantLike(b, loc, 1, x),
mhlo::ComparisonDirection::EQ),
b.create<mhlo::MulOp>(loc, x, getConstantLikeMaxFiniteValue(b, loc, x)),
result);
}

struct ConvertErfInvOp : public OpConversionPattern<ErfInvOp> {
using OpConversionPattern<ErfInvOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ErfInvOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
if (op.getResult().getType().getElementType().isF64()) {
rewriter.replaceOp(op, erfInv64(rewriter, loc, adaptor.getOperands()));
return success();
}
FloatType minPrecisionTy = rewriter.getF32Type();
rewriter.replaceOp(
op, materializeWithUpcast(rewriter, loc, adaptor.getOperands(),
minPrecisionTy, &erfInv32));
return success();
}
};

// Coefficients for the Lanczos approximation of the gamma function. The
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
Expand Down Expand Up @@ -1705,6 +1884,7 @@ void populateDecomposeChloPatterns(MLIRContext *context,
ConvertDigammaOp,
ConvertErfOp,
ConvertErfcOp,
ConvertErfInvOp,
ConvertLgammaOp,
ConvertNextAfterOp,
ConvertPolygammaOp,
Expand Down
1 change: 1 addition & 0 deletions stablehlo/stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CoshOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(DigammaOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfInvOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp)
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,15 @@ def CHLO_ErfOp : CHLO_UnaryElementwiseOp<"erf",
}];
}

def CHLO_ErfInvOp : CHLO_UnaryElementwiseOp<"erf_inv",
[HLO_CompatibleOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
let summary = "Inverse Erf";
let description = [{
Returns `ErfInv(operand)` element-wise.
}];
}


def CHLO_ErfcOp : CHLO_UnaryElementwiseOp<"erfc",
[HLO_CompatibleOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
let summary = "Erfc operator";
Expand Down
7 changes: 7 additions & 0 deletions stablehlo/stablehlo/tests/ops_chlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,10 @@ func.func @top_k(%arg0 : tensor<16x16xf32>) {
%0:2 = chlo.top_k(%arg0, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>)
return
}

// -----

func.func @erf_inv(%arg0 : tensor<16x16xf32>) {
%0 = chlo.erf_inv %arg0 : tensor<16x16xf32> -> tensor<16x16xf32>
return
}
8 changes: 8 additions & 0 deletions stablehlo/stablehlo/tests/ops_chlo_roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,11 @@ func.func @chlo_rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tenso
}) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

// CHECK-LABEL: func @chlo_erf_inv
// CHECK-SAME: %[[A0:.*0]]: tensor<16x16xf32>)
// CHECK: chlo.erf_inv %[[A0]] : tensor<16x16xf32> -> tensor<16x16xf32>
func.func @chlo_erf_inv(%arg0 : tensor<16x16xf32>) {
%0 = "chlo.erf_inv"(%arg0) : (tensor<16x16xf32>) -> tensor<16x16xf32>
return
}
Loading

0 comments on commit 221ac0e

Please sign in to comment.