Skip to content

Commit 47e6ed5

Browse files
LukeBoyertensorflower-gardener
authored andcommitted
Add float -> int and int -> float cast folding.
PiperOrigin-RevId: 663625013
1 parent a01c67c commit 47e6ed5

File tree

2 files changed

+140
-31
lines changed

2 files changed

+140
-31
lines changed

tensorflow/compiler/mlir/lite/ir/tfl_ops.cc

Lines changed: 104 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ limitations under the License.
3434
#include "llvm/ADT/APFloat.h"
3535
#include "llvm/ADT/APInt.h"
3636
#include "llvm/ADT/ArrayRef.h"
37+
#include "llvm/ADT/FloatingPointMode.h"
3738
#include "llvm/ADT/STLExtras.h"
3839
#include "llvm/ADT/SetVector.h"
3940
#include "llvm/ADT/SmallVector.h"
41+
#include "llvm/ADT/SmallVectorExtras.h"
4042
#include "llvm/ADT/StringExtras.h"
4143
#include "llvm/ADT/TypeSwitch.h"
4244
#include "llvm/Support/Casting.h"
@@ -3248,35 +3250,12 @@ void ConstOp::getCanonicalizationPatterns(RewritePatternSet& results,
32483250
// CastOp
32493251
//===----------------------------------------------------------------------===//
32503252

3251-
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
3252-
auto operands = adaptor.getOperands();
3253-
assert(operands.size() == 1);
3254-
if (getInput().getType() == getType()) {
3255-
return getInput();
3256-
}
3257-
3258-
// For now, only supports cast between integer types.
3259-
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
3260-
if (!elements_attr) {
3261-
return nullptr;
3262-
}
3263-
3264-
auto result_element_type =
3265-
getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
3266-
auto operand_element_type = getInput()
3267-
.getType()
3268-
.cast<ShapedType>()
3269-
.getElementType()
3270-
.dyn_cast<IntegerType>();
3271-
// Returns nullptr if either result/operand element type is not integer.
3272-
if (!result_element_type || !operand_element_type) {
3273-
return nullptr;
3274-
}
3275-
3276-
const bool is_unsigned = operand_element_type.isUnsigned();
3277-
const bool involves_bool = operand_element_type.getWidth() == 1 ||
3278-
result_element_type.getWidth() == 1;
3279-
const int output_bitwidth = result_element_type.getWidth();
3253+
OpFoldResult CastIntToInt(DenseIntElementsAttr data, IntegerType in_type,
3254+
IntegerType out_type) {
3255+
const bool is_unsigned = in_type.isUnsigned();
3256+
const bool involves_bool =
3257+
in_type.getWidth() == 1 || out_type.getWidth() == 1;
3258+
const int output_bitwidth = out_type.getWidth();
32803259
// The integer cast op is the same as C integer cast. Depends on the operand
32813260
// type's signedness, we will determine whether or not sign extension is
32823261
// needed.
@@ -3287,13 +3266,107 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
32873266
// true input should always be cast to 1 and not -1 as the sign extension
32883267
// would do for signed outputs. Similarly, non-zero inputs should be cast
32893268
// to true. Truncating even numbers to one bit will result in `false`.
3290-
return APInt(result_element_type.getWidth(), value != 0);
3269+
return APInt(out_type.getWidth(), value != 0);
32913270
}
32923271
return is_unsigned ? value.zextOrTrunc(output_bitwidth)
32933272
: value.sextOrTrunc(output_bitwidth);
32943273
};
32953274

