Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lite: Concatenation Op Refactored #26134

Merged
merged 5 commits into from May 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
88 changes: 40 additions & 48 deletions tensorflow/lite/kernels/concatenation.cc
Expand Up @@ -112,72 +112,64 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// allocate and populate these during Prepare().
// TODO(ycling): Activation function parameter is ignored. For now we dont have
// a model with a Concatenation with fused activation function.
#define TF_LITE_CONCATENATION(type, scalar) \
{ \
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.inputs_count = node->inputs->size; \
type::Concatenation(op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), \
GetTensorData<scalar>(output)); \
}

#define TF_LITE_CONCATENATION_QUANTIZED(type) \
#define TF_LITE_CONCATENATION(scalar) \
{ \
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.input_zeropoint = all_inputs.zero_point(); \
op_params.input_scale = all_inputs.scale(); \
op_params.inputs_count = node->inputs->size; \
op_params.output_zeropoint = output->params.zero_point; \
op_params.output_scale = output->params.scale; \
type::ConcatenationWithScaling(op_params, all_inputs.shapes(), \
if (kernel_type == kReference) { \
reference_ops::Concatenation(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<uint8>(output)); \
GetTensorData<scalar>(output)); \
} else { \
optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
all_inputs.data(), GetTensorShape(output), \
GetTensorData<scalar>(output)); \
} \
}

#define TF_LITE_CONCATENATION_QUANTIZED() \
{ \
VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
tflite::ConcatenationParams op_params; \
op_params.axis = axis; \
op_params.input_zeropoint = all_inputs.zero_point(); \
op_params.input_scale = all_inputs.scale(); \
op_params.inputs_count = node->inputs->size; \
op_params.output_zeropoint = output->params.zero_point; \
op_params.output_scale = output->params.scale; \
if (kernel_type == kReference) { \
reference_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} else { \
optimized_ops::ConcatenationWithScaling( \
op_params, all_inputs.shapes(), all_inputs.data(), \
GetTensorShape(output), GetTensorData<uint8>(output)); \
} \
}

switch (output->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, float);
} else {
TF_LITE_CONCATENATION(optimized_ops, float);
}
TF_LITE_CONCATENATION(float);
break;
case kTfLiteInt32:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, int32);
} else {
TF_LITE_CONCATENATION(optimized_ops, int32);
}
TF_LITE_CONCATENATION(int32);
break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
} else {
TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
}
TF_LITE_CONCATENATION_QUANTIZED();
break;
case kTfLiteInt8:
TF_LITE_CONCATENATION(int8_t);
break;
case kTfLiteInt8: {
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, int8_t);
} else {
TF_LITE_CONCATENATION(optimized_ops, int8_t);
}
} break;
case kTfLiteInt64:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION(reference_ops, int64_t);
} else {
TF_LITE_CONCATENATION(optimized_ops, int64_t);
}
TF_LITE_CONCATENATION(int64_t);
break;

default:
context->ReportError(context,
"Only float32 and uint8 are currently supported.");
context->ReportError(context, "Type '%s' is not supported currently.",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}

Expand Down
77 changes: 77 additions & 0 deletions tensorflow/lite/kernels/concatenation_test.cc
Expand Up @@ -265,6 +265,83 @@ TEST(ConcatenationOpTest, FourInputsQuantizedMixedRangeClampingLogic) {
}));
}

TEST(ConcatenationOpTest, ThreeDimensionalNonQuantizedOneInput) {
QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {2, 1, 2}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/1,
/*num_inputs=*/1);
m0.SetInput<uint8_t>(0, {1.0f, 3.0f, 4.0f, 7.0f});
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear({1.0f, 3.0f, 4.0f, 7.0f})));
}

TEST(ConcatenationOpTest, OneTrivialNonQuantizedInput) {
QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {1}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/0,
/*num_inputs=*/1);
m0.SetInput<uint8_t>(0, {5.0f});
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(), ::testing::ElementsAre(5));
}

TEST(ConcatenationOpTest, TwoDimensionalNonQuantizedOneInput) {
QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/0,
/*num_inputs=*/1);
m0.SetInput<uint8_t>(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
}

TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxesNonQuantized) {
// We will concatenate two tensors along different dimensions.
auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};

QuantizedConcatenationOpModel m0(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/0,
/*num_inputs=*/2);
m0.SetInput<uint8_t>(0, tensor0);
m0.SetInput<uint8_t>(1, tensor1);
m0.Invoke();
EXPECT_THAT(m0.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));

QuantizedConcatenationOpModel m0_negative(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/-2,
/*num_inputs=*/2);
m0_negative.SetInput<uint8_t>(0, tensor0);
m0_negative.SetInput<uint8_t>(1, tensor1);
m0_negative.Invoke();
EXPECT_THAT(m0_negative.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}));

QuantizedConcatenationOpModel m1(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/1,
/*num_inputs=*/2);
m1.SetInput<uint8_t>(0, tensor0);
m1.SetInput<uint8_t>(1, tensor1);
m1.Invoke();
EXPECT_THAT(m1.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));

QuantizedConcatenationOpModel m1_negative(
{TensorType_UINT8, {2, 3}, 0, std::numeric_limits<uint8_t>::max()},
/*axis=*/-1,
/*num_inputs=*/2);
m1_negative.SetInput<uint8_t>(0, tensor0);
m1_negative.SetInput<uint8_t>(1, tensor1);
m1_negative.Invoke();
EXPECT_THAT(m1_negative.GetOutput<uint8_t>(),
ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12}));
}

} // namespace
} // namespace tflite

Expand Down