Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make safe_embedding_lookup_sparse method public and clean duplicate codes #18771

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import itertools
import math
import sys

import numpy as np

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,8 @@ py_library(
":math_ops",
":platform",
":resource_variable_ops",
":sparse_ops",
":tensor_shape",
":variables",
],
)
Expand Down
161 changes: 3 additions & 158 deletions tensorflow/python/feature_column/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2058,7 +2058,7 @@ def _create_categorical_column_weighted_sum(column,
initializer=init_ops.zeros_initializer(),
trainable=trainable,
collections=weight_collections)
return _safe_embedding_lookup_sparse(
return embedding_ops.safe_embedding_lookup_sparse(
weight,
id_tensor,
sparse_weights=weight_tensor,
Expand Down Expand Up @@ -2479,7 +2479,7 @@ def _get_dense_tensor_internal(self,
})

# Return embedding lookup result.
return _safe_embedding_lookup_sparse(
return embedding_ops.safe_embedding_lookup_sparse(
embedding_weights=embedding_weights,
sparse_ids=sparse_ids,
sparse_weights=sparse_weights,
Expand Down Expand Up @@ -2612,7 +2612,7 @@ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
})

# Return embedding lookup result.
return _safe_embedding_lookup_sparse(
return embedding_ops.safe_embedding_lookup_sparse(
embedding_weights=embedding_weights,
sparse_ids=sparse_ids,
sparse_weights=sparse_weights,
Expand Down Expand Up @@ -3065,161 +3065,6 @@ def _collect_leaf_level_keys(cross):
return leaf_level_keys


# TODO(zakaria): Move this to embedding_ops and make it public.
def _safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
sparse_weights=None,
combiner='mean',
default_id=None,
name=None,
partition_strategy='div',
max_norm=None):
"""Lookup embedding results, accounting for invalid IDs and empty features.

The partitioned embedding in `embedding_weights` must all be the same shape
except for the first dimension. The first dimension is allowed to vary as the
vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
partitioner.

Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
with non-positive weight. For an entry with no features, the embedding vector
for `default_id` is returned, or the 0-vector if `default_id` is not supplied.

The ids and weights may be multi-dimensional. Embeddings are always aggregated
along the last dimension.

Args:
embedding_weights: A list of `P` float `Tensor`s or values representing
partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
created by partitioning along dimension 0. The total unpartitioned
shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
vocab size and `e_1, ..., e_m` are the embedding dimensions.
sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
ids. `d_0` is typically batch size.
sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
float weights corresponding to `sparse_ids`, or `None` if all weights
are be assumed to be 1.0.
combiner: A string specifying how to combine embedding results for each
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
the default.
default_id: The id to use for an entry with no features.
name: A name for this operation (optional).
partition_strategy: A string specifying the partitioning strategy.
Currently `"div"` and `"mod"` are supported. Default is `"div"`.
max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
combining.


Returns:
Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.

Raises:
ValueError: if `embedding_weights` is empty.
"""
if embedding_weights is None:
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
if isinstance(embedding_weights, variables.PartitionedVariable):
embedding_weights = list(embedding_weights) # get underlying Variables.
if not isinstance(embedding_weights, list):
embedding_weights = [embedding_weights]
if len(embedding_weights) < 1:
raise ValueError('Missing embedding_weights %s.' % embedding_weights)

dtype = sparse_weights.dtype if sparse_weights is not None else None
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]

with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
sparse_weights]) as scope:
# Reshape higher-rank sparse ids and weights to linear segment ids.
original_shape = sparse_ids.dense_shape
original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
original_rank = (
array_ops.size(original_shape)
if original_rank_dim.value is None
else original_rank_dim.value)
sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
math_ops.reduce_prod(
array_ops.slice(original_shape, [0], [original_rank - 1])),
array_ops.gather(original_shape, original_rank - 1)])
if sparse_weights is not None:
sparse_weights = sparse_tensor_lib.SparseTensor(
sparse_ids.indices,
sparse_weights.values, sparse_ids.dense_shape)

# Prune invalid ids and weights.
sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
if combiner != 'sum':
sparse_ids, sparse_weights = _prune_invalid_weights(
sparse_ids, sparse_weights)

# Fill in dummy values for empty features, if necessary.
sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
default_id or
0)
if sparse_weights is not None:
sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)

result = embedding_ops.embedding_lookup_sparse(
embedding_weights,
sparse_ids,
sparse_weights,
combiner=combiner,
partition_strategy=partition_strategy,
name=None if default_id is None else scope,
max_norm=max_norm)

if default_id is None:
# Broadcast is_row_empty to the same shape as embedding_lookup_result,
# for use in Select.
is_row_empty = array_ops.tile(
array_ops.reshape(is_row_empty, [-1, 1]),
array_ops.stack([1, array_ops.shape(result)[1]]))

result = array_ops.where(is_row_empty,
array_ops.zeros_like(result),
result,
name=scope)

# Reshape back from linear ids back into higher-dimensional dense result.
final_result = array_ops.reshape(
result,
array_ops.concat([
array_ops.slice(
math_ops.cast(original_shape, dtypes.int32), [0],
[original_rank - 1]),
array_ops.slice(array_ops.shape(result), [1], [-1])
], 0))
final_result.set_shape(tensor_shape.unknown_shape(
(original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
return final_result


def _prune_invalid_ids(sparse_ids, sparse_weights):
"""Prune invalid IDs (< 0) from the input ids and weights."""
is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
if sparse_weights is not None:
is_id_valid = math_ops.logical_and(
is_id_valid,
array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
if sparse_weights is not None:
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
return sparse_ids, sparse_weights


def _prune_invalid_weights(sparse_ids, sparse_weights):
"""Prune invalid weights (< 0) from the input ids and weights."""
if sparse_weights is not None:
is_weights_valid = math_ops.greater(sparse_weights.values, 0)
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
return sparse_ids, sparse_weights


class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
collections.namedtuple('_IndicatorColumn',
['categorical_column'])):
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2718,6 +2718,7 @@ cuda_py_test(
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:partitioned_variables",
Expand Down