@@ -34,9 +34,11 @@ limitations under the License.
3434#include " llvm/ADT/APFloat.h"
3535#include " llvm/ADT/APInt.h"
3636#include " llvm/ADT/ArrayRef.h"
37+ #include " llvm/ADT/FloatingPointMode.h"
3738#include " llvm/ADT/STLExtras.h"
3839#include " llvm/ADT/SetVector.h"
3940#include " llvm/ADT/SmallVector.h"
41+ #include " llvm/ADT/SmallVectorExtras.h"
4042#include " llvm/ADT/StringExtras.h"
4143#include " llvm/ADT/TypeSwitch.h"
4244#include " llvm/Support/Casting.h"
@@ -3248,35 +3250,12 @@ void ConstOp::getCanonicalizationPatterns(RewritePatternSet& results,
32483250// CastOp
32493251// ===----------------------------------------------------------------------===//
32503252
3251- OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
3252- auto operands = adaptor.getOperands ();
3253- assert (operands.size () == 1 );
3254- if (getInput ().getType () == getType ()) {
3255- return getInput ();
3256- }
3257-
3258- // For now, only supports cast between integer types.
3259- auto elements_attr = operands[0 ].dyn_cast_or_null <DenseIntElementsAttr>();
3260- if (!elements_attr) {
3261- return nullptr ;
3262- }
3263-
3264- auto result_element_type =
3265- getType ().cast <ShapedType>().getElementType ().dyn_cast <IntegerType>();
3266- auto operand_element_type = getInput ()
3267- .getType ()
3268- .cast <ShapedType>()
3269- .getElementType ()
3270- .dyn_cast <IntegerType>();
3271- // Returns nullptr if either result/operand element type is not integer.
3272- if (!result_element_type || !operand_element_type) {
3273- return nullptr ;
3274- }
3275-
3276- const bool is_unsigned = operand_element_type.isUnsigned ();
3277- const bool involves_bool = operand_element_type.getWidth () == 1 ||
3278- result_element_type.getWidth () == 1 ;
3279- const int output_bitwidth = result_element_type.getWidth ();
3253+ OpFoldResult CastIntToInt (DenseIntElementsAttr data, IntegerType in_type,
3254+ IntegerType out_type) {
3255+ const bool is_unsigned = in_type.isUnsigned ();
3256+ const bool involves_bool =
3257+ in_type.getWidth () == 1 || out_type.getWidth () == 1 ;
3258+ const int output_bitwidth = out_type.getWidth ();
32803259 // The integer cast op is the same as C integer cast. Depends on the operand
32813260 // type's signedness, we will determine whether or not sign extension is
32823261 // needed.
@@ -3287,13 +3266,107 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
32873266 // true input should always be cast to 1 and not -1 as the sign extension
32883267 // would do for signed outputs. Similarly, non-zero inputs should be cast
32893268 // to true. Truncating even numbers to one bit will result in `false`.
3290- return APInt (result_element_type .getWidth (), value != 0 );
3269+ return APInt (out_type .getWidth (), value != 0 );
32913270 }
32923271 return is_unsigned ? value.zextOrTrunc (output_bitwidth)
32933272 : value.sextOrTrunc (output_bitwidth);
32943273 };
32953274
3296- return elements_attr.mapValues (result_element_type, cast);
3275+ return data.mapValues (out_type, cast);
3276+ }
3277+
3278+ OpFoldResult CastFloatToInt (DenseFPElementsAttr data, FloatType in_type,
3279+ IntegerType out_type) {
3280+ const bool from_f32 = in_type.isF32 ();
3281+ const bool to_i32 = out_type.isSignlessInteger (32 );
3282+ if (!from_f32 || !to_i32) {
3283+ return {};
3284+ }
3285+
3286+ auto cast = [&](APFloat value) -> APInt {
3287+ APSInt result (32 , false );
3288+ bool is_exact;
3289+ value.convertToInteger (result, llvm::RoundingMode::TowardZero, &is_exact);
3290+ return result;
3291+ };
3292+
3293+ return data.mapValues (out_type, cast);
3294+ }
3295+
3296+ template <typename InType, typename OutType>
3297+ llvm::SmallVector<OutType> MapStaticCast (DenseElementsAttr data) {
3298+ return llvm::map_to_vector (data.getValues <InType>(),
3299+ [](InType v) { return static_cast <OutType>(v); });
3300+ }
3301+
3302+ OpFoldResult CastIntToFloat (DenseIntElementsAttr data, IntegerType in_type,
3303+ FloatType out_type) {
3304+ const bool from_i32 = in_type.isSignlessInteger (32 );
3305+ const bool to_f32 = out_type.isF32 ();
3306+ if (!from_i32 || !to_f32) {
3307+ return {};
3308+ }
3309+
3310+ return DenseFPElementsAttr::get (data.getType ().clone (out_type),
3311+ MapStaticCast<int32_t , float >(data));
3312+ }
3313+
3314+ OpFoldResult CastFloatToFloat (DenseFPElementsAttr data, FloatType in_type,
3315+ FloatType out_type) {
3316+ auto result_type = data.getType ().clone (out_type);
3317+ if (in_type.isF32 () && out_type.isF64 ()) {
3318+ return DenseFPElementsAttr::get (result_type,
3319+ MapStaticCast<float , double >(data));
3320+ }
3321+
3322+ if (in_type.isF64 () && out_type.isF32 ()) {
3323+ return DenseFPElementsAttr::get (result_type,
3324+ MapStaticCast<double , float >(data));
3325+ }
3326+ return {};
3327+ }
3328+
3329+ OpFoldResult CastOp::fold (FoldAdaptor adaptor) {
3330+ auto operands = adaptor.getOperands ();
3331+ if (operands.size () != 1 ) {
3332+ return {};
3333+ }
3334+ if (getInput ().getType () == getType ()) {
3335+ return getInput ();
3336+ }
3337+
3338+ auto input = operands[0 ];
3339+
3340+ auto in_type = getInput ().getType ().getElementType ();
3341+ auto out_type = getType ().getElementType ();
3342+
3343+ if (auto int_in_type = llvm::dyn_cast_or_null<IntegerType>(in_type)) {
3344+ auto in_data = llvm::dyn_cast_or_null<DenseIntElementsAttr>(input);
3345+ if (!in_data) {
3346+ return {};
3347+ }
3348+ if (auto float_out_type = llvm::dyn_cast_or_null<FloatType>(out_type)) {
3349+ return CastIntToFloat (in_data, int_in_type, float_out_type);
3350+ }
3351+ if (auto int_out_type = llvm::dyn_cast_or_null<IntegerType>(out_type)) {
3352+ return CastIntToInt (in_data, int_in_type, int_out_type);
3353+ }
3354+ }
3355+
3356+ if (auto float_in_type = llvm::dyn_cast_or_null<FloatType>(in_type)) {
3357+ auto in_data = llvm::dyn_cast_or_null<DenseFPElementsAttr>(input);
3358+ if (!in_data) {
3359+ return {};
3360+ }
3361+ if (auto float_out_type = llvm::dyn_cast_or_null<FloatType>(out_type)) {
3362+ return CastFloatToFloat (in_data, float_in_type, float_out_type);
3363+ }
3364+ if (auto int_out_type = llvm::dyn_cast_or_null<IntegerType>(out_type)) {
3365+ return CastFloatToInt (in_data, float_in_type, int_out_type);
3366+ }
3367+ }
3368+
3369+ return {};
32973370}
32983371
32993372// ===----------------------------------------------------------------------===//
0 commit comments