Skip to content

Commit

Permalink
Fixing per axis quantization bug in flatbuffer importer.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 281367969
Change-Id: Iade9732ca349b81ddd7601717e6c558dfd32c723
  • Loading branch information
Abdurrahman Akkas authored and tensorflower-gardener committed Nov 20, 2019
1 parent 0a793b9 commit be937a3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tensorflow/compiler/mlir/lite/flatbuffer_import.cc
Expand Up @@ -155,7 +155,8 @@ StatusOr<QuantizedType> GetQuantizedType(const TensorT& tensor, Builder builder,
uint32_t flags =
is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0;

if (0 != quant_params.quantized_dimension) {
// Scale size can't be zero as it is checked before.
if (quant_params.scale.size() != 1) {
llvm::SmallVector<double, 4> scales(quant_params.scale.begin(),
quant_params.scale.end());
return mlir::quant::UniformQuantizedPerAxisType::get(
Expand Down
Expand Up @@ -75,6 +75,13 @@ func @qi32_per_axis() -> tensor<3x3x!quant.uniform<i32:f32:1, {1.0, 0.5:1, 0.25:
return %0 : tensor<3x3x!quant.uniform<i32:f32:1, {1.0, 0.5:1, 0.25:1}>>
}

func @qi32_per_axis_zero() -> tensor<3x3x!quant.uniform<i32:f32:0, {1.0, 0.5:1, 0.25:1}>> {
// CHECK-LABEL: @qi32_per_axis_zero
// CHECK: {qtype = tensor<3x3x!quant.uniform<i32:f32:0, {1.000000e+00,5.000000e-01:1,2.500000e-01:1}>>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform<i32:f32:0, {1.000000e+00,5.000000e-01:1,2.500000e-01:1}>>
%0 = "tfl.pseudo_qconst"() { qtype = tensor<3x3x!quant.uniform<i32:f32:0, {1.0, 0.5:1, 0.25:1}>>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform<i32:f32:0, {1.0, 0.5:1, 0.25:1}>>
return %0 : tensor<3x3x!quant.uniform<i32:f32:0, {1.0, 0.5:1, 0.25:1}>>
}

func @qu8() -> tensor<3x!quant.uniform<u8<1:255>:f32, 1.0>> {
// CHECK-LABEL: @qu8
// CHECK: {qtype = tensor<3x!quant.uniform<u8<1:255>:f32, 1.000000e+00>>, value = dense<1> : tensor<3xi8>} : () -> tensor<3x!quant.uniform<u8<1:255>:f32, 1.000000e+00>>
Expand Down

0 comments on commit be937a3

Please sign in to comment.