Skip to content

Commit

Permalink
Calculate min/max and histogram inside the CustomAggregator op
Browse files Browse the repository at this point in the history
This cl is the first part of removing the Calibration singleton:
- Part 1 (this cl): Move the min/max and histogram calculation inside CustomAggregator op.
- Part 2 (follow-up cl): Aggregate the statistics inside StatisticsSaver op and remove the singleton.

PiperOrigin-RevId: 622053138
  • Loading branch information
thaink authored and tensorflower-gardener committed Apr 5, 2024
1 parent 276dc80 commit a700cea
Show file tree
Hide file tree
Showing 28 changed files with 860 additions and 426 deletions.
Expand Up @@ -99,3 +99,20 @@ tf_cc_test(
"@local_tsl//tsl/platform:status_matchers",
],
)

cc_library(
name = "calibration_parameters",
srcs = [],
hdrs = ["calibration_parameters.h"],
compatible_with = get_compatible_with_portable(),
deps = ["//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc"],
)

tf_cc_test(
name = "calibration_parameters_test",
srcs = ["calibration_parameters_test.cc"],
deps = [
":calibration_parameters",
"@com_google_googletest//:gtest_main",
],
)
@@ -0,0 +1,78 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_

#include <algorithm>
#include <cmath>
#include <cstdint>

#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"

namespace stablehlo::quantization {

// Calculates the bin width from the range and expected number of bins. The
// bin width is formalized to the form of 2^n. As a consequence, the actual
// number of bins might be smaller than the given `num_bins`.
inline float CalculateBinWidth(const float min_value, const float max_value,
const int32_t num_bins) {
const float raw_bin_width = (max_value - min_value) / num_bins;
return std::pow(2, std::ceil(std::log2(raw_bin_width)));
}

// Calculates the lower bound of the histogram. The lower bound is in form of
// `N * bin_width`.
inline float CalculateLowerBound(const float min_value, const float bin_width) {
return std::floor(min_value / bin_width) * bin_width;
}

// Calculates the number of bins from the range and bin width.
inline int32_t CalculateActualNumBins(const float min_value,
const float max_value,
const float bin_width) {
const float lower_bound = CalculateLowerBound(min_value, bin_width);
return std::ceil((max_value - lower_bound) / bin_width);
}

// Calculates the bin index of the current value.
inline int32_t CalculateBinIndex(const float value, const float lower_bound,
const float bin_width) {
return std::floor((value - lower_bound) / bin_width);
}

// Same as `CalculateBinIndex` but clamps to avoid out-of-bound.
inline int32_t CalculateBinIndexSafe(const float value, const float lower_bound,
const float bin_width,
const int32_t num_bins) {
const int32_t bin_index = CalculateBinIndex(value, lower_bound, bin_width);
return std::clamp(bin_index, 0, num_bins - 1);
}

// Checks if the given method is a histogram-based calibration method.
inline bool IsHistogramCalibration(
const CalibrationOptions::CalibrationMethod method) {
return method ==
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE ||
method ==
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE ||
method == CalibrationOptions::
CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY ||
method ==
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC;
}

} // namespace stablehlo::quantization

#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_
@@ -0,0 +1,92 @@
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h"

#include <cstdint>

#include <gtest/gtest.h>

namespace stablehlo::quantization {
namespace {

TEST(CalibrationParametersTest, CalculateBinWidthSmallerThanOne) {
float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/25.0,
/*num_bins=*/256);
EXPECT_FLOAT_EQ(bin_width, 0.125);
int32_t actual_num_bins =
CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/25.0, bin_width);
EXPECT_EQ(actual_num_bins, 200);

// Calculate the bin width with the actual num bins.
float raw_bin_width = 25.0 / actual_num_bins;
EXPECT_FLOAT_EQ(bin_width, raw_bin_width);
}

TEST(CalibrationParametersTest, CalculateBinWidthLargerThanOne) {
float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/360.0,
/*num_bins=*/256);
EXPECT_FLOAT_EQ(bin_width, 2.0);
int32_t actual_num_bins =
CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/360.0, bin_width);
EXPECT_EQ(actual_num_bins, 180);

// Calculate the bin width with the actual num bins.
float raw_bin_width = 360.0 / actual_num_bins;
EXPECT_FLOAT_EQ(bin_width, raw_bin_width);
}

TEST(CalibrationParametersTest, CalculateBinWidthDivisible) {
float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/256.0,
/*num_bins=*/256);
EXPECT_FLOAT_EQ(bin_width, 1.0);
int32_t actual_num_bins =
CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/256.0, bin_width);
EXPECT_EQ(actual_num_bins, 256);

// Calculate the bin width with the actual num bins.
float raw_bin_width = 256.0 / actual_num_bins;
EXPECT_FLOAT_EQ(bin_width, raw_bin_width);
}

