From 16dd0d4e718ccfb980cb9d97a54e21a295a9c66d Mon Sep 17 00:00:00 2001 From: Josh Beal Date: Wed, 17 Apr 2019 11:48:55 -0700 Subject: [PATCH] Add Int32 and Int64 value support to SplitV. --- tensorflow/lite/kernels/split_v.cc | 12 ++++++++++- tensorflow/lite/kernels/split_v_test.cc | 28 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/tensorflow/lite/kernels/split_v.cc b/tensorflow/lite/kernels/split_v.cc index c95396c621b988..e7b7b95a15e211 100644 --- a/tensorflow/lite/kernels/split_v.cc +++ b/tensorflow/lite/kernels/split_v.cc @@ -126,7 +126,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto input_type = op_context.input->type; TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || - input_type == kTfLiteInt16); + input_type == kTfLiteInt16 || + input_type == kTfLiteInt32 || + input_type == kTfLiteInt64); for (int i = 0; i < NumOutputs(node); ++i) { GetOutput(context, node, i)->type = input_type; } @@ -182,6 +184,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_SPLIT_V(int16_t); break; } + case kTfLiteInt32: { + TF_LITE_SPLIT_V(int32_t); + break; + } + case kTfLiteInt64: { + TF_LITE_SPLIT_V(int64_t); + break; + } default: context->ReportError(context, "Type %s currently not supported.", TfLiteTypeGetName(op_context.input->type)); diff --git a/tensorflow/lite/kernels/split_v_test.cc b/tensorflow/lite/kernels/split_v_test.cc index 27fed63f0eea45..ced8f5341dad61 100644 --- a/tensorflow/lite/kernels/split_v_test.cc +++ b/tensorflow/lite/kernels/split_v_test.cc @@ -200,6 +200,34 @@ TEST(SplitVOpTest, TwoDimensionalInt16) { {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); } +TEST(SplitVOpTest, TwoDimensionalInt32) { + // Input shape: {4, 3} + // size_splits: {1, 1, 2} + // axis: 0 + // We should have 3 outpus with shapes respectively: + // output 1 : {1, 3} + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); +} + +TEST(SplitVOpTest, TwoDimensionalInt64) { + // Input shape: {4, 3} + // size_splits: {1, 1, 2} + // axis: 0 + // We should have 3 outpus with shapes respectively: + // output 1 : {1, 3} + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); +} + } // namespace } // namespace tflite