Skip to content

Commit

Permalink
Adds performance enhancements for sparse embedding lookups.
Browse files Browse the repository at this point in the history
  • Loading branch information
philipphack committed Feb 16, 2023
1 parent ee90704 commit 9e216e3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions tensorflow/python/ops/nn_test.py
Expand Up @@ -1853,7 +1853,7 @@ def testRaggedTensor(self):
ragged_ids = ragged_factory_ops.constant([[1, 2, 3], [0], [1, 2]],
ragged_rank=1)

embedded_ragged = nn.embedding_lookup_ragged(weights, ragged_ids)
embedded_ragged = nn.embedding_lookup(weights, ragged_ids)
expected_output = ragged_factory_ops.constant(
[[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[0, 0, 0]], [[1, 1, 1], [2, 2, 2]]
],
Expand All @@ -1869,7 +1869,7 @@ def testMultipleRaggedDimTensor(self):
]],
ragged_rank=2)

embedded_ragged = nn.embedding_lookup_ragged(weights, ragged_ids)
embedded_ragged = nn.embedding_lookup(weights, ragged_ids)
expected_output = ragged_factory_ops.constant(
[[[[[3, 3], [4, 4]], [[0, 0], [6, 6]]], []],
[[[[2, 2], [1, 1]], [[1, 1], [0, 0]]],
Expand All @@ -1882,23 +1882,23 @@ def testMissingWeights(self):
ragged_ids = ragged_factory_ops.constant([[1, 2, 3], [0], [1, 2]])

with self.assertRaisesRegex(ValueError,
"The embedding weights must be specified.*"):
nn.embedding_lookup_ragged(None, ragged_ids)
"params must be specified.*"):
nn.embedding_lookup(None, ragged_ids)

def testEmptyWeights(self):
ragged_ids = ragged_factory_ops.constant([[1, 2, 3], [0], [1, 2]])

with self.assertRaisesRegex(ValueError,
"The embedding weights should not be empty.*"):
nn.embedding_lookup_ragged([], ragged_ids)
"params should not be empty.*"):
nn.embedding_lookup([], ragged_ids)

def testInvalidIndicesType(self):
weights = constant_op.constant([[0, 0, 0], [1, 1, 1], [2, 2, 2]])
ragged_ids = ragged_factory_ops.constant([[1., 2., 3.], [1., 2.]])

with self.assertRaisesRegex(
ValueError, "The values contained by the inputs have type*"):
nn.embedding_lookup_ragged(weights, ragged_ids)
nn.embedding_lookup(weights, ragged_ids)

def testMaxNormForEmbeddings(self):
weights = constant_op.constant(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/tpu/tpu_embedding_for_serving.py
Expand Up @@ -192,7 +192,7 @@ def _ragged_embedding_lookup_with_reduce(table: tf_variables.Variable,
if weights is None:
weights = array_ops.ones_like(ragged, dtype=table.dtype)
weights = array_ops.expand_dims(weights, axis=2)
ragged_result = embedding_ops.embedding_lookup_ragged(table, ragged)
ragged_result = embedding_ops.embedding_lookup(table, ragged)
ragged_result = math_ops.reduce_sum(ragged_result * weights, axis=1)
if combiner == "mean":
ragged_result = math_ops.div_no_nan(ragged_result,
Expand Down

0 comments on commit 9e216e3

Please sign in to comment.