Skip to content

Commit

Permalink
Adds performance enhancements for sparse embedding lookups.
Browse files Browse the repository at this point in the history
  • Loading branch information
philipphack committed Feb 16, 2023
1 parent 3d1096a commit c115b80
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 68 deletions.
16 changes: 8 additions & 8 deletions tensorflow/python/eager/benchmarks_test.py
Expand Up @@ -1627,14 +1627,14 @@ def benchmark_tf_range_return_int64_GPU(self):
self._benchmark_tf_range_return(dtype=dtypes.int64, device=GPU)

def _benchmark_embedding_lookup_sparse_with_sparse_input(self,
allow_dense_grads=True,
allow_fast_lookup=True,
batch_size=32000,
device=GPU):
def func(sp_ids):
return embedding_ops.embedding_lookup_sparse(self._m_10000_by_16,
sp_ids,
None,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)

with context.device(device):
values = random_ops.random_uniform(shape=(batch_size,),
Expand All @@ -1651,21 +1651,21 @@ def func(sp_ids):

def benchmark_tf_embedding_lookup_sparse_with_sparse_input_sparse_grads(self):
self._benchmark_embedding_lookup_sparse_with_sparse_input(
allow_dense_grads=False)
allow_fast_lookup=False)

def benchmark_tf_embedding_lookup_sparse_with_sparse_input_dense_grads(self):
self._benchmark_embedding_lookup_sparse_with_sparse_input(
allow_dense_grads=True)
allow_fast_lookup=True)

def _benchmark_embedding_lookup_sparse_with_ragged_input(self,
allow_dense_grads=True,
allow_fast_lookup=True,
batch_size=32000,
device=GPU):
def func(sp_ids):
return embedding_ops.embedding_lookup_sparse(self._m_10000_by_16,
sp_ids,
None,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)

with context.device(device):
values = random_ops.random_uniform(shape=(batch_size,),
Expand All @@ -1681,11 +1681,11 @@ def func(sp_ids):

def benchmark_embedding_lookup_sparse_with_ragged_input_sparse_grads(self):
self._benchmark_embedding_lookup_sparse_with_ragged_input(
allow_dense_grads=False)
allow_fast_lookup=False)

def benchmark_embedding_lookup_sparse_with_ragged_input_dense_grads(self):
self._benchmark_embedding_lookup_sparse_with_ragged_input(
allow_dense_grads=True)
allow_fast_lookup=True)

if __name__ == "__main__":
test.main()
62 changes: 42 additions & 20 deletions tensorflow/python/kernel_tests/nn_ops/embedding_ops_test.py
Expand Up @@ -698,8 +698,13 @@ def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
[True, False]
))
@test_util.run_deprecated_v1
def testEmbeddingLookupSparse(self, num_shards, combiner, dtype,
ignore_weights, ragged, allow_dense_grads):
def testEmbeddingLookupSparse(self,
num_shards,
combiner,
dtype,
ignore_weights,
ragged,
allow_fast_lookup):
vocab_size = 13
batch_size = 10
param_shape = [2, 5]
Expand All @@ -721,7 +726,7 @@ def testEmbeddingLookupSparse(self, num_shards, combiner, dtype,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)

self.assertEqual(embedding_sum.get_shape().as_list(),
expected_lookup_result_shape)
Expand Down Expand Up @@ -756,7 +761,7 @@ def testEmbeddingLookupSparse(self, num_shards, combiner, dtype,
[True, False],
[True, False]
))
def testMissingInSparseIds(self, combiner, ragged, allow_dense_grads):
def testMissingInSparseIds(self, combiner, ragged, allow_fast_lookup):
# Github issue, 36359
with self.test_session():
x = array_ops.ones((4, 5))
Expand All @@ -776,7 +781,7 @@ def testMissingInSparseIds(self, combiner, ragged, allow_dense_grads):

embedding_sum = embedding_ops.embedding_lookup_sparse(
x, sp_ids, sp_weights, combiner=combiner,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)

tf_embedding_sum = ops.convert_to_tensor(embedding_sum)
self.assertAllClose(tf_embedding_sum[0], np.zeros(5))
Expand All @@ -794,8 +799,13 @@ def testMissingInSparseIds(self, combiner, ragged, allow_dense_grads):
[True, False]
))
@test_util.run_deprecated_v1
def testGradientsEmbeddingLookupSparse(self, num_shards, combiner, dtype, ignore_weights, ragged,
allow_dense_grads):
def testGradientsEmbeddingLookupSparse(self,
num_shards,
combiner,
dtype,
ignore_weights,
ragged,
allow_fast_lookup):
vocab_size = 12
batch_size = 4
param_shape = [2, 3]
Expand All @@ -811,7 +821,7 @@ def testGradientsEmbeddingLookupSparse(self, num_shards, combiner, dtype, ignore
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
Expand All @@ -826,7 +836,7 @@ def testGradientsEmbeddingLookupSparse(self, num_shards, combiner, dtype, ignore
[True, False]
))
@test_util.run_deprecated_v1
def testIncompatibleShapes(self, ragged, allow_dense_grads):
def testIncompatibleShapes(self, ragged, allow_fast_lookup):
with self.cached_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
indices = [[0, 0], [0, 1], [1, 0]]
Expand All @@ -846,7 +856,7 @@ def testIncompatibleShapes(self, ragged, allow_dense_grads):
with self.assertRaises(ValueError):
embedding_ops.embedding_lookup_sparse(
x, sp_ids, sp_weights, combiner="mean",
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)

