Skip to content

Commit

Permalink
Merge pull request #43426 from patriklaurell:tflu-resize-bilinear
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 368022694
Change-Id: I7d782318733650e75d0a32d5161b37faa88d3817
  • Loading branch information
tensorflower-gardener committed Apr 12, 2021
2 parents 9d28bf4 + f0e53e2 commit dfd36cf
Show file tree
Hide file tree
Showing 11 changed files with 810 additions and 194 deletions.
1 change: 1 addition & 0 deletions tensorflow/lite/kernels/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ cc_library(
"reference/quantize.h",
"reference/reduce.h",
"reference/requantize.h",
"reference/resize_bilinear.h",
"reference/resize_nearest_neighbor.h",
"reference/round.h",
"reference/softmax.h",
Expand Down
193 changes: 1 addition & 192 deletions tensorflow/lite/kernels/internal/reference/reference_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ limitations under the License.
#include "tensorflow/lite/kernels/internal/reference/quantize.h"
#include "tensorflow/lite/kernels/internal/reference/reduce.h"
#include "tensorflow/lite/kernels/internal/reference/requantize.h"
#include "tensorflow/lite/kernels/internal/reference/resize_bilinear.h"
#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
#include "tensorflow/lite/kernels/internal/reference/round.h"
#include "tensorflow/lite/kernels/internal/reference/softmax.h"
Expand Down Expand Up @@ -1055,198 +1056,6 @@ inline void ScatterNd(const RuntimeShape& indices_shape,
}
}

inline void ComputeInterpolationValues(const float value, const float scale,
const bool half_pixel_centers,
int32 input_size, float* scaled_value,
int32* lower_bound, int32* upper_bound) {
if (half_pixel_centers) {
*scaled_value = (value + 0.5f) * scale - 0.5f;
} else {
*scaled_value = value * scale;
}
float scaled_value_floor = std::floor(*scaled_value);
*lower_bound =
std::max(static_cast<int32>(scaled_value_floor), static_cast<int32>(0));
*upper_bound =
std::min(static_cast<int32>(std::ceil(*scaled_value)), input_size - 1);
}

template <typename T>
inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape,
const T* input_data,
const RuntimeShape& unextended_output_size_shape,
const int32* output_size_data,
const RuntimeShape& unextended_output_shape,
T* output_data) {
// If half_pixel_centers is True, align_corners must be False.
TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_size_shape =
RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);

int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
int32 input_height = input_shape.Dims(1);
int32 input_width = input_shape.Dims(2);
int32 depth = MatchingDim(input_shape, 3, output_shape, 3);

TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];

float height_scale = static_cast<float>(input_height) / output_height;
float width_scale = static_cast<float>(input_width) / output_width;
if (op_params.align_corners && output_height > 1) {
height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
}
if (op_params.align_corners && output_width > 1) {
width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
}
const float rounding_offset = std::numeric_limits<T>::is_integer ? .5f : .0f;

for (int b = 0; b < batches; ++b) {
for (int y = 0; y < output_height; ++y) {
float input_y;
int32 y0, y1;
ComputeInterpolationValues(y, height_scale, op_params.half_pixel_centers,
input_height, &input_y, &y0, &y1);
for (int x = 0; x < output_width; ++x) {
float input_x;
int32 x0, x1;
ComputeInterpolationValues(x, width_scale, op_params.half_pixel_centers,
input_width, &input_x, &x0, &x1);
for (int c = 0; c < depth; ++c) {
T interpolation =
static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
(1 - (input_y - y0)) * (1 - (input_x - x0)) +
input_data[Offset(input_shape, b, y1, x0, c)] *
(input_y - y0) * (1 - (input_x - x0)) +
input_data[Offset(input_shape, b, y0, x1, c)] *
(1 - (input_y - y0)) * (input_x - x0) +
input_data[Offset(input_shape, b, y1, x1, c)] *
(input_y - y0) * (input_x - x0) +
rounding_offset);
output_data[Offset(output_shape, b, y, x, c)] = interpolation;
}
}
}
}
}

