Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcasting functionality for Div and Sub ops. #17123

Merged
merged 9 commits into from Apr 12, 2018
5 changes: 4 additions & 1 deletion tensorflow/contrib/lite/kernels/div.cc
Expand Up @@ -106,6 +106,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_DIV
}



template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
Expand All @@ -118,7 +120,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
} else {
context->ReportError(context, "Inputs and outputs not all float types.");
context->ReportError(context,
"Div only supports FLOAT32 and quantized UINT8 now.");
return kTfLiteError;
}

Expand Down
Expand Up @@ -3938,7 +3938,7 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;

gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
Expand Down
39 changes: 27 additions & 12 deletions tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
Expand Up @@ -1255,6 +1255,33 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
output_data, output_dims);
}

inline void Div(const float* input1_data, const Dims<4>& input1_dims,
const float* input2_data, const Dims<4>& input2_dims,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
const int batches =
MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
const int height =
MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
const int width =
MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
const int depth =
MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
for (int b = 0; b < batches; ++b) {
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
for (int c = 0; c < depth; ++c) {
output_data[Offset(output_dims, c, x, y, b)] =
ActivationFunctionWithMinMax(
input1_data[Offset(input1_dims, c, x, y, b)] /
input2_data[Offset(input2_dims, c, x, y, b)],
output_activation_min, output_activation_max);
}
}
}
}
}

// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
Expand Down Expand Up @@ -1296,18 +1323,6 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}

inline void Div(const float* input1_data, const Dims<4>& input1_dims,
const float* input2_data, const Dims<4>& input2_dims,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
output_activation_max);
}
}

inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
const float* input2_data, const Dims<4>& input2_dims,
float output_activation_min, float output_activation_max,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/contrib/lite/kernels/sub.cc
Expand Up @@ -174,7 +174,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
output);
} else {
context->ReportError(context, "Inputs and outputs not all float types.");
context->ReportError(context,
"Inputs and outputs not all float|unit8 types.");
return kTfLiteError;
}

Expand Down