Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge pull request #43269 from rhdong:Mutablehashtable lookup support… #45896

Merged
merged 1 commit into from Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 10 additions & 3 deletions tensorflow/core/framework/lookup_interface.cc
Expand Up @@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
if (default_value.shape() != value_shape()) {
TensorShape fullsize_value_shape = key.shape();
for (int i = 0; i < key_shape().dims(); ++i) {
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
}
fullsize_value_shape.AppendShape(value_shape());
if (default_value.shape() != value_shape() &&
default_value.shape() != fullsize_value_shape) {
return errors::InvalidArgument(
"Expected shape ", value_shape().DebugString(),
" for default value, got ", default_value.shape().DebugString());
"Expected shape ", value_shape().DebugString(), " or ",
fullsize_value_shape.DebugString(), " for default value, got ",
default_value.shape().DebugString());
}
return Status::OK();
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/lookup_interface.h
Expand Up @@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
// requirements are satisfied, otherwise it returns InvalidArgument:
// - DataType of the tensor keys equals to the table key_dtype
// - DataType of the tensor default_value equals to the table value_dtype
// - the default_value tensor shape matches the table's value shape.
// - the default_value tensor has the required shape given keys and the
// tables's value shape.
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);

string DebugString() const override {
Expand Down
30 changes: 26 additions & 4 deletions tensorflow/core/kernels/lookup_table_op.cc
Expand Up @@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface {

Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const V default_val = default_value.flat<V>()(0);
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
const auto default_flat = default_value.flat<V>();

int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_size_default = (total == default_total);

tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
// is_full_size_default is true:
// Each key has an independent default value, key_values(i)
// corresponding uses default_flat(i) as its default value.
//
// is_full_size_default is false:
// All keys will share the default_flat(0) as default value.
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
table_, SubtleMustCopyIfIntegral(key_values(i)),
is_full_size_default ? default_flat(i) : default_flat(0));
}

return Status::OK();
Expand Down Expand Up @@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface {

Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const auto default_flat = default_value.flat<V>();
const auto default_flat = default_value.flat_inner_dims<V, 2>();
const auto key_values = key.flat<K>();
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);

int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_size_default = (total == default_total);

tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
Expand All @@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
value_values(i, j) = value_vec->at(j);
}
} else {
// is_full_size_default is true:
// Each key has an independent default value, key_values(i)
// corresponding uses default_flat(i) as its default value.
//
// is_full_size_default is false:
// All keys will share the default_flat(0) as default value.
for (int64 j = 0; j < value_dim; j++) {
value_values(i, j) = default_flat(j);
value_values(i, j) =
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/core/ops/lookup_ops.cc
Expand Up @@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2")
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));

// Default value must be scalar or vector.
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));

ShapeAndType value_shape_and_type;
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
c,
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/ops/lookup_ops_test.cc
Expand Up @@ -25,7 +25,6 @@ namespace {
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
ShapeInferenceTestOp op("LookupTableFindV2");
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]");
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
.Input({"table_handle", 0, DT_RESOURCE})
.Input({"keys", 0, DT_INT64})
Expand Down
65 changes: 65 additions & 0 deletions tensorflow/python/kernel_tests/lookup_ops_test.py
Expand Up @@ -3375,6 +3375,71 @@ def testMutableHashTableFindHighRank(self):
result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-1, -1]], result)

def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
default_val = [-1, -1]
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

invalid_default_val = constant_op.constant(
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)

with self.assertRaisesRegex(
(ValueError, errors_impl.InvalidArgumentError),
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
self.evaluate(table.lookup(input_string, invalid_default_val))

invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
dtypes.int64)
with self.assertRaisesRegex(
(ValueError, errors_impl.InvalidArgumentError),
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
self.evaluate(table.lookup(input_string, invalid_default_val))

def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2], output.get_shape())

result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-4, -5]], result)

def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
default_val = [-1, -1]
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

dynamic_default_val = constant_op.constant(
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2, 2], output.get_shape())

result = self.evaluate(output)
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)

def testMutableHashTableInsertHighRank(self):
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
Expand Down
24 changes: 21 additions & 3 deletions tensorflow/python/ops/lookup_ops.py
Expand Up @@ -1849,14 +1849,31 @@ def remove(self, keys, name=None):

return op

def lookup(self, keys, name=None):
def lookup(self, keys, dynamic_default_values=None, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.

The `default_value` is used for keys not present in the table.

Args:
keys: Keys to look up. Can be a tensor of any shape. Must match the
table's key_dtype.
dynamic_default_values: The values to use if a key is missing in the
table. If None (by default), the `table.default_value` will be used.
Shape of `dynamic_default_values` must be same with
`table.default_value` or the lookup result tensor.
In the latter case, each key will have a different default value.

For example:

```python
keys = [0, 1, 3]
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]

# The key '0' will use [1, 3, 4] as default value.
# The key '1' will use [2, 3, 9] as default value.
# The key '3' will use [8, 3, 0] as default value.
```

name: A name for the operation (optional).

Returns:
Expand All @@ -1870,8 +1887,9 @@ def lookup(self, keys, name=None):
(self.resource_handle, keys, self._default_value)):
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self.resource_handle):
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
self._default_value)
values = gen_lookup_ops.lookup_table_find_v2(
self.resource_handle, keys, dynamic_default_values
if dynamic_default_values is not None else self._default_value)
return values

def insert(self, keys, values, name=None):
Expand Down