Skip to content

Commit

Permalink
Implement tensorflow's hash table V2 kernels as TFLite custom ops
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 282478716
Change-Id: I45efb39d5cc39ddaf90662f287937b195c8dd61a
  • Loading branch information
abattery authored and tensorflower-gardener committed Nov 26, 2019
1 parent ba1765a commit 3cfd65b
Show file tree
Hide file tree
Showing 7 changed files with 1,381 additions and 1 deletion.
42 changes: 42 additions & 0 deletions tensorflow/lite/experimental/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,45 @@ cc_test(
"@com_google_googletest//:gtest",
],
)

cc_library(
name = "hashtable_op_kernels",
srcs = [
"hashtable.cc",
"hashtable_import.cc",
"hashtable_lookup.cc",
"hashtable_size.cc",
],
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/kernels:kernel_util",
"//tensorflow/lite/kernels:op_macros",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/schema:schema_fbs",
"@flatbuffers",
],
)

cc_test(
name = "hashtable_op_test",
size = "small",
srcs = [
"hashtable_ops_test.cc",
],
deps = [
":hashtable_op_kernels", # buildcleaner: keep
"//tensorflow/lite:framework",
"//tensorflow/lite/core/api",
"//tensorflow/lite/experimental/resource",
"//tensorflow/lite/kernels:test_main",
"//tensorflow/lite/kernels:test_util",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/testing:util",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)
113 changes: 113 additions & 0 deletions tensorflow/lite/experimental/kernels/hashtable.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>

#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/schema/schema_generated.h"

namespace tflite {
namespace ops {
namespace custom {
namespace hashtable {

constexpr int kResourceHandleTensor = 0;

// TODO(b/144728911): The following structure should be moved to
// builtin_op_data.h when it is ready to become a builtin op.
typedef struct {
std::string table_name;
TfLiteType key_dtype;
TfLiteType value_dtype;
} TfLiteHashtableParams;

void* InitHashtable(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_CHECK(buffer != nullptr);

const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();

TfLiteHashtableParams* option = new TfLiteHashtableParams;
option->table_name = m["table_name"].AsString().str();
option->key_dtype = static_cast<TfLiteType>(m["key_dtype"].AsInt32());
option->value_dtype = static_cast<TfLiteType>(m["value_dtype"].AsInt32());

return option;
}

void FreeHashtable(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<TfLiteHashtableParams*>(buffer);
}

TfLiteStatus PrepareHashtable(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 0);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

TF_LITE_ENSURE(context, node->user_data != nullptr);
const auto* params =
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);
TF_LITE_ENSURE(context, !params->table_name.empty());
TF_LITE_ENSURE(context, (params->key_dtype == kTfLiteInt32 ||
params->key_dtype == kTfLiteString));
TF_LITE_ENSURE(context, (params->value_dtype == kTfLiteInt32 ||
params->value_dtype == kTfLiteString ||
params->value_dtype == kTfLiteFloat32));

TfLiteTensor* resource_handle_tensor =
GetOutput(context, node, kResourceHandleTensor);
TF_LITE_ENSURE(context, resource_handle_tensor != nullptr);
TF_LITE_ENSURE_EQ(context, resource_handle_tensor->type, kTfLiteInt32);
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1);
outputSize->data[0] = 1;
return context->ResizeTensor(context, resource_handle_tensor, outputSize);
}

TfLiteStatus EvalHashtable(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<const TfLiteHashtableParams*>(node->user_data);

// The resource id is generated based on the given table name.
const int resource_id = std::hash<std::string>{}(params->table_name);

TfLiteTensor* resource_handle_tensor =
GetOutput(context, node, kResourceHandleTensor);
auto* resource_handle_data = GetTensorData<int32>(resource_handle_tensor);
resource_handle_data[0] = resource_id;

Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources();
resource::CreateHashtableResourceIfNotAvailable(
&resources, resource_id, params->key_dtype, params->value_dtype);
return kTfLiteOk;
}

} // namespace hashtable

TfLiteRegistration* Register_HASHTABLE() {
static TfLiteRegistration r = {hashtable::InitHashtable,
hashtable::FreeHashtable,
hashtable::PrepareHashtable,
hashtable::EvalHashtable,
nullptr,
BuiltinOperator_CUSTOM};
return &r;
}

} // namespace custom
} // namespace ops
} // namespace tflite
89 changes: 89 additions & 0 deletions tensorflow/lite/experimental/kernels/hashtable_import.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/schema/schema_generated.h"

