Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Calculate min/max and histogram inside the CustomAggregator op
This cl is the first part of removing the Calibration singleton: - Part 1 (this cl): Move the min/max and histogram calculation inside CustomAggregator op. - Part 2 (follow-up cl): Aggregate the statistics inside StatisticsSaver op and remove the singleton. PiperOrigin-RevId: 622053138
- Loading branch information
1 parent
276dc80
commit a700cea
Showing
28 changed files
with
860 additions
and
426 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ | ||
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ | ||
|
||
#include <algorithm> | ||
#include <cmath> | ||
#include <cstdint> | ||
|
||
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" | ||
|
||
namespace stablehlo::quantization { | ||
|
||
// Calculates the bin width from the range and expected number of bins. The | ||
// bin width is formalized to the form of 2^n. As a consequence, the actual | ||
// number of bins might be smaller than the given `num_bins`. | ||
inline float CalculateBinWidth(const float min_value, const float max_value, | ||
const int32_t num_bins) { | ||
const float raw_bin_width = (max_value - min_value) / num_bins; | ||
return std::pow(2, std::ceil(std::log2(raw_bin_width))); | ||
} | ||
|
||
// Calculates the lower bound of the histogram. The lower bound is in form of | ||
// `N * bin_width`. | ||
inline float CalculateLowerBound(const float min_value, const float bin_width) { | ||
return std::floor(min_value / bin_width) * bin_width; | ||
} | ||
|
||
// Calculates the number of bins from the range and bin width. | ||
inline int32_t CalculateActualNumBins(const float min_value, | ||
const float max_value, | ||
const float bin_width) { | ||
const float lower_bound = CalculateLowerBound(min_value, bin_width); | ||
return std::ceil((max_value - lower_bound) / bin_width); | ||
} | ||
|
||
// Calculates the bin index of the current value. | ||
inline int32_t CalculateBinIndex(const float value, const float lower_bound, | ||
const float bin_width) { | ||
return std::floor((value - lower_bound) / bin_width); | ||
} | ||
|
||
// Same as `CalculateBinIndex` but clamps to avoid out-of-bound. | ||
inline int32_t CalculateBinIndexSafe(const float value, const float lower_bound, | ||
const float bin_width, | ||
const int32_t num_bins) { | ||
const int32_t bin_index = CalculateBinIndex(value, lower_bound, bin_width); | ||
return std::clamp(bin_index, 0, num_bins - 1); | ||
} | ||
|
||
// Checks if the given method is a histogram-based calibration method. | ||
inline bool IsHistogramCalibration( | ||
const CalibrationOptions::CalibrationMethod method) { | ||
return method == | ||
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE || | ||
method == | ||
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE || | ||
method == CalibrationOptions:: | ||
CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY || | ||
method == | ||
CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC; | ||
} | ||
|
||
} // namespace stablehlo::quantization | ||
|
||
#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ |
92 changes: 92 additions & 0 deletions
92
...orflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters_test.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h" | ||
|
||
#include <cstdint> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace stablehlo::quantization { | ||
namespace { | ||
|
||
TEST(CalibrationParametersTest, CalculateBinWidthSmallerThanOne) { | ||
float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/25.0, | ||
/*num_bins=*/256); | ||
EXPECT_FLOAT_EQ(bin_width, 0.125); | ||
int32_t actual_num_bins = | ||
CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/25.0, bin_width); | ||
EXPECT_EQ(actual_num_bins, 200); | ||
|
||
// Calculate the bin width with the actual num bins. | ||
float raw_bin_width = 25.0 / actual_num_bins; | ||
EXPECT_FLOAT_EQ(bin_width, raw_bin_width); | ||
} | ||
|
||
TEST(CalibrationParametersTest, CalculateBinWidthLargerThanOne) { | ||
float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/360.0, | ||
/*num_bins=*/256); | ||
EXPECT_FLOAT_EQ(bin_width, 2.0); | ||
int32_t actual_num_bins = | ||
CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/360.0, bin_width); | ||
EXPECT_EQ(actual_num_bins, 180); | ||
|
||
// Calculate the bin width with the actual num bins. | ||
float raw_bin_width = 360.0 / actual_num_bins; | ||
EXPECT_FLOAT_EQ(bin_width, raw_bin_width); | ||
} | ||
|
||
TEST(CalibrationParametersTest, CalculateBinWidthDivisible) { | ||
float bin_width = CalculateBinWidth(/*min_value=*/0.0, /*max_value=*/256.0, | ||
/*num_bins=*/256); | ||
EXPECT_FLOAT_EQ(bin_width, 1.0); | ||
int32_t actual_num_bins = | ||
CalculateActualNumBins(/*min_value=*/0.0, /*max_value=*/256.0, bin_width); | ||
EXPECT_EQ(actual_num_bins, 256); | ||
|
||
// Calculate the bin width with the actual num bins. | ||
float raw_bin_width = 256.0 / actual_num_bins; | ||
EXPECT_FLOAT_EQ(bin_width, raw_bin_width); | ||
} | ||
|
||
TEST(CalibrationParametersTest, CalculateNumBinsDivisible) { | ||
int32_t num_bins = CalculateActualNumBins( | ||
/*min_value=*/0.0, /*max_value=*/4.0, /*bin_width=*/2.0); | ||
|
||
// Expect 2 bins: [0, 2), [2, 4]. | ||
EXPECT_EQ(num_bins, 2); | ||
} | ||
|
||
TEST(CalibrationParametersTest, CalculateNumBinsNotDivisible) { | ||
int32_t num_bins = CalculateActualNumBins( | ||
/*min_value=*/0.0, /*max_value=*/5.0, /*bin_width=*/2.0); | ||
|
||
// Expect 3 bins: [0, 2), [2, 4), [4, 6]. | ||
EXPECT_EQ(num_bins, 3); | ||
} | ||
|
||
TEST(CalibrationParametersTest, CalculateBinIndex) { | ||
int32_t bin_index = CalculateBinIndexSafe(/*value=*/3.0, /*lower_bound=*/0.0, | ||
/*bin_width=*/2.0, /*num_bins=*/2); | ||
EXPECT_EQ(bin_index, 1); | ||
} | ||
|
||
TEST(CalibrationParametersTest, CalculateBinIndexMaxValue) { | ||
int32_t bin_index = CalculateBinIndexSafe(/*value=*/4.0, /*lower_bound=*/0.0, | ||
/*bin_width=*/2.0, /*num_bins=*/2); | ||
EXPECT_EQ(bin_index, 1); | ||
} | ||
|
||
} // namespace | ||
} // namespace stablehlo::quantization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.