Skip to content

Commit

Permalink
Implementation of logical_and logical_not
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 207642985
  • Loading branch information
tensorflower-gardener committed Aug 7, 2018
1 parent 56a82b0 commit 0fc1de7
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 64 deletions.
2 changes: 2 additions & 0 deletions tensorflow/contrib/lite/build_def.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def generated_test_models():
"local_response_norm",
"log_softmax",
"log",
"logical_and",
"logical_or",
"logical_xor",
"lstm",
"max_pool",
"maximum",
Expand Down
99 changes: 69 additions & 30 deletions tensorflow/contrib/lite/kernels/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,79 +22,118 @@ namespace tflite {
namespace ops {
namespace builtin {
namespace elementwise {
namespace {

bool IsNumericSupportedType(const TfLiteType type) {
return type == kTfLiteFloat32;
}

bool IsLogicalSupportedType(const TfLiteType type) {
return type == kTfLiteBool;
}

typedef bool (*IsSupportedType)(TfLiteType);
template <IsSupportedType>
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, output->type);
// Quantized float is not supported yet.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
if (!IsSupportedType(input->type)) {
context->ReportError(context, "Current data type %d is not supported.",
input->type);
return kTfLiteError;
}
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}

inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
template <typename T>
inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
T func(T), TfLiteType expected_type) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32: {
size_t elements = NumElements(input);
const float* in = GetTensorData<float>(input);
const float* in_end = in + elements;
float* out = output->data.f;
for (; in < in_end; in++, out++) *out = float_func(*in);
return kTfLiteOk;
}
default: {
context->ReportError(context, "Input type is %d, requires float32",
input->type);
return kTfLiteError;
}
TF_LITE_ENSURE_EQ(context, input->type, expected_type);
const int64_t num_elements = NumElements(input);
const T* in_data = GetTensorData<T>(input);
T* out_data = GetTensorData<T>(output);
for (int64_t i = 0; i < num_elements; ++i) {
out_data[i] = func(in_data[i]);
}
return kTfLiteOk;
}

inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
float float_func(float)) {
return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
}

inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
bool bool_func(bool)) {
return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
}

TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, std::sin);
return EvalNumeric(context, node, std::sin);
}

TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, std::log);
return EvalNumeric(context, node, std::log);
}

TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, std::sqrt);
return EvalNumeric(context, node, std::sqrt);
}

TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); });
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
}

TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}

} // namespace
} // namespace elementwise

TfLiteRegistration* Register_SIN() {
static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
elementwise::SinEval};
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SinEval};
return &r;
}

TfLiteRegistration* Register_LOG() {
static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
elementwise::LogEval};
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::LogEval};
return &r;
}

TfLiteRegistration* Register_SQRT() {
static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
elementwise::SqrtEval};
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::SqrtEval};
return &r;
}

TfLiteRegistration* Register_RSQRT() {
static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
elementwise::RsqrtEval};
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
elementwise::RsqrtEval};
return &r;
}

TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
elementwise::LogicalNotEval};
return &r;
}

Expand Down
49 changes: 36 additions & 13 deletions tensorflow/contrib/lite/kernels/elementwise_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,40 @@ namespace {

using ::testing::ElementsAreArray;

class ElementWiseOpModel : public SingleOpModel {
class ElementWiseOpBaseModel : public SingleOpModel {
public:
ElementWiseOpModel(BuiltinOperator op,
std::initializer_list<int> input_shape) {
int input() const { return input_; }
int output() const { return output_; }

protected:
int input_;
int output_;
};

class ElementWiseOpFloatModel : public ElementWiseOpBaseModel {
public:
ElementWiseOpFloatModel(BuiltinOperator op,
std::initializer_list<int> input_shape) {
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(op, BuiltinOptions_NONE, 0);
BuildInterpreter({input_shape});
}
};

int input() const { return input_; }
int output() const { return output_; }

private:
int input_;
int output_;
class ElementWiseOpBoolModel : public ElementWiseOpBaseModel {
public:
ElementWiseOpBoolModel(BuiltinOperator op,
std::initializer_list<int> input_shape) {
input_ = AddInput(TensorType_BOOL);
output_ = AddOutput(TensorType_BOOL);
SetBuiltinOp(op, BuiltinOptions_NONE, 0);
BuildInterpreter({input_shape});
}
};

TEST(ElementWise, Sin) {
ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
ElementWiseOpFloatModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
Expand All @@ -52,7 +66,7 @@ TEST(ElementWise, Sin) {
}

TEST(ElementWise, Log) {
ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
ElementWiseOpFloatModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
Expand All @@ -61,7 +75,7 @@ TEST(ElementWise, Log) {
}

TEST(ElementWise, Sqrt) {
ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1});
ElementWiseOpFloatModel m(BuiltinOperator_SQRT, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {0, 1, 2, 4});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
Expand All @@ -70,14 +84,23 @@ TEST(ElementWise, Sqrt) {
}

TEST(ElementWise, Rsqrt) {
ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
ElementWiseOpFloatModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {1, 2, 4, 9});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333})));
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}

TEST(ElementWise, LogicalNot) {
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
m.PopulateTensor<bool>(m.input(), {true, false, true, false});
m.Invoke();
EXPECT_THAT(m.ExtractVector<bool>(m.output()),
ElementsAreArray({false, true, false, true}));
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}

} // namespace
} // namespace tflite

Expand Down
13 changes: 13 additions & 0 deletions tensorflow/contrib/lite/kernels/logical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
return LogicalImpl(context, node, logical_or_func);
}

TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
const auto logical_and_func = std::logical_and<bool>();
return LogicalImpl(context, node, logical_and_func);
}

} // namespace
} // namespace logical

Expand All @@ -116,6 +121,14 @@ TfLiteRegistration* Register_LOGICAL_OR() {
return &r;
}

TfLiteRegistration* Register_LOGICAL_AND() {
// Init, Free, Prepare, Eval are satisfying the Interface required by
// TfLiteRegistration.
static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare,
logical::LogicalAndEval};
return &r;
}

} // namespace builtin
} // namespace ops
} // namespace tflite
25 changes: 25 additions & 0 deletions tensorflow/contrib/lite/kernels/logical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class LogicalOpModel : public SingleOpModel {
CreateLogicalOrOptions(builder_).Union());
break;
}
case BuiltinOperator_LOGICAL_AND: {
SetBuiltinOp(op, BuiltinOptions_LogicalAndOptions,
CreateLogicalAndOptions(builder_).Union());
break;
}
default: { FAIL() << "We shouldn't get here."; }
}
}
Expand All @@ -77,6 +82,26 @@ TEST(LogicalTest, BroadcastLogicalOr) {
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}

TEST(LogicalTest, LogicalAnd) {
LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, BuiltinOperator_LOGICAL_AND);
model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
model.PopulateTensor<bool>(model.input2(), {true, false, true, false});
model.Invoke();

EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}

TEST(LogicalTest, BroadcastLogicalAnd) {
LogicalOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, BuiltinOperator_LOGICAL_AND);
model.PopulateTensor<bool>(model.input1(), {true, false, false, true});
model.PopulateTensor<bool>(model.input2(), {true});
model.Invoke();

EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}

} // namespace
} // namespace tflite

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/contrib/lite/kernels/register.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ TfLiteRegistration* Register_FAKE_QUANT();
TfLiteRegistration* Register_PACK();
TfLiteRegistration* Register_ONE_HOT();
TfLiteRegistration* Register_LOGICAL_OR();
TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();

TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
Expand Down Expand Up @@ -228,6 +230,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());

// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
Expand Down

0 comments on commit 0fc1de7

Please sign in to comment.