Skip to content

Keras support for RaggedTensors #27170

@jackd

Description

@jackd

System Information

  • TensorFlow version: 1.13.1 (issue present in 2.0 alpha
  • Are you willing to contribute it: Yes

Current State/Behaviour

tf.RaggedTensors do a fantastic job of masking their internal representation from the user, allowing them to be used as regular tensors in most contexts. This does not extend to keras models however. As an example:

import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda
from tensorflow.keras.models import Model
if not hasattr(tf, 'nest'):
    tf.nest = tf.contrib.framework.nest


values = Input(shape=(10,), dtype=tf.float32)

## non-ragged version
indices = Input(shape=(5,), dtype=tf.int64)
gathered = Lambda(lambda args: tf.gather(*args))([values, indices])
Model(inputs=(values, indices), outputs=gathered)
# works fine

## ragged version
index_values = Input(shape=(), dtype=tf.int64)
index_row_splits = Input(shape=(), dtype=tf.int64)

indices = tf.RaggedTensor.from_row_splits(index_values, index_row_splits)
gathered = Lambda(lambda args: tf.gather(*args))([values, indices])
Model(inputs=(values, indices), outputs=gathered)
# raises
# 1.13.1: ValueError: Input tensors to a Model must come from `tf.keras.Input`.
# Received: tf.RaggedTensor(values=Tensor("input_3:0", shape=(?,), dtype=int64),
# row_splits=Tensor("input_4:0", shape=(?,), dtype=int64))
# (missing previous layer metadata).
# 2.0: AttributeError: 'RaggedTensor' object has no attribute 'op'

Workaround

The operations acting on the component tensors are fine. As a work around, one can do the following:

def ragged_tensor_from_row_lengths(values, row_lengths):
    def components(args):
        values, row_lengths = args
        ragged = tf.RaggedTensor.from_row_lengths(values, row_lengths)
        return ragged.values, ragged.row_splits

    components = tf.keras.layers.Lambda(components)([values, row_lengths])
    return tf.RaggedTensor.from_row_splits(*components)


def as_ragged_components(tensor):
    if isinstance(tensor, tf.RaggedTensor):
        return dict(values=tensor.values, row_splits=tensor.row_splits)
    elif isinstance(tensor, (list, tuple)):
        return tuple(as_ragged_components(t) for t in tensor)
    elif isinstance(tensor, dict):
        return {k: as_ragged_components(v) for k, v in tensor.items()}
    else:
        # leave unchanged
        assert(isinstance(tensor, tf.Tensor))
        return tensor


def as_ragged(components):
    if isinstance(components, (list, tuple)):
        return tuple(as_ragged(c) for c in components)
    elif isinstance(components, dict):
        if all(k in components for k in ('values', 'row_splits')):
            return tf.RaggedTensor.from_row_splits(**components)
        else:
            return {k: as_ragged(v) for k, v in components.items()}
    else:
        assert(isinstance(components, tf.Tensor))
        return components


def ragged_lambda(fn, args):
    assert(isinstance(args, (list, tuple)))
    if not any(isinstance(a, tf.RaggedTensor) for a in args):
        out_components = tf.keras.layers.Lambda(fn)(args)
    else:
        components = as_ragged_components(args)
        flat_args = tf.nest.flatten(components)

        def actual_fn(flat_args):
            args = tf.nest.pack_sequence_as(components, flat_args)
            args = as_ragged(components)
            out = fn(args)
            return as_ragged_components(out)

        out_components = tf.keras.layers.Lambda(actual_fn)(flat_args)
    return as_ragged(out_components)


gathered = ragged_lambda(
    lambda args: tf.gather(*args), [values, indices])

gathered_components = tf.nest.flatten(as_ragged_components(gathered))
Model(
    inputs=(values, index_values, index_row_splits),
    outputs=gathered_components)

This is, however, very convoluted, and largely defeats the purpose of having compound tensors like RaggedTensors that should be able to be used transparently in place of regular tensors.

Change to API

Nothing, though will allow RaggedTensors to be used in keras Models.

Who will benefit

People who appreciate clean code.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions