Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some issues in official tf.nn.topk() in lite #18896

Merged
merged 1 commit into from
Apr 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/contrib/lite/kernels/topk_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace builtin {
namespace topk_v2 {
constexpr int kInputTensor = 0;
constexpr int kInputTopK = 1;
constexpr int kOutputIndexes = 0;
constexpr int kOutputValues = 1;
constexpr int kOutputValues = 0;
constexpr int kOutputIndexes = 1;

namespace {
TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/lite/kernels/topk_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class TopKV2OpModel : public SingleOpModel {
int top_k) {
input_ = AddInput(input_type);
top_k_ = AddInput(TensorType_INT32);
output_indexes_ = AddOutput(TensorType_INT32);
output_values_ = AddOutput(input_type);
output_indexes_ = AddOutput(TensorType_INT32);
SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0);
BuildInterpreter({input_shape, {1}});
PopulateTensor<int32_t>(top_k_, {top_k});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, rand_op->dtype);
break;
}
case OperatorType::kTopK_V2: {
// topk(values: T, k: int32) -> values: T, indices: int32
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 2);
CHECK(model->GetArray(op->inputs[1]).data_type == ArrayDataType::kInt32);
model->GetArray(op->outputs[0]).data_type = model->GetArray(op->inputs[0]).data_type;
model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32;
break;
}
case OperatorType::kTensorFlowUnsupported: {
auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
// Some output tensors from the op could be eliminated by optimization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1110,8 +1110,8 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) {
void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
const auto& input_values = model->GetArray(op->inputs[0]);
const auto& input_k = model->GetArray(op->inputs[1]);
auto& output_indexes = model->GetArray(op->outputs[0]);
auto& output_values = model->GetArray(op->outputs[1]);
auto& output_values = model->GetArray(op->outputs[0]);
auto& output_indexes = model->GetArray(op->outputs[1]);

// Bail if we already know the output shape.
if (output_indexes.has_shape()) {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/lite/toco/import_tensorflow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1970,7 +1970,7 @@ void ConvertTopKV2Operator(const NodeDef& node,
op->inputs.push_back(node.input(1));
}
// The op has two outputs.
op->outputs.push_back(node.name() + ":0");
op->outputs.push_back(node.name());
op->outputs.push_back(node.name() + ":1");
model->operators.emplace_back(op.release());
}
Expand Down
9 changes: 4 additions & 5 deletions tensorflow/contrib/lite/toco/tooling_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,6 @@ void FixNoOrphanedArray(Model* model) {
void CheckEachArray(const Model& model) {
for (const auto& array_entry : model.GetArrayMap()) {
const auto& array = array_entry.second;
if (array->has_shape()) {
for (int d : array->shape().dims()) {
CHECK_GE(d, 1);
}
}
// It's OK to have a buffer or an alloc, but not both.
// (Since allocs are for transient arrays without a buffer).
CHECK(!array->buffer || !array->alloc);
Expand All @@ -839,6 +834,10 @@ void CheckEachArray(const Model& model) {
// The presence of a fixed buffer should imply the presence of a fixed
// shape.
CHECK(array->has_shape());
// Constant buffer should has a valid shape.
for (int d : array->shape().dims()) {
CHECK_GE(d, 1);
}
// The shape flat-size should agree with the buffer length.
CHECK_EQ(array->buffer->Length(),
RequiredBufferSizeForShape(array->shape()));
Expand Down