Skip to content

Commit

Permalink
Merge pull request #45896 from rhdong/github/rhdong/r2.4
Browse files Browse the repository at this point in the history
Merge pull request #43269 from rhdong:Mutablehashtable lookup support…
  • Loading branch information
mihaimaruseac committed Jan 19, 2021
2 parents 302112b + a26eed2 commit 6d18335
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 16 deletions.
13 changes: 10 additions & 3 deletions tensorflow/core/framework/lookup_interface.cc
Expand Up @@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
if (default_value.shape() != value_shape()) {
TensorShape fullsize_value_shape = key.shape();
for (int i = 0; i < key_shape().dims(); ++i) {
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
}
fullsize_value_shape.AppendShape(value_shape());
if (default_value.shape() != value_shape() &&
default_value.shape() != fullsize_value_shape) {
return errors::InvalidArgument(
"Expected shape ", value_shape().DebugString(),
" for default value, got ", default_value.shape().DebugString());
"Expected shape ", value_shape().DebugString(), " or ",
fullsize_value_shape.DebugString(), " for default value, got ",
default_value.shape().DebugString());
}
return Status::OK();
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/lookup_interface.h
Expand Up @@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
// requirements are satisfied, otherwise it returns InvalidArgument:
// - DataType of the tensor keys equals to the table key_dtype
// - DataType of the tensor default_value equals to the table value_dtype
// - the default_value tensor shape matches the table's value shape.
// - the default_value tensor has the required shape given keys and the
// tables's value shape.
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);

string DebugString() const override {
Expand Down
30 changes: 26 additions & 4 deletions tensorflow/core/kernels/lookup_table_op.cc
Expand Up @@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface {

Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const V default_val = default_value.flat<V>()(0);
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
const auto default_flat = default_value.flat<V>();

int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_size_default = (total == default_total);

tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
// is_full_size_default is true:
// Each key has an independent default value, key_values(i)
// corresponding uses default_flat(i) as its default value.
//
// is_full_size_default is false:
// All keys will share the default_flat(0) as default value.
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
table_, SubtleMustCopyIfIntegral(key_values(i)),
is_full_size_default ? default_flat(i) : default_flat(0));
}

return Status::OK();
Expand Down Expand Up @@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface {

Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const auto default_flat = default_value.flat<V>();
const auto default_flat = default_value.flat_inner_dims<V, 2>();
const auto key_values = key.flat<K>();
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);

int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_size_default = (total == default_total);

tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
Expand All @@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
value_values(i, j) = value_vec->at(j);
}
} else {
// is_full_size_default is true:
// Each key has an independent default value, key_values(i)
// corresponding uses default_flat(i) as its default value.
//
// is_full_size_default is false:
// All keys will share the default_flat(0) as default value.
for (int64 j = 0; j < value_dim; j++) {
value_values(i, j) = default_flat(j);
value_values(i, j) =
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/core/ops/lookup_ops.cc
Expand Up @@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2")
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));

// Default value must be scalar or vector.
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));

ShapeAndType value_shape_and_type;
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
c,
Expand Down
1 change: 0 additions & 1 deletion tensorflow/core/ops/lookup_ops_test.cc
Expand Up @@ -25,7 +25,6 @@ namespace {
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
ShapeInferenceTestOp op("LookupTableFindV2");
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]");
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
.Input({"table_handle", 0, DT_RESOURCE})
.Input({"keys", 0, DT_INT64})
Expand Down
65 changes: 65 additions & 0 deletions tensorflow/python/kernel_tests/lookup_ops_test.py
Expand Up @@ -3375,6 +3375,71 @@ def testMutableHashTableFindHighRank(self):
result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-1, -1]], result)

def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
default_val = [-1, -1]
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

invalid_default_val = constant_op.constant(
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)

with self.assertRaisesRegex(
(ValueError, errors_impl.InvalidArgumentError),
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
self.evaluate(table.lookup(input_string, invalid_default_val))

invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
dtypes.int64)
with self.assertRaisesRegex(
(ValueError, errors_impl.InvalidArgumentError),
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
self.evaluate(table.lookup(input_string, invalid_default_val))

def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2], output.get_shape())

result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-4, -5]], result)

def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
default_val = [-1, -1]
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

dynamic_default_val = constant_op.constant(
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2, 2], output.get_shape())

result = self.evaluate(output)
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)

def testMutableHashTableInsertHighRank(self):
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
Expand Down
24 changes: 21 additions & 3 deletions tensorflow/python/ops/lookup_ops.py
Expand Up @@ -1849,14 +1849,31 @@ def remove(self, keys, name=None):

return op

def lookup(self, keys, name=None):
def lookup(self, keys, dynamic_default_values=None, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
The `default_value` is used for keys not present in the table.
Args:
keys: Keys to look up. Can be a tensor of any shape. Must match the
table's key_dtype.
dynamic_default_values: The values to use if a key is missing in the
table. If None (by default), the `table.default_value` will be used.
Shape of `dynamic_default_values` must be same with
`table.default_value` or the lookup result tensor.
In the latter case, each key will have a different default value.
For example:
```python
keys = [0, 1, 3]
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
# The key '0' will use [1, 3, 4] as default value.
# The key '1' will use [2, 3, 9] as default value.
# The key '3' will use [8, 3, 0] as default value.
```
name: A name for the operation (optional).
Returns:
Expand All @@ -1870,8 +1887,9 @@ def lookup(self, keys, name=None):
(self.resource_handle, keys, self._default_value)):
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self.resource_handle):
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
self._default_value)
values = gen_lookup_ops.lookup_table_find_v2(
self.resource_handle, keys, dynamic_default_values
if dynamic_default_values is not None else self._default_value)
return values

def insert(self, keys, values, name=None):
Expand Down

0 comments on commit 6d18335

Please sign in to comment.