Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fixes shape inference of LookupTableImportV2 to handle scalar values.
PiperOrigin-RevId: 506126405
  • Loading branch information
wangpengmit authored and tensorflower-gardener committed Jan 31, 2023
1 parent 138c13f commit 980b225
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tensorflow/core/ops/lookup_ops.cc
Expand Up @@ -309,9 +309,10 @@ REGISTER_OP("LookupTableImportV2")

ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
ShapeHandle values;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &values));
DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->Merge(c->Dim(keys, 0), c->Dim(c->input(2), 0), &unused));
TF_RETURN_IF_ERROR(c->Merge(c->Dim(keys, 0), c->Dim(values, 0), &unused));
return OkStatus();
});

Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/kernel_tests/data_structures/lookup_ops_test.py
Expand Up @@ -41,6 +41,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_lookup_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import variables
Expand Down Expand Up @@ -573,6 +574,20 @@ def false_fn():
self.evaluate(lookup_ops.tables_initializer())
self.assertAllEqual(grad, -10.)

def testImportShapeInference(self, is_anonymous):
v = variables.Variable(1)

@def_function.function(jit_compile=True)
def foo():
return gen_lookup_ops.lookup_table_import_v2(
table_handle=v.handle, keys=[1.1, 2.2], values=1
)

with self.assertRaisesRegex(
ValueError, r"Shape must be at least rank 1 but is rank 0"
):
foo()

def testExportShapeInference(self, is_anonymous):
table = self.getHashTable()(
lookup_ops.KeyValueTensorInitializer(
Expand Down

0 comments on commit 980b225

Please sign in to comment.