Skip to content

Commit

Permalink
Merge pull request #18896 from KikaTech/fix_lite_topk
Browse files Browse the repository at this point in the history
Fix some issues in official tf.nn.topk() in lite
  • Loading branch information
ekelsen committed Apr 26, 2018
2 parents 1bf9ec7 + afa4748 commit d0b9613
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 11 deletions.
4 changes: 2 additions & 2 deletions tensorflow/contrib/lite/kernels/topk_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace builtin {
namespace topk_v2 {
constexpr int kInputTensor = 0;
constexpr int kInputTopK = 1;
constexpr int kOutputIndexes = 0;
constexpr int kOutputValues = 1;
constexpr int kOutputValues = 0;
constexpr int kOutputIndexes = 1;

namespace {
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/lite/kernels/topk_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class TopKV2OpModel : public SingleOpModel {
int top_k) {
input_ = AddInput(input_type);
top_k_ = AddInput(TensorType_INT32);
output_indexes_ = AddOutput(TensorType_INT32);
output_values_ = AddOutput(input_type);
output_indexes_ = AddOutput(TensorType_INT32);
SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0);
BuildInterpreter({input_shape, {1}});
PopulateTensor<int32_t>(top_k_, {top_k});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, rand_op->dtype);
break;
}
case OperatorType::kTopK_V2: {
// topk(values: T, k: int32) -> values: T, indices: int32
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 2);
CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32);
model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type;
model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32;
break;
}
case OperatorType::kTensorFlowUnsupported: {
auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
// Some output tensors from the op could be eliminated by optimization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1086,8 +1086,8 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) {
void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
const auto& input_values = model->GetArray(op->inputs[0]);
const auto& input_k = model->GetArray(op->inputs[1]);
auto& output_indexes = model->GetArray(op->outputs[0]);
auto& output_values = model->GetArray(op->outputs[1]);
auto& output_values = model->GetArray(op->outputs[0]);
auto& output_indexes = model->GetArray(op->outputs[1]);

// Bail if we already know the output shape.
if (output_indexes.has_shape()) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/lite/toco/import_tensorflow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1981,7 +1981,7 @@ void ConvertTopKV2Operator(const NodeDef& node,
op->inputs.push_back(node.input(1));
}
// The op has two outputs.
op->outputs.push_back(node.name() + ":0");
op->outputs.push_back(node.name());
op->outputs.push_back(node.name() + ":1");
model->operators.emplace_back(op.release());
}
Expand Down
9 changes: 4 additions & 5 deletions tensorflow/contrib/lite/toco/tooling_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,6 @@ void FixNoOrphanedArray(Model* model) {
void CheckEachArray(const Model& model) {
for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = array_entry.second;
if (array->has_shape()) {
for (int d : array->shape().dims()) {
CHECK_GE(d, 1);
}
}
// It's OK to have a buffer or an alloc, but not both.
// (Since allocs are for transient arrays without a buffer).
CHECK(!array->buffer || !array->alloc);
Expand All @@ -839,6 +834,10 @@ void CheckEachArray(const Model& model) {
// The presence of a fixed buffer should imply the presence of a fixed
// shape.
CHECK(array->has_shape());
// Constant buffer should has a valid shape.
for (int d : array->shape().dims()) {
CHECK_GE(d, 1);
}
// The shape flat-size should agree with the buffer length.
CHECK_EQ(array->buffer->Length(),
RequiredBufferSizeForShape(array->shape()));
Expand Down

0 comments on commit d0b9613

Please sign in to comment.