diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 422584c0eac6e7..3f158850d9405d 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -247,7 +247,9 @@ def generated_test_models(): "local_response_norm", "log_softmax", "log", + "logical_and", "logical_or", + "logical_xor", "lstm", "max_pool", "maximum", diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 59bab3c4ecd20b..e19779ea59d441 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -22,79 +22,118 @@ namespace tflite { namespace ops { namespace builtin { namespace elementwise { +namespace { +bool IsNumericSupportedType(const TfLiteType type) { + return type == kTfLiteFloat32; +} + +bool IsLogicalSupportedType(const TfLiteType type) { + return type == kTfLiteBool; +} + +typedef bool (*IsSupportedType)(TfLiteType); +template TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); TF_LITE_ENSURE_EQ(context, input->type, output->type); - // Quantized float is not supported yet. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + if (!IsSupportedType(input->type)) { + context->ReportError(context, "Current data type %d is not supported.", + input->type); + return kTfLiteError; + } return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); } -inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, - float float_func(float)) { +template +inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, + T func(T), TfLiteType expected_type) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); - switch (input->type) { - case kTfLiteFloat32: { - size_t elements = NumElements(input); - const float* in = GetTensorData(input); - const float* in_end = in + elements; - float* out = output->data.f; - for (; in < in_end; in++, out++) *out = float_func(*in); - return kTfLiteOk; - } - default: { - context->ReportError(context, "Input type is %d, requires float32", - input->type); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, input->type, expected_type); + const int64_t num_elements = NumElements(input); + const T* in_data = GetTensorData(input); + T* out_data = GetTensorData(output); + for (int64_t i = 0; i < num_elements; ++i) { + out_data[i] = func(in_data[i]); } + return kTfLiteOk; +} + +inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { + return EvalImpl(context, node, float_func, kTfLiteFloat32); +} + +inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node, + bool bool_func(bool)) { + return EvalImpl(context, node, bool_func, kTfLiteBool); } TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::sin); + return EvalNumeric(context, node, std::sin); } TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::log); + return EvalNumeric(context, node, std::log); } TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, std::sqrt); + return EvalNumeric(context, node, std::sqrt); } TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { - return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); }); + return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); }); +} + +TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { + return EvalLogical(context, node, [](bool v) { return !v; }); } +} // namespace } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::SinEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SinEval}; return &r; } TfLiteRegistration* Register_LOG() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::LogEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogEval}; return &r; } TfLiteRegistration* Register_SQRT() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::SqrtEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::SqrtEval}; return &r; } TfLiteRegistration* Register_RSQRT() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, - elementwise::RsqrtEval}; + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::RsqrtEval}; + return &r; +} + +TfLiteRegistration* Register_LOGICAL_NOT() { + static TfLiteRegistration r = { + /*init=*/nullptr, /*free=*/nullptr, + elementwise::GenericPrepare, + elementwise::LogicalNotEval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index ce4c602ee5c788..b9d7d73c52862d 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,26 +24,40 @@ namespace { using ::testing::ElementsAreArray; -class ElementWiseOpModel : public SingleOpModel { +class ElementWiseOpBaseModel : public SingleOpModel { public: - ElementWiseOpModel(BuiltinOperator op, - std::initializer_list input_shape) { + int input() const { return input_; } + int output() const { return output_; } + + protected: + int input_; + int output_; +}; + +class ElementWiseOpFloatModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpFloatModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } +}; - int input() const { return input_; } - int output() const { return output_; } - - private: - int input_; - int output_; +class ElementWiseOpBoolModel : public ElementWiseOpBaseModel { + public: + ElementWiseOpBoolModel(BuiltinOperator op, + std::initializer_list input_shape) { + input_ = AddInput(TensorType_BOOL); + output_ = AddOutput(TensorType_BOOL); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); + BuildInterpreter({input_shape}); + } }; TEST(ElementWise, Sin) { - ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -52,7 +66,7 @@ TEST(ElementWise, Sin) { } TEST(ElementWise, Log) { - ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -61,7 +75,7 @@ TEST(ElementWise, Log) { } TEST(ElementWise, Sqrt) { - ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 1, 2, 4}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -70,7 +84,7 @@ TEST(ElementWise, Sqrt) { } TEST(ElementWise, Rsqrt) { - ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); + ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {1, 2, 4, 9}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -78,6 +92,15 @@ TEST(ElementWise, Rsqrt) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, LogicalNot) { + ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {true, false, true, false}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray({false, true, false, true})); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc index 3dc39bf79a1c05..87c2fee667ccaf 100644 --- a/tensorflow/contrib/lite/kernels/logical.cc +++ b/tensorflow/contrib/lite/kernels/logical.cc @@ -105,6 +105,11 @@ TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) { return LogicalImpl(context, node, logical_or_func); } +TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) { + const auto logical_and_func = std::logical_and(); + return LogicalImpl(context, node, logical_and_func); +} + } // namespace } // namespace logical @@ -116,6 +121,14 @@ TfLiteRegistration* Register_LOGICAL_OR() { return &r; } +TfLiteRegistration* Register_LOGICAL_AND() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare, + logical::LogicalAndEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/logical_test.cc b/tensorflow/contrib/lite/kernels/logical_test.cc index 382008245bf0b0..206cbde98fa48e 100644 --- a/tensorflow/contrib/lite/kernels/logical_test.cc +++ b/tensorflow/contrib/lite/kernels/logical_test.cc @@ -52,6 +52,11 @@ class LogicalOpModel : public SingleOpModel { CreateLogicalOrOptions(builder_).Union()); break; } + case BuiltinOperator_LOGICAL_AND: { + SetBuiltinOp(op, BuiltinOptions_LogicalAndOptions, + CreateLogicalAndOptions(builder_).Union()); + break; + } default: { FAIL() << "We shouldn't get here."; } } } @@ -77,6 +82,26 @@ TEST(LogicalTest, BroadcastLogicalOr) { EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } +TEST(LogicalTest, LogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true, false, true, false}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(LogicalTest, BroadcastLogicalAnd) { + LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_AND); + model.PopulateTensor(model.input1(), {true, false, false, true}); + model.PopulateTensor(model.input2(), {true}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 5ad0f4d23266b6..8d2c108116e166 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -109,6 +109,8 @@ TfLiteRegistration* Register_FAKE_QUANT(); TfLiteRegistration* Register_PACK(); TfLiteRegistration* Register_ONE_HOT(); TfLiteRegistration* Register_LOGICAL_OR(); +TfLiteRegistration* Register_LOGICAL_AND(); +TfLiteRegistration* Register_LOGICAL_NOT(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -228,6 +230,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_PACK, Register_PACK()); AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT()); AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); + AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); + AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 1bbf918fd7d8f4..3d1f8c07d28853 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2997,33 +2997,55 @@ def build_inputs(parameters, sess, inputs, outputs): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def _make_logical_tests(op): + """Make a set of tests to do logical operations.""" + + def logical(zip_path): + """Generate examples.""" + test_parameters = [{ + "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the logical testing graph.""" + input_value1 = tf.placeholder( + dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1]) + out = op(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(tf.bool, + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return logical + + def make_logical_or_tests(zip_path): """Make a set of tests to do logical_or.""" + return _make_logical_tests(tf.logical_or)(zip_path) - test_parameters = [{ - "input_shape_pair": [([], []), ([1, 1, 1, 3], [1, 1, 1, 3]), - ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), - ([5, 5], [1]), ([10], [2, 4, 10])], - }] - def build_graph(parameters): - """Build the logical_or op testing graph.""" - input_value1 = tf.placeholder( - dtype=tf.bool, name="input1", shape=parameters["input_shape_pair"][0]) - input_value2 = tf.placeholder( - dtype=tf.bool, name="input2", shape=parameters["input_shape_pair"][1]) - out = tf.logical_or(input_value1, input_value2) - return [input_value1, input_value2], [out] +def make_logical_and_tests(zip_path): + """Make a set of tests to do logical_and.""" + return _make_logical_tests(tf.logical_and)(zip_path) - def build_inputs(parameters, sess, inputs, outputs): - input_value1 = create_tensor_data(tf.bool, - parameters["input_shape_pair"][0]) - input_value2 = create_tensor_data(tf.bool, - parameters["input_shape_pair"][1]) - return [input_value1, input_value2], sess.run( - outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_logical_xor_tests(zip_path): + """Make a set of tests to do logical_xor. + + Test logical_not as well. + """ + return _make_logical_tests(tf.logical_xor)(zip_path) # Toco binary path provided by the generate rule. diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 4ece561e97643a..9ff89e9a653173 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1381,6 +1381,10 @@ std::vector> BuildOperatorList() { ops.emplace_back(new SimpleOperator("POW", OperatorType::kPow)); ops.emplace_back(new SimpleOperator( "LOGICAL_OR", OperatorType::kLogicalOr)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_AND", OperatorType::kLogicalAnd)); + ops.emplace_back(new SimpleOperator( + "LOGICAL_NOT", OperatorType::kLogicalNot)); // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 12fdbbf214a2b2..fc854461b4e816 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -129,6 +129,10 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("POW", OperatorType::kPow); CheckSimpleOperator("LOGICAL_OR", OperatorType::kLogicalOr); + CheckSimpleOperator("LOGICAL_AND", + OperatorType::kLogicalAnd); + CheckSimpleOperator("LOGICAL_NOT", + OperatorType::kLogicalNot); } TEST_F(OperatorTest, BuiltinAdd) {