Skip to content

Commit

Permalink
Add a MeanStddevNormalization test case with large vectors for three …
Browse files Browse the repository at this point in the history
…implementations: float, experimental/FP16 and GPU delegate.

The GPU delegate version didn't support non-power-of-two vector, fix it. (Also add some comments.)

PiperOrigin-RevId: 330038282
Change-Id: I380ce7276e42f41e54cdddfa35ad38421da89b15
  • Loading branch information
lrdxgm authored and tensorflower-gardener committed Sep 4, 2020
1 parent 6a00b43 commit bd69906
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,30 @@ std::string GetReduceCode() {
#endif
#ifdef __opencl_c_work_group_collective_functions
#define local_reduce(input, tmp) work_group_reduce_add(input)
#define local_reduce(item, tmp) work_group_reduce_add(item)
#else // !defined(__opencl_c_work_group_collective_functions)
static inline float local_reduce(float input, __local float* tmp) {
static inline float local_reduce(float item, __local float* tmp) {
const int local_id = get_local_id(0);
tmp[local_id] = input;
tmp[local_id] = item;
barrier(CLK_LOCAL_MEM_FENCE);
int reduction_size = get_local_size(0) / 2;
while (reduction_size > 0) {
if (local_id < reduction_size) {
tmp[local_id] += tmp[local_id + reduction_size];
// The number of items still need to be summed
int reduction_size = get_local_size(0);
while (reduction_size > 1) {
// Reduction step: add upper half of the still-to-be-summed vector to the
// lower half, while taking care of odd sizes and rounding. E.g.:
// Number of items still to be summed before: 5
// Local memory before: [a, b, c, d, e];
// Local memory after: [a+d, b+e, c, d, e];
// Threads doing work: id < 2 = floor(5/2)
// Offset to the added items: 3 = ceil(5/2)
// Number of items still to be summed after: 3 = ceil(5/2)
const int active_thread_limit = reduction_size / 2;
const int offset = (reduction_size + 1) / 2;
if (local_id < active_thread_limit) {
tmp[local_id] += tmp[local_id + offset];
}
barrier(CLK_LOCAL_MEM_FENCE);
reduction_size /= 2;
reduction_size = offset;
}
return tmp[0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ TEST_P(MeanStddevNormalizationTest, SeparateBatches) {
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
TensorFloat32 dst_tensor;
auto operation =
CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_);
CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_, 1);
ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation,
BHWC(1, 1, 1, 4), &dst_tensor));

Expand Down Expand Up @@ -88,8 +88,6 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(100.0f, 100.0f, 2.63e-4f) // large mean, large variance
));

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(MeanStddevNormalizationTest);

TEST_F(OpenCLOperationTest, MeanStddevNormalizationAllBatches) {
TensorFloat32 src_tensor;
src_tensor.shape = BHWC(9, 1, 1, 4);
Expand All @@ -113,7 +111,7 @@ TEST_F(OpenCLOperationTest, MeanStddevNormalizationAllBatches) {
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
TensorFloat32 dst_tensor;
auto operation =
CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_);
CreateMeanStdDevNormalization(op_def, env_.GetDevicePtr()->info_, 1);
ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation,
BHWC(9, 1, 1, 4), &dst_tensor));

Expand All @@ -136,6 +134,53 @@ TEST_F(OpenCLOperationTest, MeanStddevNormalizationAllBatches) {
}
}

TEST_F(OpenCLOperationTest, MeanStddevNormalizationLargeVector) {
const float mean = 100.0f;
const float diff = 1.0f;
// Some large vector that is not a round multiple of any SIMD vector sizes.
constexpr int kVectorSize = 16 * 16 + 16 + 1;

TensorFloat32 src_tensor;
src_tensor.shape = BHWC(1, 1, 1, kVectorSize);
src_tensor.data.resize(kVectorSize);
// First input is mean.
src_tensor.data[0] = mean;
// Rest is alternating between mean + diff and mean - diff.
for (int i = 1; i < kVectorSize - 1; i += 2) {
src_tensor.data[i + 0] = mean + diff;
src_tensor.data[i + 1] = mean - diff;
}

for (auto storage : env_.GetSupportedStorages()) {
for (auto precision : env_.GetSupportedPrecisions()) {
OperationDef op_def;
op_def.precision = precision;
auto data_type = DeduceDataTypeFromPrecision(precision);
op_def.src_tensors.push_back({data_type, storage, Layout::BHWC});
op_def.dst_tensors.push_back({data_type, storage, Layout::BHWC});
TensorFloat32 dst_tensor;
auto operation = CreateMeanStdDevNormalization(
op_def, env_.GetDevicePtr()->info_, (kVectorSize + 3) / 4);
ASSERT_OK(ExecuteGPUOperation({src_tensor}, creation_context_, &operation,
BHWC(1, 1, 1, kVectorSize), &dst_tensor));

float expected_output[kVectorSize];
// First output should be 0.
expected_output[0] = 0.0;
// Rest should be alternating between ±√(N/(N-1)).
const float expected_elem =
std::sqrt(static_cast<double>(kVectorSize) /
static_cast<double>(kVectorSize - 1));
for (int i = 1; i < kVectorSize - 1; i += 2) {
expected_output[i + 0] = +expected_elem;
expected_output[i + 1] = -expected_elem;
}
EXPECT_THAT(dst_tensor.data,
Pointwise(FloatNear(1.17e-4f), expected_output));
}
}
}

} // namespace
} // namespace cl
} // namespace gpu
Expand Down
31 changes: 31 additions & 0 deletions tensorflow/lite/kernels/internal/tensor_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,37 @@ TEST(uKernels, MeanStddevNormalizationAllBatches) {
ArrayFloatNear(expected_output, 1.81e-4f)));
}

TEST(uKernels, MeanStddevNormalizationLargeVector) {
const float mean = 100.0f;
const float diff = 1.0f;
// Some large vector that is not a round multiple of any SIMD vector sizes.
// Note this is odd.
constexpr int kVectorSize = 16 * 16 + 16 + 1;

float input[kVectorSize];
// First input is mean.
input[0] = mean;
// Rest is alternating between mean + diff and mean - diff.
for (int i = 1; i < kVectorSize - 1; i += 2) {
input[i + 0] = mean + diff;
input[i + 1] = mean - diff;
}
float output[kVectorSize];
MeanStddevNormalization(input, output, kVectorSize, 1);

float expected_output[kVectorSize];
// First output should be 0.
expected_output[0] = 0.0;
// Rest should be alternating between ±√(N/(N-1)).
const float expected_elem = std::sqrt(static_cast<double>(kVectorSize) /
static_cast<double>(kVectorSize - 1));
for (int i = 1; i < kVectorSize - 1; i += 2) {
expected_output[i + 0] = +expected_elem;
expected_output[i + 1] = -expected_elem;
}
EXPECT_THAT(output, testing::Pointwise(testing::FloatEq(), expected_output));
}

} // namespace tensor_utils
} // namespace tflite

Expand Down

0 comments on commit bd69906

Please sign in to comment.