Skip to content

Commit

Permalink
This change is to provide users with the option `canonicalizing_inf_a…
Browse files Browse the repository at this point in the history
…s_min_max_float` to canonicalize boundary values in the IR replacing -Inf/Inf with MIN/MAX float value. An option is provided to the user to enable this option.

PiperOrigin-RevId: 630372176
  • Loading branch information
tensorflower-gardener committed May 24, 2024
1 parent f8fdab0 commit 9933d53
Show file tree
Hide file tree
Showing 19 changed files with 338 additions and 13 deletions.
27 changes: 27 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,32 @@ tf_cc_test(
],
)

cc_library(
name = "canonicalize_boundary_value",
srcs = [
"transforms/canonicalize_boundary_value_pass.cc",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":tensorflow_lite",
":tensorflow_lite_passes_inc_gen",
"//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
Expand Down Expand Up @@ -1365,6 +1391,7 @@ cc_library(
"tf_tfl_passes.h",
],
deps = [
":canonicalize_boundary_value",
":common",
":fake_quant_utils",
":tensorflow_lite_d2s",
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/lite/common/tfl_pass_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ struct PassConfig {

// Enables the attempt to directly lower composites into tflite ops.
bool enable_composite_direct_lowering = true;

// When set to true, convert +Inf/-Inf to MIN/MAX float value and output of
// convert only contains finite values.
bool canonicalizing_inf_as_min_max_float = true;
};

inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ absl::Status ConvertGraphDefToTFLiteFlatBuffer(
pass_config.guarantee_all_funcs_one_use =
toco_flags.guarantee_all_funcs_one_use();
pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo();
pass_config.canonicalizing_inf_as_min_max_float =
toco_flags.canonicalizing_inf_as_min_max_float();

// StableHLO Quantizer is not supported for GraphDef inputs, so
// quantization_py_function_lib is set to nullptr.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config();
pass_config.enable_composite_direct_lowering =
toco_flags.enable_composite_direct_lowering();
pass_config.canonicalizing_inf_as_min_max_float =
toco_flags.canonicalizing_inf_as_min_max_float();

if (toco_flags.qdq_conversion_mode() == "STATIC") {
pass_config.quant_specs.qdq_conversion_mode =
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ cc_library(
":tf_stablehlo",
":unfold_splat_constant_pass",
":unfuse_batch_norm_pass",
"//tensorflow/compiler/mlir/lite:canonicalize_boundary_value",
"//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes",
"//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes",
Expand Down Expand Up @@ -814,6 +815,7 @@ tf_cc_binary(
"//tensorflow/cc/saved_model:loader",
"//tensorflow/compiler/mlir:init_mlir",
"//tensorflow/compiler/mlir:passes",
"//tensorflow/compiler/mlir/lite:canonicalize_boundary_value",
"//tensorflow/compiler/mlir/lite:flatbuffer_export",
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess",
Expand Down
18 changes: 6 additions & 12 deletions tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/utils.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
Expand Down Expand Up @@ -1998,7 +1999,7 @@ class ConvertReduceOpToTfMax
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || !const_value.isNegative())
!TFL::IsMinFloatValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand All @@ -2023,7 +2024,7 @@ class ConvertReduceOpToTfMin
if (mlir::isa<FloatType>(type)) {
APFloat const_value(.0);
if (failed(GetConstantSplatValue(init_value, const_value)) ||
!const_value.isInfinity() || const_value.isNegative())
!TFL::IsMaxFloatValue(const_value))
return failure();
} else if (mlir::isa<IntegerType>(type) && type.isSignlessInteger()) {
APInt const_value;
Expand Down Expand Up @@ -2081,7 +2082,7 @@ class ConvertReduceOpToTfArgmax
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return value.isNegative() && value.isInfinity();
return TFL::IsMinFloatValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand All @@ -2105,7 +2106,7 @@ class ConvertReduceOpToTfArgmin
return false;
if (mlir::isa<FloatType>(element_type)) {
auto value = *attr.value_begin<APFloat>();
return !value.isNegative() && value.isInfinity();
return TFL::IsMaxFloatValue(value);
} else if (element_type.isInteger(1)) {
auto value = *attr.value_begin<APInt>();
return value.isZero();
Expand Down Expand Up @@ -2596,14 +2597,7 @@ class ConvertMaxPoolOp : public OpConversionPattern<mhlo::ReduceWindowOp> {
}

APFloat element = float_value.getValues<APFloat>()[0];
if (!element.isInfinity()) {
return false;
}
if (!element.isNegative()) {
return false;
}

return true;
return TFL::IsMinFloatValue(element);
}

LogicalResult replaceWithMaxPool(mhlo::ReduceWindowOp op, Value input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// RUN: tf-opt %s --canonicalize-boundary-value --split-input-file | FileCheck %s

// CHECK-LABEL: func.func @clamp_neg_inf_f32() -> tensor<f32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor<f32>
// CHECK: return %[[CONST]] : tensor<f32>

func.func @clamp_neg_inf_f32() -> tensor<f32> {
%ret = stablehlo.constant dense<0xFF800000> : tensor<f32>
return %ret : tensor<f32>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f32() -> tensor<f32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor<f32>
// CHECK: return %[[CONST]] : tensor<f32>
func.func @clamp_pos_inf_f32() -> tensor<f32> {
%ret = stablehlo.constant dense<0x7F800000> : tensor<f32>
return %ret : tensor<f32>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}3.40282347E+38, 1.000000e+01, 2.000000e+01, -3.40282347E+38]]> : tensor<1x4xf32>
// CHECK: return %[[CONST]] : tensor<1x4xf32>
func.func @clamp_pos_inf_f32_tensor() -> tensor<1x4xf32> {
%ret = stablehlo.constant dense<[[0x7F800000, 10.0, 20.0, 0xFF800000]]> : tensor<1x4xf32>
return %ret : tensor<1x4xf32>
}

// -----

// CHECK-LABEL: func.func @clamp_neg_inf_f16() -> tensor<f16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-6.550400e+04> : tensor<f16>
// CHECK: return %[[CONST]] : tensor<f16>
func.func @clamp_neg_inf_f16() -> tensor<f16> {
%ret = stablehlo.constant dense<0xFC00> : tensor<f16>
return %ret : tensor<f16>
}
// -----

// CHECK-LABEL: func.func @clamp_neg_inf_bf16() -> tensor<bf16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<-1.038460e+34> : tensor<bf16>
// CHECK: return %[[CONST]] : tensor<bf16>
func.func @clamp_neg_inf_bf16() -> tensor<bf16> {
%ret = stablehlo.constant dense<0xF800> : tensor<bf16>
return %ret : tensor<bf16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16() -> tensor<f16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<6.550400e+04> : tensor<f16>
// CHECK: return %[[CONST]] : tensor<f16>
func.func @clamp_pos_inf_f16() -> tensor<f16> {
%ret = stablehlo.constant dense<0x7C00> : tensor<f16>
return %ret : tensor<f16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<{{\[\[}}6.550400e+04, 1.000000e+01, 2.000000e+01, -6.550400e+04]]> : tensor<1x4xf16>
// CHECK: return %[[CONST]] : tensor<1x4xf16>
func.func @clamp_pos_inf_f16_tensor() -> tensor<1x4xf16> {
%ret = stablehlo.constant dense<[[0x7C00, 10.0, 20.0, 0xFC00]]> : tensor<1x4xf16>
return %ret : tensor<1x4xf16>
}

// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> {
// CHECK: %[[CONST:.*]] = "tf.Const"() <{value = dense<6.550400e+04> : tensor<3xf16>}> : () -> tensor<3xf16>
// CHECK: return %[[CONST]] : tensor<3xf16>
func.func @clamp_pos_inf_f16_tensor_tf_const() -> tensor<3xf16> {
%ret = "tf.Const"() <{value = dense<0x7C00> : tensor<3xf16>}> : () -> tensor<3xf16>
return %ret : tensor<3xf16>
}


// -----

// CHECK-LABEL: func.func @clamp_pos_inf_f16_arith_op() -> f16 {
// CHECK: %[[CONST:.*]] = arith.constant 6.550400e+04 : f16
// CHECK: return %[[CONST]] : f16
func.func @clamp_pos_inf_f16_arith_op() -> f16 {
%ret = arith.constant 0x7C00 : f16
return %ret : f16
}

2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_tfl_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ void AddPostVariableFreezingTFToTFLConversionPasses(
pass_manager->addNestedPass<mlir::func::FuncOp>(
mlir::TFL::CreateRuntimeVerifyPass());
}
if (pass_config.canonicalizing_inf_as_min_max_float)
pass_manager->addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());
}

void AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,8 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer(
MetadataForReducedPrecisionSupport(quant_specs.support_mask));
}
pass_manager.clear();
if (pass_config.canonicalizing_inf_as_min_max_float)
pass_manager.addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass());
pass_manager.addPass(mlir::odml::createLegalizeStablehloToVhloPass());
if (failed(pass_manager.run(module))) {
return status_handler.Combine(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

namespace mlir {
namespace TFL {
namespace {

#define DEBUG_TYPE "canonicalize-boundary-value"

#define GEN_PASS_DEF_CANONICALIZEBOUNDARYVALUEPASS
#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"

class CanonicalizeBoundaryValuePass
: public impl::CanonicalizeBoundaryValuePassBase<
CanonicalizeBoundaryValuePass> {
void runOnOperation() override;
};

// Clamp constant -Inf/Inf to MIN/MAX float value.
template <typename OpTy>
struct ClampInfToMinMaxFloat : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(OpTy const_op,
PatternRewriter& rewriter) const override {
Attribute attr = const_op.getValueAttr();
if (auto float_attr = llvm::dyn_cast<FloatAttr>(attr)) {
if (float_attr.getValue().isInfinity()) {
FloatType float_type = llvm::dyn_cast<FloatType>(const_op.getType());
if (!float_type) return failure();
rewriter.replaceOpWithNewOp<OpTy>(
const_op, rewriter.getFloatAttr(
float_type, APFloat::getLargest(
float_type.getFloatSemantics(),
float_attr.getValue().isNegative())));
return success();
}
}

ElementsAttr tensor_attr = llvm::dyn_cast<ElementsAttr>(attr);
if (!tensor_attr) return failure();

Type type = tensor_attr.getType();
ShapedType tensor_type = llvm::cast<ShapedType>(type);
auto float_type = dyn_cast<FloatType>(tensor_type.getElementType());
if (!float_type) return failure();

auto vals_orig = tensor_attr.getValues<APFloat>();
// If all values are finite, no need to rewrite.
if (llvm::all_of(vals_orig, [&](APFloat val) { return !val.isInfinity(); }))
return failure();

SmallVector<APFloat> vals_new(llvm::map_range(vals_orig, [&](APFloat val) {
return val.isInfinity()
? APFloat::getLargest(float_type.getFloatSemantics(),
val.isNegative())
: val;
}));
rewriter.replaceOpWithNewOp<OpTy>(
const_op, DenseElementsAttr::get(tensor_type, vals_new));
return success();
}
};

void CanonicalizeBoundaryValuePass::runOnOperation() {
auto* ctx = &getContext();

RewritePatternSet patterns(ctx);
patterns.add<ClampInfToMinMaxFloat<stablehlo::ConstantOp>,
ClampInfToMinMaxFloat<TF::ConstOp>,
ClampInfToMinMaxFloat<arith::ConstantOp>>(ctx);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}

} // end namespace

std::unique_ptr<OperationPass<ModuleOp>> CreateCanonicalizeBoundaryValuePass() {
return std::make_unique<CanonicalizeBoundaryValuePass>();
}

} // end namespace TFL
} // end namespace mlir
3 changes: 3 additions & 0 deletions tensorflow/compiler/mlir/lite/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateReduceTypePrecisionPass();
// so redudant ones may be grouped and removed.
std::unique_ptr<OperationPass<ModuleOp>> CreatePushTransposeThroughEwisePass();

// Create a pass that canonicalize the boundary values.
std::unique_ptr<OperationPass<ModuleOp>> CreateCanonicalizeBoundaryValuePass();

// Creates a pass that brings operations into the same order as graph_info.cc.
std::unique_ptr<OperationPass<func::FuncOp>>
CreatePartitionedTopologicalSortPass();
Expand Down
Loading

0 comments on commit 9933d53

Please sign in to comment.