diff --git a/tensorflow/lite/kernels/split_v.cc b/tensorflow/lite/kernels/split_v.cc index 060e3c5f79c808..c95396c621b988 100644 --- a/tensorflow/lite/kernels/split_v.cc +++ b/tensorflow/lite/kernels/split_v.cc @@ -183,10 +183,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } default: - context->ReportError( - context, - "Only float32, uint8 and int16 are currently supported, got %d.", - op_context.input->type); + context->ReportError(context, "Type %s currently not supported.", + TfLiteTypeGetName(op_context.input->type)); return kTfLiteError; } #undef TF_LITE_SPLIT_V diff --git a/tensorflow/lite/kernels/split_v_test.cc b/tensorflow/lite/kernels/split_v_test.cc index 2d1d36d6851c12..27fed63f0eea45 100644 --- a/tensorflow/lite/kernels/split_v_test.cc +++ b/tensorflow/lite/kernels/split_v_test.cc @@ -50,16 +50,18 @@ class SplitVOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { - PopulateTensor(input_, data); + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); } void SetSizeSplits(std::initializer_list data) { PopulateTensor(size_splits_, data); } void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } - std::vector GetOutput(int i) { - return ExtractVector(outputs_[i]); + template + std::vector GetOutput(int i) { + return ExtractVector(outputs_[i]); } std::vector GetOutputShape(int i) { return GetTensorShape(outputs_[i]); } @@ -70,99 +72,132 @@ class SplitVOpModel : public SingleOpModel { std::vector outputs_; }; -// TODO(ruic): Add tests to test quantized values. b/119638735 -using TensorValues = std::initializer_list; - +template void Check(int axis, std::initializer_list input_shape, std::initializer_list size_splits_shape, std::vector> output_shapes, - const TensorValues& input_data, + const std::initializer_list& input_data, const std::initializer_list& size_splits_data, - const std::vector& output_data) { + const std::vector>& output_data) { int num_splits = size_splits_data.size(); - SplitVOpModel m({TensorType_FLOAT32, input_shape}, - {TensorType_INT32, size_splits_shape}, num_splits, - kAxisIsATensor); - m.SetInput(input_data); + SplitVOpModel m({T1, input_shape}, {TensorType_INT32, size_splits_shape}, + num_splits, kAxisIsATensor); + m.SetInput(input_data); m.SetSizeSplits(size_splits_data); m.SetAxis(axis); m.Invoke(); for (int i = 0; i < num_splits; ++i) { - EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])); + EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])); EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shapes[i])); } - SplitVOpModel const_m({TensorType_FLOAT32, input_shape}, + SplitVOpModel const_m({T1, input_shape}, {TensorType_INT32, size_splits_shape}, num_splits, axis); - const_m.SetInput(input_data); + const_m.SetInput(input_data); const_m.SetSizeSplits(size_splits_data); const_m.Invoke(); for (int i = 0; i < num_splits; ++i) { - EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])); + EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])); EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shapes[i])); } } TEST(SplitVOpTest, TwoDimensional) { // Input shape: {4, 3} - // size_splits: {1, 1, 3} + // size_splits: {1, 1, 2} // axis: 0 // We should have 3 outpus with shapes respectively: - // output 0 : {1, 3} // output 1 : {1, 3} - // output 1 : {2, 3} - Check(/*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, - {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); } TEST(SplitVOpTest, FourDimensional) { - Check(/*axis=*/0, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 2, 3, 4, 5, 6, 7, 8}, - {9, 10, 11, 12, 13, 14, 15, 16}, - }); - Check(/*axis=*/1, {2, 2, 2, 2}, {2}, {{2, 1, 2, 2}, {2, 1, 2, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, -1}, - { - {1, 2, 3, 4, 9, 10, 11, 12}, - {5, 6, 7, 8, 13, 14, 15, 16}, - }); - Check(/*axis=*/2, {2, 2, 2, 2}, {2}, {{2, 2, 1, 2}, {2, 2, 1, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 2, 5, 6, 9, 10, 13, 14}, - {3, 4, 7, 8, 11, 12, 15, 16}, - }); - Check(/*axis=*/3, {2, 2, 2, 2}, {2}, {{2, 2, 2, 1}, {2, 2, 2, 1}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 3, 5, 7, 9, 11, 13, 15}, - {2, 4, 6, 8, 10, 12, 14, 16}, - }); + Check( + /*axis=*/0, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); + Check( + /*axis=*/1, {2, 2, 2, 2}, {2}, {{2, 1, 2, 2}, {2, 1, 2, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, -1}, + { + {1, 2, 3, 4, 9, 10, 11, 12}, + {5, 6, 7, 8, 13, 14, 15, 16}, + }); + Check( + /*axis=*/2, {2, 2, 2, 2}, {2}, {{2, 2, 1, 2}, {2, 2, 1, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 2, 5, 6, 9, 10, 13, 14}, + {3, 4, 7, 8, 11, 12, 15, 16}, + }); + Check( + /*axis=*/3, {2, 2, 2, 2}, {2}, {{2, 2, 2, 1}, {2, 2, 2, 1}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 3, 5, 7, 9, 11, 13, 15}, + {2, 4, 6, 8, 10, 12, 14, 16}, + }); } TEST(SplitVOpTest, OneDimensional) { - Check(/*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}}, - {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 1, 1}, - {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); + Check( + /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}}, + {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 1, 1}, + {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); } TEST(SplitVOpTest, OneDimensional2) { - Check(/*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {2}, {0}}, - {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 2, -1}, - {{1}, {2}, {3}, {4}, {5}, {6}, {7, 8}, {}}); + Check( + /*axis=*/0, {8}, {8}, {{1}, {1}, {1}, {1}, {1}, {1}, {2}, {0}}, + {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 1, 1, 2, -1}, + {{1}, {2}, {3}, {4}, {5}, {6}, {7, 8}, {}}); } TEST(SplitVOpTest, NegativeAxis) { - Check(/*axis=*/-4, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, - { - {1, 2, 3, 4, 5, 6, 7, 8}, - {9, 10, 11, 12, 13, 14, 15, 16}, - }); + Check( + /*axis=*/-4, {2, 2, 2, 2}, {2}, {{1, 2, 2, 2}, {1, 2, 2, 2}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 1}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); +} + +TEST(SplitVOpTest, TwoDimensionalUint8) { + // Input shape: {4, 3} + // size_splits: {1, 1, 2} + // axis: 0 + // We should have 3 outpus with shapes respectively: + // output 1 : {1, 3} + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); +} + +TEST(SplitVOpTest, TwoDimensionalInt16) { + // Input shape: {4, 3} + // size_splits: {1, 1, 2} + // axis: 0 + // We should have 3 outpus with shapes respectively: + // output 1 : {1, 3} + // output 2 : {1, 3} + // output 3 : {2, 3} + Check( + /*axis=*/0, {4, 3}, {3}, {{1, 3}, {1, 3}, {2, 3}}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 1, 2}, + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9, 10, 11, 12}}); } } // namespace