Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add num_bins and calibration_data_dir to CalibrationOptions #66233

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ limitations under the License.

namespace stablehlo::quantization {

// TODO: b/321158562 - Make the number of bins configurable.
// Default number of histogram bins for each batch sample.
constexpr int32_t kDefaultNumOfBins = 1 << 9;

// Calculates the bin width from the range and expected number of bins. The
// bin width is formalized to the form of 2^n. As a consequence, the actual
// number of bins might be smaller than the given `num_bins`.
Expand Down Expand Up @@ -70,8 +66,10 @@ inline bool IsHistogramCalibration(
}

// Gets the number of bins for the given calibration method.
inline int32_t GetNumBins(const CalibrationOptions::CalibrationMethod method) {
return IsHistogramCalibration(method) ? kDefaultNumOfBins : 0;
inline int32_t GetNumBins(const CalibrationOptions& calib_opts) {
return IsHistogramCalibration(calib_opts.calibration_method())
? calib_opts.calibration_parameters().num_bins()
: 0;
}

} // namespace stablehlo::quantization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,11 @@ absl::StatusOr<ModuleOp> CalibrationComponent::Run(
TF_ASSIGN_OR_RETURN(const std::string precalibrated_saved_model_dir,
CreateTmpDir());

// TODO: b/333809933 - Make the calibration statistics directory configurable.
TF_ASSIGN_OR_RETURN(const std::string calibration_data_dir, CreateTmpDir());
std::string calibration_data_dir =
config.calibration_options().calibration_data_dir();
if (calibration_data_dir.empty()) {
TF_ASSIGN_OR_RETURN(calibration_data_dir, CreateTmpDir());
}

TF_ASSIGN_OR_RETURN(ExportedModel exported_model,
ExportToSavedModel(module_op, calibration_data_dir,
Expand Down
64 changes: 20 additions & 44 deletions tensorflow/compiler/mlir/quantization/stablehlo/cc/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,59 +29,35 @@ void PopulateDefaultCalibrationOptions(QuantizationConfig& quant_config) {
quant_config.mutable_calibration_options()->set_calibration_method(
CalibrationOptions::CALIBRATION_METHOD_MIN_MAX);
}

switch (quant_config.calibration_options().calibration_method()) {
case CalibrationOptions::CALIBRATION_METHOD_MIN_MAX:
break;
case CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX:
break;
case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE:
if (quant_config.calibration_options()
.calibration_parameters()
.initial_num_bins() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_initial_num_bins(256);
}
if (quant_config.calibration_options()
.calibration_parameters()
.min_percentile() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_min_percentile(0.001);
}
if (quant_config.calibration_options()
.calibration_parameters()
.max_percentile() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_max_percentile(99.999);
}
break;
case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE:
if (quant_config.calibration_options()
.calibration_parameters()
.initial_num_bins() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_initial_num_bins(256);
}
break;
case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY:
if (quant_config.calibration_options()
.calibration_parameters()
.initial_num_bins() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_initial_num_bins(256);
}
break;
case CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC:
if (quant_config.calibration_options()
.calibration_parameters()
.initial_num_bins() == 0) {
.num_bins() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_initial_num_bins(256);
->set_num_bins(512);
}
if (quant_config.calibration_options().calibration_method() ==
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE) {
if (quant_config.calibration_options()
.calibration_parameters()
.min_percentile() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_min_percentile(0.001);
}
if (quant_config.calibration_options()
.calibration_parameters()
.max_percentile() == 0) {
quant_config.mutable_calibration_options()
->mutable_calibration_parameters()
->set_max_percentile(99.999);
}
}
break;
default:
Expand Down
27 changes: 11 additions & 16 deletions tensorflow/compiler/mlir/quantization/stablehlo/cc/config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,16 @@ TEST(PopulateDefaultsTest, ExplicitCalibrationOptionsNotOverridden) {
*config.mutable_calibration_options();
calibration_options.set_calibration_method(
CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX);
calibration_options.mutable_calibration_parameters()->set_initial_num_bins(
512);
calibration_options.mutable_calibration_parameters()->set_num_bins(512);

// Test that if the user explicitly provided `calibration_options`, it is not
// overridden.
const QuantizationConfig new_config = PopulateDefaults(config);
EXPECT_THAT(new_config.calibration_options().calibration_method(),
Eq(CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX));
EXPECT_THAT(new_config.calibration_options()
.calibration_parameters()
.initial_num_bins(),
Eq(512));
EXPECT_THAT(
new_config.calibration_options().calibration_parameters().num_bins(),
Eq(512));
}

TEST(PopulateDefaultsTest, DefaultNumbersPopulatedForPartOfCalibrationOptions) {
Expand All @@ -89,18 +87,16 @@ TEST(PopulateDefaultsTest, DefaultNumbersPopulatedForPartOfCalibrationOptions) {
*config.mutable_calibration_options();
calibration_options.set_calibration_method(
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE);
calibration_options.mutable_calibration_parameters()->set_initial_num_bins(
512);
calibration_options.mutable_calibration_parameters()->set_num_bins(512);

// Test that if the user explicitly provided part of the
// `calibration_options`, it is not overridden, rest of the data are default.
const QuantizationConfig new_config = PopulateDefaults(config);
EXPECT_THAT(new_config.calibration_options().calibration_method(),
Eq(CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE));
EXPECT_THAT(new_config.calibration_options()
.calibration_parameters()
.initial_num_bins(),
Eq(512));
EXPECT_THAT(
new_config.calibration_options().calibration_parameters().num_bins(),
Eq(512));
EXPECT_THAT(new_config.calibration_options()
.calibration_parameters()
.min_percentile(),
Expand All @@ -123,10 +119,9 @@ TEST(PopulateDefaultsTest,
EXPECT_THAT(
new_config.calibration_options().calibration_method(),
Eq(CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE));
EXPECT_THAT(new_config.calibration_options()
.calibration_parameters()
.initial_num_bins(),
Eq(256));
EXPECT_THAT(
new_config.calibration_options().calibration_parameters().num_bins(),
Eq(512));
EXPECT_THAT(new_config.calibration_options()
.calibration_parameters()
.min_percentile(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -931,31 +931,31 @@ class CalibrationOptionsTest(quantize_model_test_base.QuantizedModelTest):
'calibration_options': qc.CalibrationOptions(
calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE,
calibration_parameters=qc.CalibrationOptions.CalibrationParameters(
initial_num_bins=10,
num_bins=10,
),
),
},
{
'calibration_options': qc.CalibrationOptions(
calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE,
calibration_parameters=qc.CalibrationOptions.CalibrationParameters(
initial_num_bins=10,
num_bins=10,
),
),
},
{
'calibration_options': qc.CalibrationOptions(
calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY,
calibration_parameters=qc.CalibrationOptions.CalibrationParameters(
initial_num_bins=10,
num_bins=10,
),
),
},
{
'calibration_options': qc.CalibrationOptions(
calibration_method=_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC,
calibration_parameters=qc.CalibrationOptions.CalibrationParameters(
initial_num_bins=10,
num_bins=10,
),
),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ message DebuggerConfig {
}

// Defines various calibration options.
// Next ID: 4
// Next ID: 5
message CalibrationOptions {
// Configurations for calibration methods.
// Next ID: 7
Expand Down Expand Up @@ -308,10 +308,8 @@ message CalibrationOptions {
// Parameters required for calibration.
// Next ID: 4
message CalibrationParameters {
// The number of bins when histogram is initialized. It can be increased
// because histogram is dynamically expanded by sample inputs.
// initial_num_bins is 256 by default.
int32 initial_num_bins = 1;
// The number of histogram bins. Default to 512.
int32 num_bins = 1;
// min_percentile is only used in HISTOGRAM_PERCENTILE.
// min_percentile is 0.001 by default.
float min_percentile = 2;
Expand All @@ -333,6 +331,9 @@ message CalibrationOptions {
// Configures representative dataset. Each item corresponds to a
// representative dataset used to calibrate a function.
repeated RepresentativeDatasetConfig representative_datasets = 3;

// The path to save calibration statistics data.
string calibration_data_dir = 4;
}

// Quantization configuration for StableHLO Quantizer. This is the primary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
// int ops.
func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> {
%0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32>
%1:4 = "tf.CustomAggregator"(%arg0) <{id = "1", calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {min = 7.547870e-07 : f32, max = 0.999992311 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%1:4 = "tf.CustomAggregator"(%arg0) <{id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {min = 7.547870e-07 : f32, max = 0.999992311 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
%3:4 = "tf.CustomAggregator"(%2) <{id = "2", calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {min = -17.5216827 : f32, max = 18.3033524 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%3:4 = "tf.CustomAggregator"(%2) <{id = "2", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {min = -17.5216827 : f32, max = 18.3033524 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
return %3#0 : tensor<1x3xf32>
}
func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
Expand All @@ -36,9 +36,9 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1:

func.func @main_no_unpack(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> {
%0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32>
%1:4 = "tf.CustomAggregator"(%arg0) <{id = "1", calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {device = "", max = 0.999992311 : f32, min = 7.547870e-07 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%1:4 = "tf.CustomAggregator"(%arg0) <{id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {device = "", max = 0.999992311 : f32, min = 7.547870e-07 : f32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%2 = "tf.XlaCallModule"(%1#0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32>
%3:4 = "tf.CustomAggregator"(%2) <{id = "2", calibration_method = 1 : i32, initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {device = "", max = 18.3033524 : f32, min = -17.5216827 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
%3:4 = "tf.CustomAggregator"(%2) <{id = "2", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> {device = "", max = 18.3033524 : f32, min = -17.5216827 : f32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<*xi64>)
return %3#0 : tensor<1x3xf32>
}
func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ func.func @main(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> {
}
// CHECK: @main(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32>
// CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]])
// CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32>
// CHECK: }
// CHECK: }
Expand All @@ -28,10 +28,10 @@ func.func @serving_default(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> {
}
// CHECK: @serving_default(%[[ARG_0:.+]]: tensor<1x4xf32>) -> tensor<1x3xf32>
// CHECK-DAG: %[[CST:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32>
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{calibration_method = 1 : i32, id = "0", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: %[[CUSTOM_AGGREGATOR_0:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG_0]]) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: %[[XLA_CALL_MODULE:.+]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[CST]])
// CHECK-SAME: _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1"
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{calibration_method = 1 : i32, id = "1", initial_num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: %[[CUSTOM_AGGREGATOR_1:.+]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE]]) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor<f32>, tensor<f32>, tensor<0xi64>)
// CHECK: return %[[CUSTOM_AGGREGATOR_1]] : tensor<1x3xf32>
// CHECK: }
// CHECK: }
Expand Down