namespace tflite {
namespace ops {
namespace custom {

namespace hashtable {

constexpr int kInputResourceIdTensor = 0;
constexpr int kKeyTensor = 1;
constexpr int kValueTensor = 2;

TfLiteStatus PrepareHashtableImport(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);

const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);

const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
key_tensor->type == kTfLiteString));

const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);
TF_LITE_ENSURE(context, (value_tensor->type == kTfLiteInt32 ||
value_tensor->type == kTfLiteString ||
value_tensor->type == kTfLiteFloat32));
// TODO(b/144731295): Tensorflow lookup ops support 1-D vector in storing
// values.
TF_LITE_ENSURE(context, HaveSameShapes(key_tensor, value_tensor));
return kTfLiteOk;
}

TfLiteStatus EvalHashtableImport(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
const int resource_id = input_resource_id_tensor->data.i32[0];

const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
const TfLiteTensor* value_tensor = GetInput(context, node, kValueTensor);

Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources();
auto* lookup = resource::GetHashtableResource(&resources, resource_id);
TF_LITE_ENSURE(context, lookup != nullptr);
TF_LITE_ENSURE_STATUS(
lookup->CheckKeyAndValueTypes(context, key_tensor, value_tensor));
// The hashtable resource will only be initialized once, attempting to
// initialize it multiple times will be a no-op.
return lookup->Import(context, key_tensor, value_tensor);
}

} // namespace hashtable

TfLiteRegistration* Register_HASHTABLE_IMPORT() {
static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr,
hashtable::PrepareHashtableImport,
hashtable::EvalHashtableImport,
nullptr,
BuiltinOperator_CUSTOM};
return &r;
}

} // namespace custom
} // namespace ops
} // namespace tflite
98 changes: 98 additions & 0 deletions tensorflow/lite/experimental/kernels/hashtable_lookup.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/core/subgraph.h"
#include "tensorflow/lite/experimental/resource/lookup_interfaces.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/schema/schema_generated.h"

namespace tflite {
namespace ops {
namespace custom {

namespace hashtable {

constexpr int kInputResourceIdTensor = 0;
constexpr int kKeyTensor = 1;
constexpr int kDefaultValueTensor = 2;
constexpr int kOutputTensor = 0;

TfLiteStatus PrepareHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
TF_LITE_ENSURE_EQ(context, input_resource_id_tensor->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_resource_id_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_resource_id_tensor, 0), 1);

const TfLiteTensor* default_value_tensor =
GetInput(context, node, kDefaultValueTensor);
TF_LITE_ENSURE_EQ(context, NumDimensions(default_value_tensor), 1);
TF_LITE_ENSURE_EQ(context, SizeOfDimension(default_value_tensor, 0), 1);

TfLiteTensor* output_tensor = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, default_value_tensor->type, output_tensor->type);
TF_LITE_ENSURE(context, (output_tensor->type == kTfLiteInt32 ||
output_tensor->type == kTfLiteString ||
output_tensor->type == kTfLiteFloat32));

const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
TF_LITE_ENSURE(context, (key_tensor->type == kTfLiteInt32 ||
key_tensor->type == kTfLiteString));
if (output_tensor->type != kTfLiteString) {
return context->ResizeTensor(context, output_tensor,
TfLiteIntArrayCopy(key_tensor->dims));
}
return kTfLiteOk;
}

TfLiteStatus EvalHashtableLookup(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_resource_id_tensor =
GetInput(context, node, kInputResourceIdTensor);
int resource_id = input_resource_id_tensor->data.i32[0];

const TfLiteTensor* key_tensor = GetInput(context, node, kKeyTensor);
const TfLiteTensor* default_value_tensor =
GetInput(context, node, kDefaultValueTensor);
TfLiteTensor* output_tensor = GetOutput(context, node, 0);

Subgraph* subgraph = reinterpret_cast<Subgraph*>(context->impl_);
auto& resources = subgraph->resources();
auto* lookup = resource::GetHashtableResource(&resources, resource_id);
TF_LITE_ENSURE(context, lookup != nullptr);
TF_LITE_ENSURE_STATUS(
lookup->CheckKeyAndValueTypes(context, key_tensor, output_tensor));
return lookup->Lookup(context, key_tensor, output_tensor,
default_value_tensor);
}

} // namespace hashtable

TfLiteRegistration* Register_HASHTABLE_LOOKUP() {
static TfLiteRegistration r = {/*init=*/nullptr,
/*free=*/nullptr,
hashtable::PrepareHashtableLookup,
hashtable::EvalHashtableLookup,
nullptr,
BuiltinOperator_CUSTOM};
return &r;
}

} // namespace custom
} // namespace ops
} // namespace tflite

0 comments on commit 3cfd65b

Please sign in to comment.