TEST(CalibrationParametersTest, CalculateNumBinsDivisible) {
int32_t num_bins = CalculateActualNumBins(
/*min_value=*/0.0, /*max_value=*/4.0, /*bin_width=*/2.0);

// Expect 2 bins: [0, 2), [2, 4].
EXPECT_EQ(num_bins, 2);
}

TEST(CalibrationParametersTest, CalculateNumBinsNotDivisible) {
int32_t num_bins = CalculateActualNumBins(
/*min_value=*/0.0, /*max_value=*/5.0, /*bin_width=*/2.0);

// Expect 3 bins: [0, 2), [2, 4), [4, 6].
EXPECT_EQ(num_bins, 3);
}

TEST(CalibrationParametersTest, CalculateBinIndex) {
int32_t bin_index = CalculateBinIndexSafe(/*value=*/3.0, /*lower_bound=*/0.0,
/*bin_width=*/2.0, /*num_bins=*/2);
EXPECT_EQ(bin_index, 1);
}

TEST(CalibrationParametersTest, CalculateBinIndexMaxValue) {
int32_t bin_index = CalculateBinIndexSafe(/*value=*/4.0, /*lower_bound=*/0.0,
/*bin_width=*/2.0, /*num_bins=*/2);
EXPECT_EQ(bin_index, 1);
}

} // namespace
} // namespace stablehlo::quantization
Expand Up @@ -8,10 +8,10 @@
// int ops.
func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> {
%0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32>
%1 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
%2 = "tf.XlaCallModule"(%1, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
%3 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 18.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -17.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32>
return %3 : tensor<1x3xf32>
%1:4 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
%3:4 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 18.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -17.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
return %3#0 : tensor<1x3xf32>
}
func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
Expand All @@ -36,10 +36,10 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1:

func.func @main_no_unpack(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> {
%0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32>
%1 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
%2 = "tf.XlaCallModule"(%1, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
%3 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 18.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -17.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> tensor<1x3xf32>
return %3 : tensor<1x3xf32>
%1:4 = "tf.CustomAggregator"(%arg0) <{id = "1"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 0.999992311 : f32, max_percentile = 0.000000e+00 : f32, min = 7.547870e-07 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
%3:4 = "tf.CustomAggregator"(%2) <{id = "2"}> {calibration_method = 1 : i64, device = "", initial_num_bins = 0 : i64, max = 18.3033524 : f32, max_percentile = 0.000000e+00 : f32, min = -17.5216827 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
return %3#0 : tensor<1x3xf32>
}
func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
Expand Down Expand Up @@ -67,7 +67,7 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1:
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
return %0 : tensor<1x3xf32>
}
// CHECK-LABEL: func.func @main
// CHECK: func.func @main
// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32>
// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<{{.*}}> : tensor<1024x3xf32>
// CHECK: stablehlo.dot_general %[[ARG_0]], %[[CONST_0]]
Expand Down
Expand Up @@ -8,10 +8,10 @@ func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> {
}
// CHECK: @main(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32>
// CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
// CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]])
// CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
// CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32>
// CHECK: }
// CHECK: }
Expand All @@ -28,10 +28,10 @@ func.func @serving_default(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> {
}
// CHECK: @serving_default(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32>
// CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> tensor<1x4xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{id = "0"}> {{.*}} : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
// CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]])
// CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> tensor<1x3xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{id = "1"}> {{.*}} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
// CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32>
// CHECK: }
// CHECK: }
Expand All @@ -51,12 +51,12 @@ func.func @main(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> {
// [b, 0, 1, f]). The weight constant is folded into [0, 1, i, o] format.
// CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<3.000000e+00> : tensor<3x3x8x8xf32>
// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %arg0, dims = [0, 2, 3, 1] : (tensor<1x8x4x4xf32>) -> tensor<1x4x4x8xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]] = "tf.CustomAggregator"(%[[TRANSPOSE_1]]) {{.*}} : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[TRANSPOSE_1]]) {{.*}} : (tensor<1x4x4x8xf32>) -> (tensor<1x4x4x8xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)

// Corresponds to the converted `stablehlo.convolution`. Note that the shapes
// correspond to the dimension numbers of: [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
// CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]]) {{.*}} : (tensor<1x4x4x8xf32>, tensor<3x3x8x8xf32>) -> tensor<1x4x4x8xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]] = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) {{.*}} : (tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) {{.*}} : (tensor<1x4x4x8xf32>) -> (tensor<1x4x4x8xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)

// CHECK: %[[TRANSPOSE_2:.+]] = stablehlo.transpose %[[CUSTOM_AGGREGATOR_1]], dims = [0, 3, 1, 2] : (tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32>
// CHECK: return %[[TRANSPOSE_2]] : tensor<1x8x4x4xf32>

0 comments on commit a700cea

Please sign in to comment.