Skip to content

Commit

Permalink
Adding bow_encoder and embed_sequence to encode sequence of symbols (…
Browse files Browse the repository at this point in the history
…words/categories) into representation.

Additionally:
 - dense_to_sparse op conversion, allows to convert large tensors with a lot of 0s into sparse tensors.
 - embedding_lookup_unique op - allows to reduce bandwidth when looking embeddings
Change: 137539118
  • Loading branch information
Illia Polosukhin authored and tensorflower-gardener committed Oct 28, 2016
1 parent 6ab39dc commit d86d570
Show file tree
Hide file tree
Showing 9 changed files with 563 additions and 1 deletion.
26 changes: 26 additions & 0 deletions tensorflow/contrib/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,32 @@ py_test(
],
)

py_test(
name = "sparse_ops_test",
size = "small",
srcs = ["python/ops/sparse_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

py_test(
name = "encoders_test",
size = "small",
srcs = ["python/layers/encoders_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

filegroup(
name = "all_files",
srcs = glob(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/contrib/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@

# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.layers.python.layers import *
from tensorflow.contrib.layers.python.ops import sparse_ops
from tensorflow.python.util.all_util import make_all

__all__ = make_all(__name__)
1 change: 1 addition & 0 deletions tensorflow/contrib/layers/python/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# pylint: disable=wildcard-import
from tensorflow.contrib.layers.python.layers.embedding_ops import *
from tensorflow.contrib.layers.python.layers.encoders import *
from tensorflow.contrib.layers.python.layers.feature_column import *
from tensorflow.contrib.layers.python.layers.feature_column_ops import *
from tensorflow.contrib.layers.python.layers.initializers import *
Expand Down
36 changes: 35 additions & 1 deletion tensorflow/contrib/layers/python/layers/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tensorflow.python.platform import tf_logging as logging

__all__ = ["safe_embedding_lookup_sparse", "hashed_embedding_lookup",
"hashed_embedding_lookup_sparse"]
"hashed_embedding_lookup_sparse", "embedding_lookup_unique"]


def safe_embedding_lookup_sparse(embedding_weights,
Expand Down Expand Up @@ -340,3 +340,37 @@ def hashed_embedding_lookup_sparse(params,
raise ValueError("Combiner must be one of 'mean', 'sqrtn' or 'sum'.")

return embeddings


def embedding_lookup_unique(params, ids, name=None):
"""Version of embedding_lookup that avoids duplicate lookups.
This can save communication in the case of repeated ids.
Same interface as embedding_lookup.
Args:
params: A list of tensors with the same shape and type, or a
`PartitionedVariable`.
ids: A one-dimensional `Tensor` with type `int32` or `int64` containing
the ids to be looked up in `params`.
name: A name for this operation (optional).
Returns:
A `Tensor` with the same type as the tensors in `params`.
Raises:
ValueError: If `params` is empty.
"""
with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]):
params = ops.convert_to_tensor(params)
ids = ops.convert_to_tensor(ids)
shape = array_ops.shape(ids)
ids_flat = array_ops.reshape(
ids, math_ops.reduce_prod(shape, keep_dims=True))
unique_ids, idx = array_ops.unique(ids_flat)
unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids)
embeds_flat = array_ops.gather(unique_embeddings, idx)
embed_shape = array_ops.concat(0, [shape, [-1]])
embeds = array_ops.reshape(embeds_flat, embed_shape)
embeds.set_shape(ids.get_shape().concatenate(params.get_shape()[1:]))
return embeds
16 changes: 16 additions & 0 deletions tensorflow/contrib/layers/python/layers/embedding_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,21 @@ def test_hashed_embedding_lookup_sparse(self):
0.5 * (embedding_lookup_result[0] +
embedding_lookup_result[3]))

def test_embedding_lookup_unique(self):
d_embed = 5
n_embed = 10
idx_shape = (2, 3, 4)
embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape)

with self.test_session():
embedded_np = embeds[idx]
embedded_tf = tf.contrib.layers.embedding_lookup_unique(
embeds, idx).eval()

self.assertEqual(embedded_np.shape, embedded_tf.shape)
np.testing.assert_almost_equal(embedded_np, embedded_tf)


if __name__ == "__main__":
tf.test.main()
142 changes: 142 additions & 0 deletions tensorflow/contrib/layers/python/layers/encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Encoders to transform sequence of symbols into vector representation."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import embedding_ops as contrib_embedding_ops
from tensorflow.contrib.layers.python.ops import sparse_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope

__all__ = ['bow_encoder', 'embed_sequence']


def bow_encoder(ids,
vocab_size,
embed_dim,
sparse_lookup=True,
initializer=None,
regularizer=None,
trainable=True,
scope=None,
reuse=None):
"""Maps a sequence of symbols to a vector per example by averaging embeddings.
Args:
ids: `[batch_size, doc_length]` `Tensor` or `SparseTensor` of type
`int32` or `int64` with symbol ids.
vocab_size: Integer number of symbols in vocabulary.
embed_dim: Integer number of dimensions for embedding matrix.
sparse_lookup: `bool`, if `True`, converts ids to a `SparseTensor`
and performs a sparse embedding lookup. This is usually faster,
but not desirable if padding tokens should have an embedding. Empty rows
are assigned a special embedding.
initializer: An initializer for the embeddings, if `None` default for
current scope is used.
regularizer: Optional regularizer for the embeddings.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
scope: Optional string specifying the variable scope for the op, required
if `reuse=True`.
reuse: If `True`, variables inside the op will be reused.
Returns:
Encoding `Tensor` `[batch_size, embed_dim]` produced by
averaging embeddings.
Raises:
ValueError: If `embed_dim` or `vocab_size` are not specified.
"""
if not vocab_size or not embed_dim:
raise ValueError('Must specify vocab size and embedding dimension')
with variable_scope.variable_scope(
scope, 'bow_encoder', [ids], reuse=reuse):
embeddings = variables.model_variable(
'embeddings', shape=[vocab_size, embed_dim],
initializer=initializer, regularizer=regularizer,
trainable=trainable)
if sparse_lookup:
if isinstance(ids, ops.SparseTensor):
sparse_ids = ids
else:
sparse_ids = sparse_ops.dense_to_sparse_tensor(ids)
return contrib_embedding_ops.safe_embedding_lookup_sparse(
[embeddings], sparse_ids, combiner='mean', default_id=0)
else:
if isinstance(ids, ops.SparseTensor):
raise TypeError('ids are expected to be dense Tensor, got: %s', ids)
return math_ops.reduce_mean(
embedding_ops.embedding_lookup(embeddings, ids),
reduction_indices=1)


def embed_sequence(ids,
vocab_size=None,
embed_dim=None,
unique=False,
initializer=None,
regularizer=None,
trainable=True,
scope=None,
reuse=None):
"""Maps a sequence of symbols to a sequence of embeddings.
Typical use case would be reusing embeddings between an encoder and decoder.
Args:
ids: `[batch_size, doc_length]` `Tensor` of type `int32` or `int64`
with symbol ids.
vocab_size: Integer number of symbols in vocabulary.
embed_dim: Integer number of dimensions for embedding matrix.
unique: If `True`, will first compute the unique set of indices, and then
lookup each embedding once, repeating them in the output as needed.
initializer: An initializer for the embeddings, if `None` default for
current scope is used.
regularizer: Optional regularizer for the embeddings.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
scope: Optional string specifying the variable scope for the op, required
if `reuse=True`.
reuse: If `True`, variables inside the op will be reused.
Returns:
`Tensor` of `[batch_size, doc_length, embed_dim]` with embedded sequences.
Raises:
ValueError: if `embed_dim` or `vocab_size` are not specified when not
`reuse` is `None` or `False`.
"""
if not (reuse or (vocab_size and embed_dim)):
raise ValueError('Must specify vocab size and embedding dimension when not'
'reusing. Got vocab_size=%s and embed_dim=%s' % (
vocab_size, embed_dim))
with variable_scope.variable_scope(
scope, 'EmbedSequence', [ids], reuse=reuse):
shape = [vocab_size, embed_dim]
if reuse and vocab_size is None or embed_dim is None:
shape = None
embeddings = variables.model_variable(
'embeddings', shape=shape,
initializer=initializer, regularizer=regularizer,
trainable=trainable)
if unique:
return contrib_embedding_ops.embedding_lookup_unique(embeddings, ids)
return embedding_ops.embedding_lookup(embeddings, ids)
126 changes: 126 additions & 0 deletions tensorflow/contrib/layers/python/layers/encoders_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.contrib.layers.python.layers.encoders."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.contrib.layers.python.layers import encoders


def _get_const_var(name, shape, value):
return tf.get_variable(name,
shape,
initializer=tf.constant_initializer(value))


class EncodersTest(tf.test.TestCase):

def testBowEncoderSparse(self):
with self.test_session() as sess:
docs = [[0, 1], [2, 3]]
enc = encoders.bow_encoder(docs, 4, 3)
sess.run(tf.initialize_all_variables())
self.assertAllEqual([2, 3], enc.eval().shape)

def testBowEncoderSparseTensor(self):
with self.test_session() as sess:
docs = [[0, 1], [2, 3]]
sparse_docs = tf.contrib.layers.sparse_ops.dense_to_sparse_tensor(docs)
enc = encoders.bow_encoder(sparse_docs, 4, 3)
sess.run(tf.initialize_all_variables())
self.assertAllEqual([2, 3], enc.eval().shape)

def testBowEncoderSparseEmptyRow(self):
with self.test_session() as sess:
docs = [[0, 1], [2, 3], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 5)
sess.run(tf.initialize_all_variables())
self.assertAllEqual([3, 5], enc.eval().shape)

def testBowEncoderDense(self):
with self.test_session() as sess:
docs = [[0, 1], [2, 3], [0, 0], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 3, sparse_lookup=False)
sess.run(tf.initialize_all_variables())
self.assertAllEqual([4, 3], enc.eval().shape)

def testBowEncoderSparseTensorDenseLookup(self):
with self.test_session():
docs = [[0, 1]]
sparse_docs = tf.contrib.layers.sparse_ops.dense_to_sparse_tensor(docs)
with self.assertRaises(TypeError):
encoders.bow_encoder(sparse_docs, 4, 3, sparse_lookup=False)

def testBowEncodersSharingEmbeddings(self):
with self.test_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='test')
enc_2 = encoders.bow_encoder(docs, 4, 3, scope='test', reuse=True)
sess.run(tf.initialize_all_variables())
avg_1, avg_2 = sess.run([enc_1, enc_2])
self.assertAllEqual(avg_1, avg_2)

def testBowEncodersSharingEmbeddingsInheritedScopes(self):
with self.test_session() as sess:
docs = [[0, 1], [2, 3]]
with tf.variable_scope('test'):
enc_1 = encoders.bow_encoder(docs, 4, 3)
with tf.variable_scope('test', reuse=True):
enc_2 = encoders.bow_encoder(docs, 4, 3)
sess.run(tf.initialize_all_variables())
avg_1, avg_2 = sess.run([enc_1, enc_2])
self.assertAllEqual(avg_1, avg_2)

def testBowEncodersSharingEmbeddingsSharedScope(self):
with self.test_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='bow')
tf.get_variable_scope().reuse_variables()
enc_2 = encoders.bow_encoder(docs, 4, 3, scope='bow')
sess.run(tf.initialize_all_variables())
avg_1, avg_2 = sess.run([enc_1, enc_2])
self.assertAllEqual(avg_1, avg_2)

def testBowEncoderReuseEmbeddingsVariable(self):
with self.test_session() as sess:
docs = [[1, 1], [2, 3]]
with tf.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
self.assertEqual(v.name, 'test/embeddings:0')
enc = encoders.bow_encoder(docs, 4, 3, scope='test', reuse=True)
sess.run(tf.initialize_all_variables())
self.assertAllClose([[3., 4., 5.], [7.5, 8.5, 9.5]], enc.eval())

def testEmbedSequence(self):
with self.test_session() as sess:
docs = [[1, 1], [2, 3]]
with tf.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])
self.assertEqual(v.name, 'test/embeddings:0')
emb = encoders.embed_sequence(docs, 4, 3, scope='test', reuse=True)
sess.run(tf.initialize_all_variables())
self.assertAllClose(
[[[3., 4., 5.], [3., 4., 5.]], [[6., 7., 8.], [9., 10., 11.]]],
emb.eval())


if __name__ == '__main__':
tf.test.main()

0 comments on commit d86d570

Please sign in to comment.