Skip to content

Commit

Permalink
Merge pull request #53553 from qingyunqu:add-normalize-arith
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 419128218
Change-Id: If8d5c0c1de5cb988a147d7adf66cbb8f4c32c77d
  • Loading branch information
tensorflower-gardener committed Dec 31, 2021
2 parents 698d84c + f06d914 commit 221417c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 9 deletions.
58 changes: 49 additions & 9 deletions tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc
Expand Up @@ -4449,27 +4449,67 @@ struct min<APInt> {
}
};

#define BINARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
return {}; \
#define BINARY_FOLDER_INTERNAL(Op, Func) \
if (getElementTypeOrSelf(getType()).isa<FloatType>()) \
return BinaryFolder<Op, FloatType, APFloat, Func<APFloat>>(this, attrs); \
if (getElementTypeOrSelf(getType()).isa<IntegerType>()) \
return BinaryFolder<Op, IntegerType, APInt, Func<APInt>>(this, attrs); \
return {};

#define BINARY_FOLDER(Op, Func) \
OpFoldResult Op::fold(ArrayRef<Attribute> attrs) { \
BINARY_FOLDER_INTERNAL(Op, Func) \
}

// Addition, subtraction and multiplication use the std:: versions of the ops.
// Due to the other ops behaving differently in signed vs unsigned integers,
// APInts need a special implementation. Currently, it replicates signed int
// op behavior.
BINARY_FOLDER(AddOp, std::plus);
BINARY_FOLDER(SubOp, std::minus);
BINARY_FOLDER(MulOp, std::multiplies);
BINARY_FOLDER(DivOp, divide);
BINARY_FOLDER(RemOp, remainder);
BINARY_FOLDER(MaxOp, max);
BINARY_FOLDER(MinOp, min);

OpFoldResult AddOp::fold(ArrayRef<Attribute> attrs) {
if (attrs[0] && attrs[1]) {
BINARY_FOLDER_INTERNAL(AddOp, std::plus)
}
// Handle special case where one operand is 0: x + 0 => x
if (attrs[0] || attrs[1]) {
SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()
: attrs[1].dyn_cast<SplatElementsAttr>();
if (!attr) return {};
Value result = attrs[0] ? rhs() : lhs();
if (attr.getElementType().isa<FloatType>()) {
if (attr.getSplatValue<APFloat>().isZero()) return result;
} else if (attr.getElementType().isa<IntegerType>()) {
if (attr.getSplatValue<APInt>().isZero()) return result;
}
}
return {};
}

OpFoldResult MulOp::fold(ArrayRef<Attribute> attrs) {
if (attrs[0] && attrs[1]) {
BINARY_FOLDER_INTERNAL(MulOp, std::multiplies);
}
// Handle special case where one operand is 1: x * 1 => x
if (attrs[0] || attrs[1]) {
SplatElementsAttr attr = attrs[0] ? attrs[0].dyn_cast<SplatElementsAttr>()
: attrs[1].dyn_cast<SplatElementsAttr>();
if (!attr) return {};
Value result = attrs[0] ? rhs() : lhs();
if (attr.getElementType().isa<FloatType>()) {
if (attr.getSplatValue<APFloat>().convertToDouble() == 1.0) return result;
} else if (attr.getElementType().isa<IntegerType>()) {
if (attr.getSplatValue<APInt>().getSExtValue() == 1) return result;
}
}
return {};
}

#undef BINARY_FOLDER_INTERNAL
#undef BINARY_FOLDER

//===----------------------------------------------------------------------===//
Expand Down
Expand Up @@ -27,6 +27,22 @@ func @add_fold_float() -> tensor<4xf64> {
return %2 : tensor<4xf64>
}

// CHECK-LABEL: add_zero_int_fold
func @add_zero_int_fold(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
%0 = mhlo.constant dense<0> : tensor<2x2xi64>
%1 = "mhlo.add"(%arg0, %0) : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// CHECK: return %arg0 : tensor<2x2xi64>
return %1 : tensor<2x2xi64>
}

// CHECK-LABEL: add_zero_float_flod
func @add_zero_float_flod(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = mhlo.constant dense<0.0> : tensor<2x2xf32>
%1 = "mhlo.add"(%0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: return %arg0 : tensor<2x2xf32>
return %1 : tensor<2x2xf32>
}

// CHECK-LABEL: sub_scalar_fold
func @sub_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<5> : tensor<4xi64>
Expand All @@ -45,6 +61,47 @@ func @multiply_scalar_fold() -> tensor<4xi64> {
return %2 : tensor<4xi64>
}

// CHECK-LABEL: mul_one_int_fold
func @mul_one_int_fold(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
%0 = mhlo.constant dense<1> : tensor<2x2xi64>
%1 = "mhlo.multiply"(%arg0, %0) : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64>
// CHECK: return %arg0 : tensor<2x2xi64>
return %1 : tensor<2x2xi64>
}

// CHECK-LABEL: mul_one_int8_fold
func @mul_one_int8_fold(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
%0 = mhlo.constant dense<1> : tensor<2x2xi8>
%1 = "mhlo.multiply"(%arg0, %0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8>
// CHECK: return %arg0 : tensor<2x2xi8>
return %1 : tensor<2x2xi8>
}

// CHECK-LABEL: mul_one_float_flod
func @mul_one_float_flod(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = mhlo.constant dense<1.0> : tensor<2x2xf32>
%1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: return %arg0 : tensor<2x2xf32>
return %1 : tensor<2x2xf32>
}

// CHECK-LABEL: mul_one_fp16_flod
func @mul_one_fp16_flod(%arg0: tensor<2x2xf16>) -> tensor<2x2xf16> {
%0 = mhlo.constant dense<1.0> : tensor<2x2xf16>
%1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xf16>, tensor<2x2xf16>) -> tensor<2x2xf16>
// CHECK: return %arg0 : tensor<2x2xf16>
return %1 : tensor<2x2xf16>
}

// CHECK-LABEL: mul_one_bf16_flod
func @mul_one_bf16_flod(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
%0 = mhlo.constant dense<1.0> : tensor<2x2xbf16>
%1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
// CHECK: return %arg0 : tensor<2x2xbf16>
return %1 : tensor<2x2xbf16>
}


// CHECK-LABEL: divide_scalar_fold
func @divide_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64>
Expand Down

0 comments on commit 221417c

Please sign in to comment.