Skip to content

Commit

Permalink
Added Activation Scenarios in the file.
Browse files Browse the repository at this point in the history
This was one of the TODOs in the file.
  • Loading branch information
Amit Srivastava committed Mar 15, 2019
1 parent 6d3d026 commit fda15fd
Showing 1 changed file with 93 additions and 1 deletion.
94 changes: 93 additions & 1 deletion tensorflow/lite/kernels/fully_connected_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ static float fully_connected_golden_output[] = {

class BaseFullyConnectedOpModel : public SingleOpModel {
public:
// TODO(ahentz): test different activation types too.
BaseFullyConnectedOpModel(
TfLiteRegistration* registration, int units, int batches,
const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
Expand Down Expand Up @@ -428,6 +427,99 @@ TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) {
EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
}

TEST(FloatFullyConnectedOpTest, ActivationRelu6) {
// The optimized kernel assumes that the bias is specified.
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_PIE(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_RELU6);
m.SetWeights({
2, 4, // u = 0
});

m.SetInput({
1, 2, // b = 0
2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(), ElementsAre(6, 6));
}

TEST(FloatFullyConnectedOpTest, ActivationTanh) {
// The optimized kernel assumes that the bias is specified.
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_PIE(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_TANH);
m.SetWeights({
-2, 1, // u = 0
});

m.SetInput({
1, 4, // b = 0
2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear({0.964028, -0.995055})));
}

TEST(FloatFullyConnectedOpTest, ActivationSign) {
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_SIGN_BIT);
m.SetWeights({
2, 4, // u = 0
});
m.SetBias({1});

m.SetInput({
1, -2, // b = 0
-2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(), ElementsAre(-5, 1));
}

TEST(FloatFullyConnectedOpTest, ActivationN1) {
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_RELU_N1_TO_1);
m.SetWeights({
2, 4, // u = 0
});
m.SetBias({1});

m.SetInput({
1, -2, // b = 0
-2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(), ElementsAre(-1, 1));
}

TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) {
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
Expand Down

0 comments on commit fda15fd

Please sign in to comment.