Skip to content

tf.nn larger support for RaggedTensor #45476

@zaccharieramzi

Description

@zaccharieramzi

System information

  • TensorFlow version (you are using): 2.3
  • Are you willing to contribute it (Yes/No): Yes

Describe the feature and the current behavior/state.

I would like to be able to use tf.nn, and in turn tf.keras.layers such as Conv2D on RaggedTensor.

Will this change the current api? How?

I think it shouldn't in any way, just that keras layers and tf.nn operations should accept ragged tensors.

Who will benefit with this feature?

Anyone wanting to use keras layers on ragged tensors.
I am giving a bit more explanation on my specific use case below.

Any Other info.

So in my use case, I want to perform distributed training using images that have widely different sizes (and it doesn't make sense to pad them). I would want for each of the elements of the ragged tensor to be distributed on different GPUs.
But if I use a ragged tensor as is, it is not supported by my models with an error like:

TypeError: Failed to convert object of type <class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'> to Tensor. Contents: tf.RaggedTensor(values=tf.RaggedTensor(values=tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("RaggedFromVariant/RaggedTensorFromVariant:4", shape=(None, 1), dtype=complex64), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:3", shape=(None,), dtype=int64)), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:2", shape=(None,), dtype=int64)), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:1", shape=(None,), dtype=int64)), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:0", shape=(None,), dtype=int64)). Consider casting elements to a supported type.

Even if I use the setting advised here, I still get this error.

Metadata

Metadata

Labels

comp:apisHighlevel API related issuesstaleThis label marks the issue/pr stale - to be closed automatically if no activitystat:awaiting responseStatus - Awaiting response from authorstat:awaiting tensorflowerStatus - Awaiting response from tensorflowertype:featureFeature requests

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions