diff --git a/tensorflow_datasets/core/features/sequence_feature.py b/tensorflow_datasets/core/features/sequence_feature.py index 4a960b5bbfa..0f6f96e1bc2 100644 --- a/tensorflow_datasets/core/features/sequence_feature.py +++ b/tensorflow_datasets/core/features/sequence_feature.py @@ -262,9 +262,7 @@ def np_to_list(elem): elif isinstance(elem, tuple): return list(elem) elif isinstance(elem, np.ndarray): - elem = np.split(elem, elem.shape[0]) - elem = np.squeeze(elem, axis=0) - return elem + return list(elem) else: raise ValueError( 'Input elements of a sequence should be either a numpy array, a ' diff --git a/tensorflow_datasets/core/features/video_feature.py b/tensorflow_datasets/core/features/video_feature.py index 9b87e7683b4..152e681fc0e 100644 --- a/tensorflow_datasets/core/features/video_feature.py +++ b/tensorflow_datasets/core/features/video_feature.py @@ -29,10 +29,11 @@ class Video(sequence_feature.Sequence): """`FeatureConnector` for videos, png-encoding frames on disk. Video: The image connector accepts as input: - * uint8 array representing an video. + * uint8 array representing a video. Output: - video: tf.Tensor of type tf.uint8 and shape [num_frames, height, width, 3] + video: tf.Tensor of type tf.uint8 and shape + [num_frames, height, width, channels], where channels must be 1 or 3 Example: * In the DatasetInfo object: @@ -51,7 +52,7 @@ def __init__(self, shape): Args: shape: tuple of ints, the shape of the video (num_frames, height, width, - channels=3). + channels), where channels is 1 or 3. Raises: ValueError: If the shape is invalid diff --git a/tensorflow_datasets/video/__init__.py b/tensorflow_datasets/video/__init__.py index 1775a4a4ced..7eb4ab09ba6 100644 --- a/tensorflow_datasets/video/__init__.py +++ b/tensorflow_datasets/video/__init__.py @@ -16,5 +16,6 @@ """Video datasets.""" from tensorflow_datasets.video.bair_robot_pushing import BairRobotPushingSmall +from tensorflow_datasets.video.moving_mnist import MovingMnist from tensorflow_datasets.video.starcraft import StarcraftVideo from tensorflow_datasets.video.starcraft import StarcraftVideoConfig diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py new file mode 100644 index 00000000000..4b0cc86a09c --- /dev/null +++ b/tensorflow_datasets/video/moving_mnist.py @@ -0,0 +1,80 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import numpy as np +import tensorflow as tf +import tensorflow_datasets.public_api as tfds +from tensorflow_datasets.video.moving_sequence import image_as_moving_sequence # pylint: disable=unused-import + +_OUT_RESOLUTION = (64, 64) +_SEQUENCE_LENGTH = 20 +_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/" +_CITATION = """\ +@article{DBLP:journals/corr/SrivastavaMS15, + author = {Nitish Srivastava and + Elman Mansimov and + Ruslan Salakhutdinov}, + title = {Unsupervised Learning of Video Representations using LSTMs}, + journal = {CoRR}, + volume = {abs/1502.04681}, + year = {2015}, + url = {http://arxiv.org/abs/1502.04681}, + archivePrefix = {arXiv}, + eprint = {1502.04681}, + timestamp = {Mon, 13 Aug 2018 16:47:05 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/SrivastavaMS15}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" + + +class MovingMnist(tfds.core.GeneratorBasedBuilder): + + VERSION = tfds.core.Version("0.1.0") + + def _info(self): + return tfds.core.DatasetInfo( + builder=self, + description=( + "Moving variant of MNIST database of handwritten digits. This is the " + "data used by the authors for reporting model performance. See " + "`tfds.video.moving_mnist.image_as_moving_sequence` " + "for generating training/validation data from the MNIST dataset."), + features=tfds.features.FeaturesDict( + dict(image_sequence=tfds.features.Video( + shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))), + urls=[_URL], + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + data_path = dl_manager.download(_URL + "mnist_test_seq.npy") + + # authors only provide test data. + # See `tfds.video.moving_mnist.image_as_moving_sequence` for mapping + # function to create training/validation dataset from MNIST. + return [ + tfds.core.SplitGenerator( + name=tfds.Split.TEST, + num_shards=5, + gen_kwargs=dict(data_path=data_path)), + ] + + def _generate_examples(self, data_path): + """Generate MovingMnist sequences. + + Args: + data_path (str): Path to the data file + + Yields: + 20 x 64 x 64 x 1 uint8 numpy arrays + """ + with tf.io.gfile.GFile(data_path, "rb") as fp: + images = np.load(fp) + images = np.transpose(images, (1, 0, 2, 3)) + images = np.expand_dims(images, axis=-1) + for sequence in images: + yield dict(image_sequence=sequence) diff --git a/tensorflow_datasets/video/moving_sequence.py b/tensorflow_datasets/video/moving_sequence.py new file mode 100644 index 00000000000..18d7582d203 --- /dev/null +++ b/tensorflow_datasets/video/moving_sequence.py @@ -0,0 +1,218 @@ +"""Provides `image_as_moving_sequence`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import collections + + +def _create_moving_sequence(image, pad_lefts, total_padding): + """Create a moving image sequence from the given image a left padding values. + + Args: + image: [in_h, in_w, n_channels] uint8 array + pad_lefts: [sequence_length, 2] int32 array of left padding values + total_padding: tensor of padding values, (pad_h, pad_w) + + Returns: + [sequence_length, out_h, out_w, n_channels] uint8 image sequence, where + out_h = in_h + pad_h, out_w = in_w + out_w + """ + + with tf.name_scope("moving_sequence"): + def get_padded_image(args): + pad_left, = args + pad_right = total_padding - pad_left + padding = tf.stack([pad_left, pad_right], axis=-1) + z = tf.zeros((1, 2), dtype=pad_left.dtype) + padding = tf.concat([padding, z], axis=0) + return tf.pad(image, padding) + + padded_images = tf.map_fn( + get_padded_image, [pad_lefts], dtype=tf.uint8, infer_shape=False, + back_prop=False) + + return padded_images + + +def _get_linear_trajectory(x0, velocity, t): + """ + Args: + x0: N-D float tensor. + velocity: N-D float tensor + t: [sequence_length]-length float tensor + + Returns: + x: [sequence_length, ndims] float tensor. + """ + x0 = tf.convert_to_tensor(x0) + velocity = tf.convert_to_tensor(velocity) + t = tf.convert_to_tensor(t) + if x0.shape.ndims != 1: + raise ValueError("x0 must be a rank 1 tensor") + if velocity.shape.ndims != 1: + raise ValueError("velocity must be a rank 1 tensor") + if t.shape.ndims != 1: + raise ValueError("t must be a rank 1 tensor") + x0 = tf.expand_dims(x0, axis=0) + velocity = tf.expand_dims(velocity, axis=0) + dx = velocity * tf.expand_dims(t, axis=-1) + linear_trajectories = x0 + dx + assert linear_trajectories.shape.ndims == 2, \ + "linear_trajectories should be a rank 2 tensor" + return linear_trajectories + + +def _bounce_to_bbox(points): + """ + Bounce potentially unbounded points to [0, 1]. + + Bouncing occurs by exact reflection, i.e. a pre-bound point at 1.1 is moved + to 0.9, -0.2 -> 0.2. This theoretically can occur multiple times, e.g. + 2.3 -> -0.7 -> 0.3 + + Implementation + points <- points % 2 + return min(2 - points, points) + + Args: + points: float array + + Returns: + tensor with same shape/dtype but values in [0, 1]. + """ + points = points % 2 + return tf.math.minimum(2 - points, points) + + +def _get_random_unit_vector(ndims=2, dtype=tf.float32): + x = tf.random.normal((ndims,), dtype=dtype) + return x / tf.linalg.norm(x, axis=-1, keepdims=True) + +MovingSequence = collections.namedtuple( + "MovingSequence", + ["image_sequence", "trajectory", "start_position", "velocity"]) + + +def image_as_moving_sequence( + image, sequence_length=20, output_size=(64, 64), velocity=0.1, + start_position=None): + """Turn simple static images into sequences of the originals bouncing around. + + Adapted from Srivastava et al. + http://www.cs.toronto.edu/~nitish/unsupervised_video/ + + Example usage: + ```python + import tensorflow as tf + import tensorflow_datasets as tfds + from tensorflow_datasets.video import moving_sequence + tf.compat.v1.enable_eager_execution() + + def animate(sequence): + import numpy as np + import matplotlib.pyplot as plt + import matplotlib.animation as animation + sequence = np.squeeze(sequence, axis=-1) + + fig = plt.figure() + plt.axis("off") + ims = [[plt.imshow(im, cmap="gray", animated=True)] for im in sequence] + # don't remove `anim =` as linter may suggets + # weird behaviour, plot will freeze on last frame + anim = animation.ArtistAnimation( + fig, ims, interval=50, blit=True, repeat_delay=100) + + plt.show() + plt.close() + + + tf.enable_eager_execution() + mnist_ds = tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True) + mnist_ds = mnist_ds.repeat().shuffle(1024) + + def map_fn(image, label): + sequence = moving_sequence.image_as_moving_sequence( + image, sequence_length=20) + return sequence.image_sequence + + moving_mnist_ds = mnist_ds.map(map_fn).batch(2).map( + lambda x: dict(image_sequence=tf.reduce_max(x, axis=0))) + + # # for comparison with test data provided by original authors + # moving_mnist_ds = tfds.load("moving_mnist", split=tfds.Split.TEST) + + for seq in moving_mnist_ds: + animate(seq["image_sequence"].numpy()) + ``` + + Args: + image: [in_h, in_w, n_channels] tensor defining the sub-image to be bouncing + around. + sequence_length: int, length of sequence. + output_size: (out_h, out_w) size returned images. + velocity: scalar speed or 2D velocity of image. If scalar, the 2D + velocity is randomly generated with this magnitude. This is the + normalized distance moved each time step by the sub-image, where + normalization occurs over the feasible distance the sub-image can move + e.g if the input image is [10 x 10] and the output image is [60 x 60], + a speed of 0.1 means the sub-image moves (60 - 10) * 0.1 = 5 pixels per + time step. + start_position: 2D float32 normalized initial position of each + image in [0, 1]. Randomized uniformly if not given. + + Returns: + `MovingSequence` namedtuple containing: + `image_sequence`: + [sequence_length, out_h, out_w, n_channels] image at each time step. + padded values are all zero. Same dtype as input image. + `trajectory`: [sequence_length, 2] float32 in [0, 1] + 2D normalized coordinates of the image at every time step. + `start_position`: 2D float32 initial position in [0, 1]. + 2D normalized initial position of image. Same as input if provided, + otherwise the randomly value generated. + `velocity`: 2D float32 normalized velocity. Same as input velocity + if provided as a 2D tensor, otherwise the random velocity generated. + """ + ndims = 2 + image = tf.convert_to_tensor(image) + if image.shape.ndims != 3: + raise ValueError("image must be rank 3, got %s" % str(image)) + output_size = tf.TensorShape(output_size) + if len(output_size) != ndims: + raise ValueError("output_size must have exactly %d elements, got %s" + % (ndims, output_size)) + image_shape = tf.shape(image) + if start_position is None: + start_position = tf.random.uniform((ndims,), dtype=tf.float32) + elif start_position.shape != (ndims,): + raise ValueError("start_positions must (%d,)" % ndims) + velocity = tf.convert_to_tensor(velocity, dtype=tf.float32) + if velocity.shape.ndims == 0: + velocity = _get_random_unit_vector(ndims, tf.float32) * velocity + elif velocity.shape.ndims != 1: + raise ValueError("velocity must be rank 0 or rank 1, got %s" % velocity) + t = tf.range(sequence_length, dtype=tf.float32) + trajectory = _get_linear_trajectory(start_position, velocity, t) + trajectory = _bounce_to_bbox(trajectory) + + total_padding = output_size - image_shape[:2] + + # cond = tf.assert_greater(total_padding, -1) + cond = tf.compat.v1.assert_greater(total_padding, -1) + if not tf.executing_eagerly(): + with tf.control_dependencies([cond]): + total_padding = tf.identity(total_padding) + + sequence_pad_lefts = tf.cast( + tf.math.round(trajectory * tf.cast(total_padding, tf.float32)), tf.int32) + + sequence = _create_moving_sequence(image, sequence_pad_lefts, total_padding) + sequence.set_shape( + [sequence_length] + output_size.as_list() + [image.shape[-1]]) + return MovingSequence( + image_sequence=sequence, + trajectory=trajectory, + start_position=start_position, + velocity=velocity) diff --git a/tensorflow_datasets/video/moving_sequence_test.py b/tensorflow_datasets/video/moving_sequence_test.py new file mode 100644 index 00000000000..148b14e96c1 --- /dev/null +++ b/tensorflow_datasets/video/moving_sequence_test.py @@ -0,0 +1,47 @@ +"""Tests for moving_sequence.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow_datasets.core import test_utils +import tensorflow_datasets.video.moving_sequence as ms +tf.compat.v1.enable_eager_execution() + + +class MovingSequenceTest(tf.test.TestCase): + @test_utils.run_in_graph_and_eager_modes() + def test_images_as_moving_sequence(self): + h, w = (28, 28) + sequence_length = 8 + + vh = 1 / (sequence_length) + vw = 1 / (2*(sequence_length)) + image = tf.ones((28, 28, 1), dtype=tf.uint8) + + velocity = tf.constant([vh, vw], dtype=tf.float32) + out_size = (h + sequence_length, w + sequence_length) + start_position = tf.constant([0, 0], dtype=tf.float32) + + sequence = ms.image_as_moving_sequence( + image, start_position=start_position, velocity=velocity, + output_size=out_size, sequence_length=sequence_length) + sequence = tf.cast(sequence.image_sequence, tf.float32) + + self.assertAllEqual( + self.evaluate(tf.reduce_sum(sequence, axis=(1, 2, 3))), + self.evaluate( + tf.fill( + (sequence_length,), tf.reduce_sum(tf.cast(image, tf.float32))))) + + for i, full_image in enumerate(tf.unstack(sequence, axis=0)): + j = i // 2 + subimage = full_image[i:i+h, j:j+w] + n_true = tf.reduce_sum(subimage) + # allow for pixel rounding errors in each dimension + self.assertTrue(self.evaluate(n_true) >= (h-1)*(w-1)) + + +if __name__ == '__main__': + tf.test.main()