3296-
return elements_attr.mapValues(result_element_type, cast);
3275+
return data.mapValues(out_type, cast);
3276+
}
3277+
3278+
OpFoldResult CastFloatToInt(DenseFPElementsAttr data, FloatType in_type,
3279+
IntegerType out_type) {
3280+
const bool from_f32 = in_type.isF32();
3281+
const bool to_i32 = out_type.isSignlessInteger(32);
3282+
if (!from_f32 || !to_i32) {
3283+
return {};
3284+
}
3285+
3286+
auto cast = [&](APFloat value) -> APInt {
3287+
APSInt result(32, false);
3288+
bool is_exact;
3289+
value.convertToInteger(result, llvm::RoundingMode::TowardZero, &is_exact);
3290+
return result;
3291+
};
3292+
3293+
return data.mapValues(out_type, cast);
3294+
}
3295+
3296+
template <typename InType, typename OutType>
3297+
llvm::SmallVector<OutType> MapStaticCast(DenseElementsAttr data) {
3298+
return llvm::map_to_vector(data.getValues<InType>(),
3299+
[](InType v) { return static_cast<OutType>(v); });
3300+
}
3301+
3302+
OpFoldResult CastIntToFloat(DenseIntElementsAttr data, IntegerType in_type,
3303+
FloatType out_type) {
3304+
const bool from_i32 = in_type.isSignlessInteger(32);
3305+
const bool to_f32 = out_type.isF32();
3306+
if (!from_i32 || !to_f32) {
3307+
return {};
3308+
}
3309+
3310+
return DenseFPElementsAttr::get(data.getType().clone(out_type),
3311+
MapStaticCast<int32_t, float>(data));
3312+
}
3313+
3314+
OpFoldResult CastFloatToFloat(DenseFPElementsAttr data, FloatType in_type,
3315+
FloatType out_type) {
3316+
auto result_type = data.getType().clone(out_type);
3317+
if (in_type.isF32() && out_type.isF64()) {
3318+
return DenseFPElementsAttr::get(result_type,
3319+
MapStaticCast<float, double>(data));
3320+
}
3321+
3322+
if (in_type.isF64() && out_type.isF32()) {
3323+
return DenseFPElementsAttr::get(result_type,
3324+
MapStaticCast<double, float>(data));
3325+
}
3326+
return {};
3327+
}
3328+
3329+
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
3330+
auto operands = adaptor.getOperands();
3331+
if (operands.size() != 1) {
3332+
return {};
3333+
}
3334+
if (getInput().getType() == getType()) {
3335+
return getInput();
3336+
}
3337+
3338+
auto input = operands[0];
3339+
3340+
auto in_type = getInput().getType().getElementType();
3341+
auto out_type = getType().getElementType();
3342+
3343+
if (auto int_in_type = llvm::dyn_cast_or_null<IntegerType>(in_type)) {
3344+
auto in_data = llvm::dyn_cast_or_null<DenseIntElementsAttr>(input);
3345+
if (!in_data) {
3346+
return {};
3347+
}
3348+
if (auto float_out_type = llvm::dyn_cast_or_null<FloatType>(out_type)) {
3349+
return CastIntToFloat(in_data, int_in_type, float_out_type);
3350+
}
3351+
if (auto int_out_type = llvm::dyn_cast_or_null<IntegerType>(out_type)) {
3352+
return CastIntToInt(in_data, int_in_type, int_out_type);
3353+
}
3354+
}
3355+
3356+
if (auto float_in_type = llvm::dyn_cast_or_null<FloatType>(in_type)) {
3357+
auto in_data = llvm::dyn_cast_or_null<DenseFPElementsAttr>(input);
3358+
if (!in_data) {
3359+
return {};
3360+
}
3361+
if (auto float_out_type = llvm::dyn_cast_or_null<FloatType>(out_type)) {
3362+
return CastFloatToFloat(in_data, float_in_type, float_out_type);
3363+
}
3364+
if (auto int_out_type = llvm::dyn_cast_or_null<IntegerType>(out_type)) {
3365+
return CastFloatToInt(in_data, float_in_type, int_out_type);
3366+
}
3367+
}
3368+
3369+
return {};
32973370
}
32983371

32993372
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/lite/tests/const-fold.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,42 @@ func.func @cast_ui8_to_i1() -> tensor<4xi1> {
779779
// CHECK: return %[[CST]]
780780
}
781781

782+
// CHECK-LABEL: @cast_f32_to_i32
783+
func.func @cast_f32_to_i32() -> tensor<8xi32> {
784+
%cst = arith.constant dense<[-1.0, 0.0, 1.5, 0.99, 1.175494351e-38, 3.402823466e+38, -3.402823466e+38, -1.175494351e-38]> : tensor<8xf32>
785+
%0 = "tfl.cast"(%cst) : (tensor<8xf32>) -> tensor<8xi32>
786+
func.return %0 : tensor<8xi32>
787+
}
788+
789+
// CHECK: %cst = arith.constant dense<[-1, 0, 1, 0, 0, 2147483647, -2147483648, 0]> : tensor<8xi32>
790+
791+
// CHECK-LABEL: @cast_i32_to_f32
792+
func.func @cast_i32_to_f32() -> tensor<5xf32> {
793+
%cst = arith.constant dense<[-1, 0, 2, 2147483647, -2147483648]> : tensor<5xi32>
794+
%0 = "tfl.cast"(%cst) : (tensor<5xi32>) -> tensor<5xf32>
795+
func.return %0 : tensor<5xf32>
796+
}
797+
798+
// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 2.000000e+00, 2.14748365E+9, -2.14748365E+9]> : tensor<5xf32>
799+
800+
// CHECK-LABEL: @cast_f64_to_f32
801+
func.func @cast_f64_to_f32() -> tensor<4xf32> {
802+
%cst = arith.constant dense<[-1.0, 0.0, 1.5, 100.0]> : tensor<4xf64>
803+
%0 = "tfl.cast"(%cst) : (tensor<4xf64>) -> tensor<4xf32>
804+
func.return %0 : tensor<4xf32>
805+
}
806+
807+
// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf32>
808+
809+
// CHECK-LABEL: @cast_f32_to_f64
810+
func.func @cast_f32_to_f64() -> tensor<4xf64> {
811+
%cst = arith.constant dense<[-1.0, 0.0, 1.5, 100.0]> : tensor<4xf32>
812+
%0 = "tfl.cast"(%cst) : (tensor<4xf32>) -> tensor<4xf64>
813+
func.return %0 : tensor<4xf64>
814+
}
815+
816+
// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf64>
817+
782818
// CHECK-LABEL: @ConstantFoldFullyConnectedSmall
783819
func.func @ConstantFoldFullyConnectedSmall() -> tensor<3xf32> {
784820
%cst_input = arith.constant dense<[2.0, 3.0]> : tensor<2xf32>

0 commit comments

Comments
 (0)