-
Notifications
You must be signed in to change notification settings - Fork 75.3k
Keras support for RaggedTensors #27170
Copy link
Copy link
Closed
Labels
comp:kerasKeras related issuesKeras related issuesstat:awaiting tensorflowerStatus - Awaiting response from tensorflowerStatus - Awaiting response from tensorflowertype:featureFeature requestsFeature requests
Description
System Information
- TensorFlow version: 1.13.1 (issue present in 2.0 alpha
- Are you willing to contribute it: Yes
Current State/Behaviour
tf.RaggedTensors do a fantastic job of masking their internal representation from the user, allowing them to be used as regular tensors in most contexts. This does not extend to keras models however. As an example:
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda
from tensorflow.keras.models import Model
if not hasattr(tf, 'nest'):
tf.nest = tf.contrib.framework.nest
values = Input(shape=(10,), dtype=tf.float32)
## non-ragged version
indices = Input(shape=(5,), dtype=tf.int64)
gathered = Lambda(lambda args: tf.gather(*args))([values, indices])
Model(inputs=(values, indices), outputs=gathered)
# works fine
## ragged version
index_values = Input(shape=(), dtype=tf.int64)
index_row_splits = Input(shape=(), dtype=tf.int64)
indices = tf.RaggedTensor.from_row_splits(index_values, index_row_splits)
gathered = Lambda(lambda args: tf.gather(*args))([values, indices])
Model(inputs=(values, indices), outputs=gathered)
# raises
# 1.13.1: ValueError: Input tensors to a Model must come from `tf.keras.Input`.
# Received: tf.RaggedTensor(values=Tensor("input_3:0", shape=(?,), dtype=int64),
# row_splits=Tensor("input_4:0", shape=(?,), dtype=int64))
# (missing previous layer metadata).
# 2.0: AttributeError: 'RaggedTensor' object has no attribute 'op'Workaround
The operations acting on the component tensors are fine. As a work around, one can do the following:
def ragged_tensor_from_row_lengths(values, row_lengths):
def components(args):
values, row_lengths = args
ragged = tf.RaggedTensor.from_row_lengths(values, row_lengths)
return ragged.values, ragged.row_splits
components = tf.keras.layers.Lambda(components)([values, row_lengths])
return tf.RaggedTensor.from_row_splits(*components)
def as_ragged_components(tensor):
if isinstance(tensor, tf.RaggedTensor):
return dict(values=tensor.values, row_splits=tensor.row_splits)
elif isinstance(tensor, (list, tuple)):
return tuple(as_ragged_components(t) for t in tensor)
elif isinstance(tensor, dict):
return {k: as_ragged_components(v) for k, v in tensor.items()}
else:
# leave unchanged
assert(isinstance(tensor, tf.Tensor))
return tensor
def as_ragged(components):
if isinstance(components, (list, tuple)):
return tuple(as_ragged(c) for c in components)
elif isinstance(components, dict):
if all(k in components for k in ('values', 'row_splits')):
return tf.RaggedTensor.from_row_splits(**components)
else:
return {k: as_ragged(v) for k, v in components.items()}
else:
assert(isinstance(components, tf.Tensor))
return components
def ragged_lambda(fn, args):
assert(isinstance(args, (list, tuple)))
if not any(isinstance(a, tf.RaggedTensor) for a in args):
out_components = tf.keras.layers.Lambda(fn)(args)
else:
components = as_ragged_components(args)
flat_args = tf.nest.flatten(components)
def actual_fn(flat_args):
args = tf.nest.pack_sequence_as(components, flat_args)
args = as_ragged(components)
out = fn(args)
return as_ragged_components(out)
out_components = tf.keras.layers.Lambda(actual_fn)(flat_args)
return as_ragged(out_components)
gathered = ragged_lambda(
lambda args: tf.gather(*args), [values, indices])
gathered_components = tf.nest.flatten(as_ragged_components(gathered))
Model(
inputs=(values, index_values, index_row_splits),
outputs=gathered_components)This is, however, very convoluted, and largely defeats the purpose of having compound tensors like RaggedTensors that should be able to be used transparently in place of regular tensors.
Change to API
Nothing, though will allow RaggedTensors to be used in keras Models.
Who will benefit
People who appreciate clean code.
Reactions are currently unavailable
Metadata
Metadata
Labels
comp:kerasKeras related issuesKeras related issuesstat:awaiting tensorflowerStatus - Awaiting response from tensorflowerStatus - Awaiting response from tensorflowertype:featureFeature requestsFeature requests