Skip to content

Commit

Permalink
Merge pull request #48410 from lgeiger:support-quantize-dequant-v4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 367520045
Change-Id: I640cdaebaf8804020807698253e3238e9bfdbb08
  • Loading branch information
tensorflower-gardener committed Apr 8, 2021
2 parents 8c60112 + 01e4f05 commit d8fdfcb
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 2 deletions.
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2068,3 +2068,15 @@ func @all_i64axes(%arg0: tensor<8x16x16xi1>, %arg1: tensor<2xi64>) -> tensor<?xi
// CHECK: %[[V0:.*]] = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32>
// CHECK: "tfl.reduce_all"(%arg0, %[[V0]]) {keep_dims = false} : (tensor<8x16x16xi1>, tensor<2xi32>) -> tensor<?xi1>
}

func @quantize_dequantize_v4(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%cst = constant dense<0.0> : tensor<f32>
%cst_0 = constant dense<255.0> : tensor<f32>
%0 = "tf.QuantizeAndDequantizeV4"(%arg0, %cst, %cst_0) : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>

// CHECK-LABEL: quantize_dequantize_v4
// CHECK: %[[QUANT:.*]] = "tfl.quantize"(%arg0) {qtype = tensor<?x?x!quant.uniform<u8:f32, 1.000000e+00>>} : (tensor<?x?xf32>) -> tensor<?x?x!quant.uniform<u8:f32, 1.000000e+00>>
// CHECK: %[[DEQUANT:.*]] = "tfl.dequantize"(%[[QUANT]]) : (tensor<?x?x!quant.uniform<u8:f32, 1.000000e+00>>) -> tensor<?x?xf32>
// CHECK: return %[[DEQUANT]]
}
4 changes: 2 additions & 2 deletions tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ def LegalizeFakeQuantWithMinMaxVars: Pat<

// TODO(rocky): Not all of the attributes are handled correctly. Make this
// more general if there is a need.
def LegalizeQuantizeAndDequantizeV2 : Pat<
(TF_QuantizeAndDequantizeV2Op $inputs, (ConstantOp F32ElementsAttr:$min),
def LegalizeQuantizeAndDequantizeV4 : Pat<
(TF_QuantizeAndDequantizeV4Op $inputs, (ConstantOp F32ElementsAttr:$min),
(ConstantOp F32ElementsAttr:$max),
$signed_input, $num_bits, $range_given, $round_mode, $narrow_range, $axis),
(TFL_DequantizeOp
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -9862,6 +9862,8 @@ tensor.}]>:$input_max,
);

TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;

let hasCanonicalizer = 1;
}

def TF_QuantizeAndDequantizeV3Op : TF_Op<"QuantizeAndDequantizeV3", [NoSideEffect]> {
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,15 @@ OpFoldResult PowOp::fold(ArrayRef<Attribute> operands) {
return {};
}

//===----------------------------------------------------------------------===//
// QuantizeAndDequantizeV2Op
//===----------------------------------------------------------------------===//

void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4>(context);
}

//===----------------------------------------------------------------------===//
// QrOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1695,3 +1695,11 @@ func @while_with_id_passthrough(%arg0: tensor<7xf32> {tf._user_specified_name =
%7 = "tf.Identity"(%6#2) {device = ""} : (tensor<?xf32>) -> tensor<?xf32>
return %7 : tensor<?xf32>
}

// CHECK-LABEL: testConvertQuantizeAndDequantizeV2ToQuantizeAndDequantizeV4
func @testConvertQuantizeAndDequantizeV2ToQuantizeAndDequantizeV4(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<?x?xf32> {
%0 = "tf.QuantizeAndDequantizeV2"(%arg0, %arg1, %arg2) : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
// CHECK: %[[QUANT:.*]] = "tf.QuantizeAndDequantizeV4"(%arg0, %arg1, %arg2) {axis = -1 : i64, narrow_range = false, num_bits = 8 : i64, range_given = false, round_mode = "HALF_TO_EVEN", signed_input = true} : (tensor<?x?xf32>, tensor<f32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: return %[[QUANT]] : tensor<?x?xf32>
}
10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ def MatrixSetDiagV2ToV3 : Pat<(TF_MatrixSetDiagV2Op $input, $diag, $k),
(TF_MatrixSetDiagV3Op $input, $diag, $k,
(GetStrAttr<"LEFT_LEFT">))>;

//===----------------------------------------------------------------------===//
// QuantizeAndDequantizeV2 op patterns.
//===----------------------------------------------------------------------===//

def QuantizeAndDequantizeV2ToQuantizeAndDequantizeV4 : Pat<
(TF_QuantizeAndDequantizeV2Op $inputs, $min, $max, $signed_input, $num_bits,
$range_given, $round_mode, $narrow_range, $axis),
(TF_QuantizeAndDequantizeV4Op $inputs, $min, $max, $signed_input, $num_bits,
$range_given, $round_mode, $narrow_range, $axis)>;

//===----------------------------------------------------------------------===//
// RealDiv op patterns.
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit d8fdfcb

Please sign in to comment.