@test_util.run_deprecated_v1
def test_incompatible_types(self):
Expand Down Expand Up @@ -946,7 +956,9 @@ def _ids_and_weights_3d(self):
[True, False]
))
@test_util.run_deprecated_v1
def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged, allow_dense_grads):
def test_safe_embedding_lookup_sparse_return_zero_vector(self,
ragged,
allow_fast_lookup):
with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d(ragged)
Expand All @@ -956,7 +968,7 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged, allow_den
embedding_weights,
sparse_ids,
sparse_weights,
allow_dense_grads=allow_dense_grads))
allow_fast_lookup=allow_fast_lookup))

self.assertAllClose(
embedding_lookup_result,
Expand All @@ -969,15 +981,18 @@ def test_safe_embedding_lookup_sparse_return_zero_vector(self, ragged, allow_den
[True, False]
))
@test_util.run_deprecated_v1
def test_safe_embedding_lookup_sparse_return_special_vector(self, ragged, allow_dense_grads):
def test_safe_embedding_lookup_sparse_return_special_vector(
self,
ragged,
allow_fast_lookup):
with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d(ragged)

embedding_lookup_result = (
embedding_ops.safe_embedding_lookup_sparse_v2(
embedding_weights, sparse_ids, sparse_weights, default_id=3,
allow_dense_grads=allow_dense_grads))
allow_fast_lookup=allow_fast_lookup))

self.assertAllClose(
embedding_lookup_result,
Expand All @@ -991,7 +1006,9 @@ def test_safe_embedding_lookup_sparse_return_special_vector(self, ragged, allow_
[True, False]
))
@test_util.run_deprecated_v1
def test_safe_embedding_lookup_sparse_no_weights(self, ragged, allow_dense_grads):
def test_safe_embedding_lookup_sparse_no_weights(self,
ragged,
allow_fast_lookup):
with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_2d(ragged)
Expand All @@ -1001,7 +1018,7 @@ def test_safe_embedding_lookup_sparse_no_weights(self, ragged, allow_dense_grads
embedding_weights,
sparse_ids,
None,
allow_dense_grads=allow_dense_grads))
allow_fast_lookup=allow_fast_lookup))

self.assertAllClose(
embedding_lookup_result,
Expand All @@ -1015,7 +1032,9 @@ def test_safe_embedding_lookup_sparse_no_weights(self, ragged, allow_dense_grads
[True, False]
))
@test_util.run_deprecated_v1
def test_safe_embedding_lookup_sparse_partitioned(self, ragged, allow_dense_grads):
def test_safe_embedding_lookup_sparse_partitioned(self,
ragged,
allow_fast_lookup):
with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_2d(ragged)
Expand All @@ -1025,7 +1044,7 @@ def test_safe_embedding_lookup_sparse_partitioned(self, ragged, allow_dense_grad
embedding_weights,
sparse_ids,
None,
allow_dense_grads=allow_dense_grads))
allow_fast_lookup=allow_fast_lookup))

embedding_weights_list = list(itertools.chain(*embedding_weights))
self.assertAllClose(
Expand All @@ -1040,7 +1059,10 @@ def test_safe_embedding_lookup_sparse_partitioned(self, ragged, allow_dense_grad
[True, False]
))
@test_util.run_deprecated_v1
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self, ragged, allow_dense_grads):
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(
self,
ragged,
allow_fast_lookup):
with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_2d(ragged)
Expand All @@ -1057,7 +1079,7 @@ def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self, rag
embedding_weights_constant,
sparse_ids,
sparse_weights,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)

