Skip to content

Commit

Permalink
Add folders for tfl.maximum/tfl.minimum with +-FLT_MAX arg
Browse files Browse the repository at this point in the history
This commit adds folders for the ops `tfl.maximum` and `tfl.minimum` for the case where one of the arguments is `-FLT_MAX` and `FLT_MAX`, respectively.

PiperOrigin-RevId: 625494071
  • Loading branch information
tensorflower-gardener committed Apr 16, 2024
1 parent 3f9e4d2 commit 0e0bca7
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tensorflow/compiler/mlir/lite/ir/tfl_ops.cc
Expand Up @@ -578,6 +578,14 @@ inline bool IsBF16ShapedType(Type t) {
return false;
}

// Returns true if it is a shaped type of FloatType elements.
inline bool IsFloatShapedType(Type t) {
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
return shaped_type.getElementType().isa<FloatType>();
}
return false;
}

// Returns new shape with rank 'new_dims' with padded ones on the
// left if needed.
inline std::vector<int64_t> GetPaddedShape(ArrayRef<int64_t> old_shape,
Expand Down Expand Up @@ -3069,6 +3077,50 @@ OpFoldResult SquareOp::fold(FoldAdaptor adaptor) {
return ConstFoldUnaryOp(result_type, operands[0], compute);
}

//===----------------------------------------------------------------------===//
// MaximumOp
//===----------------------------------------------------------------------===//

OpFoldResult MaximumOp::fold(FoldAdaptor adaptor) {
auto lhs_type = getLhs().getType();
auto rhs_type = getRhs().getType();
// Only constant fold for float tensors of the same type is implemented.
if (lhs_type != rhs_type || !IsFloatShapedType(lhs_type)) return nullptr;

auto lhs = adaptor.getLhs().dyn_cast_or_null<DenseElementsAttr>();
auto rhs = adaptor.getRhs().dyn_cast_or_null<DenseElementsAttr>();
if (lhs && lhs.isSplat()) {
APFloat lhs_value = lhs.getSplatValue<APFloat>();
lhs_value.changeSign();
if (lhs_value.isLargest()) return getRhs();
}
if (rhs && rhs.isSplat()) {
APFloat rhs_value = rhs.getSplatValue<APFloat>();
rhs_value.changeSign();
if (rhs_value.isLargest()) return getLhs();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// MinimumOp
//===----------------------------------------------------------------------===//

OpFoldResult MinimumOp::fold(FoldAdaptor adaptor) {
auto lhs_type = getLhs().getType();
auto rhs_type = getRhs().getType();
// Only constant fold for float tensors of the same type is implemented.
if (lhs_type != rhs_type || !IsFloatShapedType(lhs_type)) return nullptr;

auto lhs = adaptor.getLhs().dyn_cast_or_null<DenseElementsAttr>();
auto rhs = adaptor.getRhs().dyn_cast_or_null<DenseElementsAttr>();
if (lhs && lhs.isSplat() && lhs.getSplatValue<APFloat>().isLargest())
return getRhs();
if (rhs && rhs.isSplat() && rhs.getSplatValue<APFloat>().isLargest())
return getLhs();
return nullptr;
}

//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/lite/ir/tfl_ops.td
Expand Up @@ -2269,6 +2269,8 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$max
);

let hasFolder = 1;

let builders = [TFL_BroadcastableBinaryBuilder];

let hasOptions = 0;
Expand Down Expand Up @@ -2528,6 +2530,8 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
TFL_TensorOf<[F32, TFL_Int32Or64, QI8, QUI8, QI16]>:$min
);

let hasFolder = 1;

let builders = [TFL_BroadcastableBinaryBuilder];

let hasOptions = 0;
Expand Down
40 changes: 40 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/const-fold.mlir
Expand Up @@ -181,6 +181,46 @@ func.func @elementwise_unary_ops() -> (tensor<f32>, tensor<f32>, tensor<f32>, te
func.return %7, %8, %9, %10, %11, %12, %13 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>
}

// CHECK-LABEL: @max_with_neg_f32_max_val
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>)
func.func @max_with_neg_f32_max_val(%arg0 : tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%neg_f32_max = arith.constant dense<-3.40282347E+38> : tensor<f32>
%0 = "tfl.maximum"(%arg0, %neg_f32_max) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%1 = "tfl.maximum"(%neg_f32_max, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
func.return %0, %1 : tensor<f32>, tensor<f32>
// CHECK: return %[[ARG0]], %[[ARG0]]
}

// CHECK-LABEL: @min_with_f32_max_val
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>)
func.func @min_with_f32_max_val(%arg0 : tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%f32_max = arith.constant dense<3.40282347E+38> : tensor<f32>
%0 = "tfl.minimum"(%arg0, %f32_max) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%1 = "tfl.minimum"(%f32_max, %arg0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
func.return %0, %1 : tensor<f32>, tensor<f32>
// CHECK: return %[[ARG0]], %[[ARG0]]
}

// CHECK-LABEL: @max_with_neg_f64_max_val
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f64>)
func.func @max_with_neg_f64_max_val(%arg0 : tensor<f64>) -> (tensor<f64>, tensor<f64>) {
%neg_f64_max = arith.constant dense<-1.7976931348623157E+308> : tensor<f64>
%0 = "tfl.maximum"(%arg0, %neg_f64_max) : (tensor<f64>, tensor<f64>) -> tensor<f64>
%1 = "tfl.maximum"(%neg_f64_max, %arg0) : (tensor<f64>, tensor<f64>) -> tensor<f64>
func.return %0, %1 : tensor<f64>, tensor<f64>
// CHECK: return %[[ARG0]], %[[ARG0]]
}

// CHECK-LABEL: @min_with_f64_max_val
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f64>)
func.func @min_with_f64_max_val(%arg0 : tensor<f64>) -> (tensor<f64>, tensor<f64>) {
%f64_max = arith.constant dense<1.7976931348623157E+308> : tensor<f64>
%0 = "tfl.minimum"(%arg0, %f64_max) : (tensor<f64>, tensor<f64>) -> tensor<f64>
%1 = "tfl.minimum"(%f64_max, %arg0) : (tensor<f64>, tensor<f64>) -> tensor<f64>
func.return %0, %1 : tensor<f64>, tensor<f64>
// CHECK: return %[[ARG0]], %[[ARG0]]
}

// CHECK-LABEL: @mul_int
func.func @mul_int() -> (tensor<i32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
%0 = arith.constant dense<8> : tensor<i32>
Expand Down

0 comments on commit 0e0bca7

Please sign in to comment.