diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py index 3497c2678f4f34..453f76220c7267 100644 --- a/tensorflow/python/data/experimental/ops/grouping.py +++ b/tensorflow/python/data/experimental/ops/grouping.py @@ -167,6 +167,64 @@ def bucket_by_sequence_length(element_length_func, [[ 0 0] [21 22]] + There is also a possibility to pad the dataset till the bucket boundary. + You can also provide which value to be used while padding the data. + Below example uses `-1` as padding and it also shows the input data + being bucketizied to two buckets "[0,3], [4,6]". + + >>> elements = [ + ... [0], [1, 2, 3, 4], [5, 6, 7], + ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] + + >>> dataset = tf.data.Dataset.from_generator( + ... lambda: elements, tf.int32, output_shapes=[None]) + + >>> dataset = dataset.apply( + ... tf.data.experimental.bucket_by_sequence_length( + ... element_length_func=lambda elem: tf.shape(elem)[0], + ... bucket_boundaries=[4, 7], + ... bucket_batch_sizes=[2, 2, 2], + ... pad_to_bucket_boundary=True, + ... padding_values=-1)) + + >>> for elem in dataset.as_numpy_iterator(): + ... print(elem) + [[ 0 -1 -1] + [ 5 6 7]] + [[ 1 2 3 4 -1 -1] + [ 7 8 9 10 11 -1]] + [[21 22 -1]] + [[13 14 15 16 19 20]] + + When using `pad_to_bucket_boundary` option, it can be seen that it is + not always possible to maintain the bucket batch size. + You can drop the batches that do not maintain the bucket batch size by + using the option `drop_remainder`. Using the same input data as in the + above example you get the following result. + + >>> elements = [ + ... [0], [1, 2, 3, 4], [5, 6, 7], + ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] + + >>> dataset = tf.data.Dataset.from_generator( + ... lambda: elements, tf.int32, output_shapes=[None]) + + >>> dataset = dataset.apply( + ... tf.data.experimental.bucket_by_sequence_length( + ... element_length_func=lambda elem: tf.shape(elem)[0], + ... bucket_boundaries=[4, 7], + ... bucket_batch_sizes=[2, 2, 2], + ... pad_to_bucket_boundary=True, + ... padding_values=-1, + ... drop_remainder=True)) + + >>> for elem in dataset.as_numpy_iterator(): + ... print(elem) + [[ 0 -1 -1] + [ 5 6 7]] + [[ 1 2 3 4 -1 -1] + [ 7 8 9 10 11 -1]] + Args: element_length_func: function from element in `Dataset` to `tf.int32`, determines the length of the element, which will determine the bucket it