@test_util.run_deprecated_v1
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
Expand Down
52 changes: 26 additions & 26 deletions tensorflow/python/ops/embedding_ops.py
Expand Up @@ -406,7 +406,7 @@ def embedding_lookup_sparse(params,
name=None,
combiner=None,
max_norm=None,
allow_dense_grads=False):
allow_fast_lookup=False):
"""Looks up embeddings for the given ids and weights from a list of tensors.
This op assumes that there is at least one id for each row in the dense tensor
Expand Down Expand Up @@ -443,10 +443,10 @@ def embedding_lookup_sparse(params,
of the squares of the weights. Defaults to `mean`.
max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
than this value, before combining.
allow_dense_grads: An optional boolean specifying whether to allow dense
gradients during training. Setting this flag to `True` can improve
performance when `params` is a single tensor and `max_norm` is `None` at
the expense of higher memory usage.
allow_fast_lookup: An optional boolean specifying whether to allow
simplified embedding lookups when `params` is a single tensor and
`max_norm` is `None`. Setting this flag to `True` during training can
cause the use of dense gradients with increased memory footprint.
Returns:
A dense tensor representing the combined embeddings for the
Expand Down Expand Up @@ -521,7 +521,7 @@ def embedding_lookup_sparse(params,

return embedding_lookup_sparse_impl(params, segment_ids, sp_weights,
ids, combiner, ignore_weights, max_norm,
allow_dense_grads, partition_strategy,
allow_fast_lookup, partition_strategy,
name)


Expand All @@ -532,7 +532,7 @@ def embedding_lookup_sparse_v2(params,
sp_weights,
combiner=None,
max_norm=None,
allow_dense_grads=False,
allow_fast_lookup=False,
name=None):
"""Looks up embeddings for the given ids and weights from a list of tensors.
Expand Down Expand Up @@ -573,10 +573,10 @@ def embedding_lookup_sparse_v2(params,
of the squares of the weights. Defaults to `mean`.
max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
than this value, before combining.
allow_dense_grads: An optional boolean specifying whether to allow dense
gradients during training. Setting this flag to `True` can improve
performance when `params` is a single tensor and `max_norm` is `None` at
the expense of higher memory usage.
allow_fast_lookup: An optional boolean specifying whether to allow
simplified embedding lookups when `params` is a single tensor and
`max_norm` is `None`. Setting this flag to `True` during training can
cause the use of dense gradients with increased memory footprint.
name: Optional name for the op.
Returns:
Expand Down Expand Up @@ -620,7 +620,7 @@ def embedding_lookup_sparse_v2(params,
ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
"""
return embedding_lookup_sparse(params, sp_ids, sp_weights, "div", name,
combiner, max_norm, allow_dense_grads)
combiner, max_norm, allow_fast_lookup)


@tf_export("nn.safe_embedding_lookup_sparse", v1=[])
Expand All @@ -631,7 +631,7 @@ def safe_embedding_lookup_sparse_v2(embedding_weights,
combiner="mean",
default_id=None,
max_norm=None,
allow_dense_grads=False,
allow_fast_lookup=False,
name=None):
"""Lookup embedding results, accounting for invalid IDs and empty features.
Expand Down Expand Up @@ -675,10 +675,10 @@ def safe_embedding_lookup_sparse_v2(embedding_weights,
0-vector.
max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
combining.
allow_dense_grads: An optional boolean specifying whether to allow dense
gradients during training. Setting this flag to `True` can improve
performance when `params` is a single tensor and `max_norm` is `None` at
the expense of higher memory usage.
allow_fast_lookup: An optional boolean specifying whether to allow
simplified embedding lookups when `params` is a single tensor and
`max_norm` is `None`. Setting this flag to `True` during training can
cause the use of dense gradients with increased memory footprint.
name: A name for this operation (optional).
Returns:
Expand Down Expand Up @@ -730,7 +730,7 @@ def safe_embedding_lookup_sparse_v2(embedding_weights,
name=name,
partition_strategy="div",
max_norm=max_norm,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)


@tf_export(v1=["nn.safe_embedding_lookup_sparse"])
Expand All @@ -743,7 +743,7 @@ def safe_embedding_lookup_sparse(embedding_weights,
name=None,
partition_strategy="div",
max_norm=None,
allow_dense_grads=False):
allow_fast_lookup=False):
"""Lookup embedding results, accounting for invalid IDs and empty features.
The partitioned embedding in `embedding_weights` must all be the same shape
Expand Down Expand Up @@ -783,10 +783,10 @@ def safe_embedding_lookup_sparse(embedding_weights,
`"div"` and `"mod"` are supported. Default is `"div"`.
max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
combining.
allow_dense_grads: An optional boolean specifying whether to allow dense
gradients during training. Setting this flag to `True` can improve
performance when `params` is a single tensor and `max_norm` is `None` at
the expense of higher memory usage.
allow_fast_lookup: An optional boolean specifying whether to allow
simplified embedding lookups when `params` is a single tensor and
`max_norm` is `None`. Setting this flag to `True` during training can
cause the use of dense gradients with increased memory footprint.
Returns:
A dense tensor representing the combined embeddings for the
Expand Down Expand Up @@ -884,7 +884,7 @@ def safe_embedding_lookup_sparse(embedding_weights,
partition_strategy=partition_strategy,
name=None if default_id is None else scope,
max_norm=max_norm,
allow_dense_grads=allow_dense_grads)
allow_fast_lookup=allow_fast_lookup)

if default_id is None:
# Broadcast is_row_empty to the same shape as embedding_lookup_result,
Expand Down Expand Up @@ -919,11 +919,11 @@ def embedding_lookup_sparse_impl(params,
combiner,
ignore_weights,
max_norm,
allow_dense_grads,
allow_fast_lookup,
partition_strategy,
name):
"""Implementation of sparse embedding aggregation."""
if len(params) == 1 and max_norm is None and allow_dense_grads:
if len(params) == 1 and max_norm is None and allow_fast_lookup:
idx = ids
embeddings = params[0]
else:
Expand Down

0 comments on commit c115b80

Please sign in to comment.