Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add normalizer_fn support for sequence_numeric_column #19649

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def sequence_numeric_column(
key,
shape=(1,),
default_value=0.,
dtype=dtypes.float32):
dtype=dtypes.float32,
normalizer_fn=None):
"""Returns a feature column that represents sequences of numeric data.

Example:
Expand All @@ -370,6 +371,12 @@ def sequence_numeric_column(
default_value: A single value compatible with `dtype` that is used for
padding the sparse data into a dense `Tensor`.
dtype: The type of values.
normalizer_fn: If not `None`, a function that can be used to normalize the
value of the tensor after `default_value` is applied for parsing.
Normalizer function takes the input `Tensor` as its argument, and returns
the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
even though the most common use case of this function is normalization, it
can be used for any kind of Tensorflow transformations.

Returns:
A `_SequenceNumericColumn`.
Expand All @@ -383,12 +390,16 @@ def sequence_numeric_column(
if not (dtype.is_integer or dtype.is_floating):
raise ValueError('dtype must be convertible to float. '
'dtype: {}, key: {}'.format(dtype, key))
if normalizer_fn is not None and not callable(normalizer_fn):
raise TypeError(
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))

return _SequenceNumericColumn(
key,
shape=shape,
default_value=default_value,
dtype=dtype)
dtype=dtype,
normalizer_fn=normalizer_fn)


def _assert_all_equal_and_return(tensors, name=None):
Expand All @@ -407,7 +418,7 @@ class _SequenceNumericColumn(
fc._SequenceDenseColumn,
collections.namedtuple(
'_SequenceNumericColumn',
['key', 'shape', 'default_value', 'dtype'])):
['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])):
"""Represents sequences of numeric data."""

@property
Expand All @@ -419,7 +430,10 @@ def _parse_example_spec(self):
return {self.key: parsing_ops.VarLenFeature(self.dtype)}

def _transform_feature(self, inputs):
return inputs.get(self.key)
input_tensor = inputs.get(self.key)
if self.normalizer_fn is not None:
input_tensor = self.normalizer_fn(input_tensor)
return input_tensor

@property
def _variable_shape(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session

Expand Down Expand Up @@ -670,6 +671,7 @@ def test_defaults(self):
self.assertEqual((1,), a.shape)
self.assertEqual(0., a.default_value)
self.assertEqual(dtypes.float32, a.dtype)
self.assertIsNone(a.normalizer_fn)

def test_shape_saved_as_tuple(self):
a = sfc.sequence_numeric_column('aaa', shape=[1, 2])
Expand All @@ -688,6 +690,10 @@ def test_dtype_is_convertible_to_float(self):
ValueError, 'dtype must be convertible to float'):
sfc.sequence_numeric_column('aaa', dtype=dtypes.string)

def test_normalizer_fn_must_be_callable(self):
with self.assertRaisesRegexp(TypeError, 'must be a callable'):
sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable')

def test_get_sequence_dense_tensor(self):
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, values [[0.], [1]]
Expand All @@ -708,6 +714,41 @@ def test_get_sequence_dense_tensor(self):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))

def test_get_sequence_dense_tensor_with_normalizer_fn(self):

def _increment_two(input_sparse_tensor):
return sparse_ops.sparse_add(
input_sparse_tensor,
sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2))
)

sparse_input = sparse_tensor.SparseTensorValue(
# example 0, values [[0.], [1]]
# example 1, [[10.]]
indices=((0, 0), (0, 1), (1, 0)),
values=(0., 1., 10.),
dense_shape=(2, 2))

# Before _increment_two:
# [[0.], [1.]],
# [[10.], [0.]],
# After _increment_two:
# [[2.], [1.]],
# [[10.], [2.]],
expected_dense_tensor = [
[[2.], [1.]],
[[10.], [2.]],
]
numeric_column = sfc.sequence_numeric_column(
'aaa', normalizer_fn=_increment_two)

dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': sparse_input}))

with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))

def test_get_sequence_dense_tensor_with_shape(self):
"""Tests get_sequence_dense_tensor with shape !=(1,)."""
sparse_input = sparse_tensor.SparseTensorValue(
Expand Down