Skip to content

Commit

Permalink
Merge pull request #26737 from amitsrivastava78:fully
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 238734455
  • Loading branch information
tensorflower-gardener committed Mar 15, 2019
2 parents d8c8f77 + fda15fd commit e117200
Showing 1 changed file with 93 additions and 1 deletion.
94 changes: 93 additions & 1 deletion tensorflow/lite/kernels/fully_connected_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ static float fully_connected_golden_output[] = {

class BaseFullyConnectedOpModel : public SingleOpModel {
public:
// TODO(ahentz): test different activation types too.
BaseFullyConnectedOpModel(
TfLiteRegistration* registration, int units, int batches,
const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
Expand Down Expand Up @@ -428,6 +427,99 @@ TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) {
EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
}

TEST(FloatFullyConnectedOpTest, ActivationRelu6) {
// The optimized kernel assumes that the bias is specified.
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_PIE(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_RELU6);
m.SetWeights({
2, 4, // u = 0
});

m.SetInput({
1, 2, // b = 0
2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(), ElementsAre(6, 6));
}

TEST(FloatFullyConnectedOpTest, ActivationTanh) {
// The optimized kernel assumes that the bias is specified.
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_PIE(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_TANH);
m.SetWeights({
-2, 1, // u = 0
});

m.SetInput({
1, 4, // b = 0
2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear({0.964028, -0.995055})));
}

TEST(FloatFullyConnectedOpTest, ActivationSign) {
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_SIGN_BIT);
m.SetWeights({
2, 4, // u = 0
});
m.SetBias({1});

m.SetInput({
1, -2, // b = 0
-2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(), ElementsAre(-5, 1));
}

TEST(FloatFullyConnectedOpTest, ActivationN1) {
FloatFullyConnectedOpModel m(
ops::builtin::Register_FULLY_CONNECTED_GENERIC_OPT(),
/*units=*/1, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {2, 2}},
/*output=*/{TensorType_FLOAT32},
/*bias_tensor_optional=*/false,
/*ActivationFunctionType*/ ActivationFunctionType_RELU_N1_TO_1);
m.SetWeights({
2, 4, // u = 0
});
m.SetBias({1});

m.SetInput({
1, -2, // b = 0
-2, 1, // b = 1
});

m.Invoke();

EXPECT_THAT(m.GetOutput(), ElementsAre(-1, 1));
}

TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) {
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
Expand Down

0 comments on commit e117200

Please sign in to comment.