Skip to content

Commit

Permalink
Mutablehashtable lookup support full size dynamic default values.
Browse files Browse the repository at this point in the history
This PR is one part of RFC:tensorflow/community#237
  • Loading branch information
rhdong committed Oct 16, 2020
1 parent 5c42efe commit 9343aea
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 15 deletions.
13 changes: 10 additions & 3 deletions tensorflow/core/framework/lookup_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
if (default_value.shape() != value_shape()) {
TensorShape fullsize_value_shape = key.shape();
for (int i = 0; i < key_shape().dims(); ++i) {
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
}
fullsize_value_shape.AppendShape(value_shape());
if (default_value.shape() != value_shape() &&
default_value.shape() != fullsize_value_shape) {
return errors::InvalidArgument(
"Expected shape ", value_shape().DebugString(),
" for default value, got ", default_value.shape().DebugString());
"Expected shape ", value_shape().DebugString(), " or ",
fullsize_value_shape.DebugString(), " for default value, got ",
default_value.shape().DebugString());
}
return Status::OK();
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/lookup_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
// requirements are satisfied, otherwise it returns InvalidArgument:
// - DataType of the tensor keys equals to the table key_dtype
// - DataType of the tensor default_value equals to the table value_dtype
// - the default_value tensor shape matches the table's value shape.
// - the default_value tensor has the required shape given keys and the
// tables's value shape.
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);

string DebugString() const override {
Expand Down
18 changes: 14 additions & 4 deletions tensorflow/core/kernels/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,19 @@ class MutableHashTableOfScalars final : public LookupInterface {

Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const V default_val = default_value.flat<V>()(0);
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
const auto default_flat = default_value.flat<V>();

int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_default = (total == default_total);

tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
table_, SubtleMustCopyIfIntegral(key_values(i)),
is_full_default ? default_flat(i) : default_flat(0));
}

return Status::OK();
Expand Down Expand Up @@ -173,11 +178,15 @@ class MutableHashTableOfTensors final : public LookupInterface {

Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const auto default_flat = default_value.flat<V>();
const auto default_flat = default_value.flat_inner_dims<V, 2>();
const auto key_values = key.flat<K>();
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);

int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_default = (total == default_total);

tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
Expand All @@ -188,7 +197,8 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
} else {
for (int64 j = 0; j < value_dim; j++) {
value_values(i, j) = default_flat(j);
value_values(i, j) =
is_full_default ? default_flat(i, j) : default_flat(0, j);
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/core/ops/lookup_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,6 @@ REGISTER_OP("LookupTableFindV2")
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));

// Default value must be scalar or vector.
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));

ShapeAndType value_shape_and_type;
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
c,
Expand Down
69 changes: 69 additions & 0 deletions tensorflow/python/kernel_tests/lookup_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3309,6 +3309,75 @@ def testMutableHashTableFindHighRank(self):
result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-1, -1]], result)

def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
with self.cached_session():
default_val = [-1, -1]
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

raised_error = ValueError
if context.executing_eagerly():
raised_error = errors_impl.InvalidArgumentError

invalid_default_val = constant_op.constant(
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)

with self.assertRaises(raised_error):
_ = table.lookup(input_string, invalid_default_val)

invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
dtypes.int64)

with self.assertRaises(raised_error):
_ = table.lookup(input_string, invalid_default_val)

def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2], output.get_shape())

result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-4, -5]], result)

def testMutableHashTableFindHighRankVactorWithDynamicDefaultValue(self):
with self.cached_session():
default_val = [-1, -1]
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)

self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))

input_string = constant_op.constant([["brain", "salad"],
["tank", "tarkus"]])

dynamic_default_val = constant_op.constant(
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2, 2], output.get_shape())

result = self.evaluate(output)
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)

def testMutableHashTableInsertHighRank(self):
with self.cached_session():
default_val = -1
Expand Down
24 changes: 21 additions & 3 deletions tensorflow/python/ops/lookup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,14 +1810,31 @@ def remove(self, keys, name=None):

return op

def lookup(self, keys, name=None):
def lookup(self, keys, dynamic_default_values=None, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
The `default_value` is used for keys not present in the table.
Args:
keys: Keys to look up. Can be a tensor of any shape. Must match the
table's key_dtype.
dynamic_default_values: The values to use if a key is missing in the
table. If None (by default), the `self._default_value` will be used.
Shape of `dynamic_default_values` must be same with
`self._default_value` or the lookup result tensor.
In the latter case, each key will have a different default value.
For example:
```python
keys = [0, 1, 3]
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
# The key '0' will use [1, 3, 4] as default value.
# The key '1' will use [2, 3, 9] as default value.
# The key '3' will use [8, 3, 0] as default value.
```
name: A name for the operation (optional).
Returns:
Expand All @@ -1831,8 +1848,9 @@ def lookup(self, keys, name=None):
(self.resource_handle, keys, self._default_value)):
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self.resource_handle):
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
self._default_value)
values = gen_lookup_ops.lookup_table_find_v2(
self.resource_handle, keys, dynamic_default_values
if dynamic_default_values is not None else self._default_value)
return values

def insert(self, keys, values, name=None):
Expand Down

0 comments on commit 9343aea

Please sign in to comment.