inline void ComputeInterpolationValuesInteger(
const int32 value, const int32 scale_10, const bool half_pixel_centers,
int32 input_size, int32* scaled_value, int32* lower_bound,
int32* upper_bound) {
if (half_pixel_centers) {
*scaled_value = value * scale_10 + scale_10 / 2 - (1 << 9);
} else {
*scaled_value = value * scale_10;
}
*lower_bound = std::max(*scaled_value / (1 << 10), 0);
*upper_bound =
std::min((*scaled_value + (1 << 10) - 1) / (1 << 10), input_size - 1);
}

// Same as above but doesn't use any floating-point for the resize
template <typename T>
inline void ResizeBilinearInteger(
const tflite::ResizeBilinearParams& op_params,
const RuntimeShape& unextended_input_shape, const T* input_data,
const RuntimeShape& unextended_output_size_shape,
const int32* output_size_data, const RuntimeShape& unextended_output_shape,
T* output_data) {
// If half_pixel_centers is True, align_corners must be False.
TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners);
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_size_shape =
RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);

const int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
const int32 input_height = input_shape.Dims(1);
const int32 input_width = input_shape.Dims(2);
const int32 depth = MatchingDim(input_shape, 3, output_shape, 3);

TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
const int32 output_height =
output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
const int32 output_width =
output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];

int32 height_scale_10 =
((1 << 10) * input_height + output_height / 2) / output_height;
int32 width_scale_10 =
((1 << 10) * input_width + output_width / 2) / output_width;
if (op_params.align_corners && output_height > 1) {
height_scale_10 =
((1 << 10) * (input_height - 1) + (output_height - 1) / 2) /
(output_height - 1);
}
if (op_params.align_corners && output_width > 1) {
width_scale_10 = ((1 << 10) * (input_width - 1) + (output_width - 1) / 2) /
(output_width - 1);
}

for (int b = 0; b < batches; ++b) {
for (int y = 0; y < output_height; ++y) {
int32 input_y, y0, y1;
ComputeInterpolationValuesInteger(y, height_scale_10,
op_params.half_pixel_centers,
input_height, &input_y, &y0, &y1);
for (int x = 0; x < output_width; ++x) {
int32 input_x, x0, x1;
ComputeInterpolationValuesInteger(x, width_scale_10,
op_params.half_pixel_centers,
input_width, &input_x, &x0, &x1);
for (int c = 0; c < depth; ++c) {
const int64_t output_20_ll =
static_cast<int64_t>(
input_data[Offset(input_shape, b, y0, x0, c)]) *
((1 << 10) - (input_y - (1 << 10) * y0)) *
((1 << 10) - (input_x - (1 << 10) * x0));
const int64_t output_20_lu =
static_cast<int64_t>(
input_data[Offset(input_shape, b, y1, x0, c)]) *
(input_y - (1 << 10) * y0) *
((1 << 10) - (input_x - (1 << 10) * x0));
const int64_t output_20_rl =
static_cast<int64_t>(
input_data[Offset(input_shape, b, y0, x1, c)]) *
((1 << 10) - (input_y - (1 << 10) * y0)) *
(input_x - (1 << 10) * x0);
const int64_t output_20_ru =
static_cast<int64_t>(
input_data[Offset(input_shape, b, y1, x1, c)]) *
(input_y - (1 << 10) * y0) * (input_x - (1 << 10) * x0);
const int64_t output_20 =
output_20_ll + output_20_lu + output_20_rl + output_20_ru;
const int64_t round = (output_20 > 0) ? (1 << 19) : -(1 << 19);
const T interpolation =
static_cast<T>((output_20 + round) / (1 << 20));
output_data[Offset(output_shape, b, y, x, c)] = interpolation;
}
}
}
}
}

template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape,
Expand Down

0 comments on commit dfd36cf

Please sign in to comment.