diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index 117adbf65c42cd..77d3314b3e8440 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key, const Tensor& default_value) { TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value)); TF_RETURN_IF_ERROR(CheckKeyShape(key.shape())); - if (default_value.shape() != value_shape()) { + TensorShape fullsize_value_shape = key.shape(); + for (int i = 0; i < key_shape().dims(); ++i) { + fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1); + } + fullsize_value_shape.AppendShape(value_shape()); + if (default_value.shape() != value_shape() && + default_value.shape() != fullsize_value_shape) { return errors::InvalidArgument( - "Expected shape ", value_shape().DebugString(), - " for default value, got ", default_value.shape().DebugString()); + "Expected shape ", value_shape().DebugString(), " or ", + fullsize_value_shape.DebugString(), " for default value, got ", + default_value.shape().DebugString()); } return Status::OK(); } diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h index 7e5dbe5632becb..04c5b0e4dc60a4 100644 --- a/tensorflow/core/framework/lookup_interface.h +++ b/tensorflow/core/framework/lookup_interface.h @@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase { // requirements are satisfied, otherwise it returns InvalidArgument: // - DataType of the tensor keys equals to the table key_dtype // - DataType of the tensor default_value equals to the table value_dtype - // - the default_value tensor shape matches the table's value shape. + // - the default_value tensor has the required shape given keys and the + // tables's value shape. Status CheckFindArguments(const Tensor& keys, const Tensor& default_value); string DebugString() const override { diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index f269aa65b4e911..09d9d32d2ae9ed 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface { Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, const Tensor& default_value) override { - const V default_val = default_value.flat()(0); const auto key_values = key.flat(); auto value_values = value->flat(); + const auto default_flat = default_value.flat(); + + int64 total = value_values.size(); + int64 default_total = default_flat.size(); + bool is_full_size_default = (total == default_total); tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { + // is_full_size_default is true: + // Each key has an independent default value, key_values(i) + // corresponding uses default_flat(i) as its default value. + // + // is_full_size_default is false: + // All keys will share the default_flat(0) as default value. value_values(i) = gtl::FindWithDefault( - table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); + table_, SubtleMustCopyIfIntegral(key_values(i)), + is_full_size_default ? default_flat(i) : default_flat(0)); } return Status::OK(); @@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface { Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, const Tensor& default_value) override { - const auto default_flat = default_value.flat(); + const auto default_flat = default_value.flat_inner_dims(); const auto key_values = key.flat(); auto value_values = value->flat_inner_dims(); int64 value_dim = value_shape_.dim_size(0); + int64 total = value_values.size(); + int64 default_total = default_flat.size(); + bool is_full_size_default = (total == default_total); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { ValueArray* value_vec = @@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface { value_values(i, j) = value_vec->at(j); } } else { + // is_full_size_default is true: + // Each key has an independent default value, key_values(i) + // corresponding uses default_flat(i) as its default value. + // + // is_full_size_default is false: + // All keys will share the default_flat(0) as default value. for (int64 j = 0; j < value_dim; j++) { - value_values(i, j) = default_flat(j); + value_values(i, j) = + is_full_size_default ? default_flat(i, j) : default_flat(0, j); } } } diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index 8948df2cef361a..05aa229336d46f 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2") ShapeHandle handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); - // Default value must be scalar or vector. - ShapeHandle keys; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys)); - ShapeAndType value_shape_and_type; TF_RETURN_IF_ERROR(ValidateTableResourceHandle( c, diff --git a/tensorflow/core/ops/lookup_ops_test.cc b/tensorflow/core/ops/lookup_ops_test.cc index ac899d59993381..904099f1813a4a 100644 --- a/tensorflow/core/ops/lookup_ops_test.cc +++ b/tensorflow/core/ops/lookup_ops_test.cc @@ -25,7 +25,6 @@ namespace { TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) { ShapeInferenceTestOp op("LookupTableFindV2"); INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?"); - INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]"); TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2") .Input({"table_handle", 0, DT_RESOURCE}) .Input({"keys", 0, DT_INT64}) diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index d564da12c27260..c2ff260d7d5c68 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -3375,6 +3375,71 @@ def testMutableHashTableFindHighRank(self): result = self.evaluate(output) self.assertAllEqual([[0, 1], [-1, -1]], result) + def testMutableHashTableFindWithInvalidShapeDefaultValue(self): + default_val = [-1, -1] + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + + input_string = constant_op.constant([["brain", "salad"], + ["tank", "tarkus"]]) + + invalid_default_val = constant_op.constant( + [[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64) + + with self.assertRaisesRegex( + (ValueError, errors_impl.InvalidArgumentError), + "Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"): + self.evaluate(table.lookup(input_string, invalid_default_val)) + + invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]], + dtypes.int64) + with self.assertRaisesRegex( + (ValueError, errors_impl.InvalidArgumentError), + "Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"): + self.evaluate(table.lookup(input_string, invalid_default_val)) + + def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self): + default_val = -1 + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + input_string = constant_op.constant([["brain", "salad"], + ["tank", "tarkus"]]) + + dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]], + dtypes.int64) + output = table.lookup(input_string, dynamic_default_val) + self.assertAllEqual([2, 2], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual([[0, 1], [-4, -5]], result) + + def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self): + default_val = [-1, -1] + keys = constant_op.constant(["brain", "salad", "surgery"]) + values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) + table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64, + default_val) + + self.evaluate(table.insert(keys, values)) + self.assertAllEqual(3, self.evaluate(table.size())) + + input_string = constant_op.constant([["brain", "salad"], + ["tank", "tarkus"]]) + + dynamic_default_val = constant_op.constant( + [[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64) + output = table.lookup(input_string, dynamic_default_val) + self.assertAllEqual([2, 2, 2], output.get_shape()) + + result = self.evaluate(output) + self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result) + def testMutableHashTableInsertHighRank(self): default_val = -1 keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index f99102fee52ac4..145c2b0195cdfb 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -1849,7 +1849,7 @@ def remove(self, keys, name=None): return op - def lookup(self, keys, name=None): + def lookup(self, keys, dynamic_default_values=None, name=None): """Looks up `keys` in a table, outputs the corresponding values. The `default_value` is used for keys not present in the table. @@ -1857,6 +1857,23 @@ def lookup(self, keys, name=None): Args: keys: Keys to look up. Can be a tensor of any shape. Must match the table's key_dtype. + dynamic_default_values: The values to use if a key is missing in the + table. If None (by default), the `table.default_value` will be used. + Shape of `dynamic_default_values` must be same with + `table.default_value` or the lookup result tensor. + In the latter case, each key will have a different default value. + + For example: + + ```python + keys = [0, 1, 3] + dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]] + + # The key '0' will use [1, 3, 4] as default value. + # The key '1' will use [2, 3, 9] as default value. + # The key '3' will use [8, 3, 0] as default value. + ``` + name: A name for the operation (optional). Returns: @@ -1870,8 +1887,9 @@ def lookup(self, keys, name=None): (self.resource_handle, keys, self._default_value)): keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys") with ops.colocate_with(self.resource_handle): - values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys, - self._default_value) + values = gen_lookup_ops.lookup_table_find_v2( + self.resource_handle, keys, dynamic_default_values + if dynamic_default_values is not None else self._default_value) return values def insert(self, keys, values, name=None):