From a55bf8053b664bbd68124f2ad03f1afad13e4ac6 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Mon, 28 Jan 2019 18:57:38 +1000 Subject: [PATCH 01/13] added moving_mnist test data and mapping fns for other splits --- tensorflow_datasets/video/__init__.py | 1 + tensorflow_datasets/video/moving_mnist.py | 95 +++++ tensorflow_datasets/video/moving_sequence.py | 419 +++++++++++++++++++ 3 files changed, 515 insertions(+) create mode 100644 tensorflow_datasets/video/moving_mnist.py create mode 100644 tensorflow_datasets/video/moving_sequence.py diff --git a/tensorflow_datasets/video/__init__.py b/tensorflow_datasets/video/__init__.py index 1775a4a4ced..ad9e1ba196a 100644 --- a/tensorflow_datasets/video/__init__.py +++ b/tensorflow_datasets/video/__init__.py @@ -18,3 +18,4 @@ from tensorflow_datasets.video.bair_robot_pushing import BairRobotPushingSmall from tensorflow_datasets.video.starcraft import StarcraftVideo from tensorflow_datasets.video.starcraft import StarcraftVideoConfig +from tensorflow_datasets.video.moving_mnist import MovingMnist diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py new file mode 100644 index 00000000000..84a6c402b87 --- /dev/null +++ b/tensorflow_datasets/video/moving_mnist.py @@ -0,0 +1,95 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import numpy as np +import tensorflow as tf +import tensorflow_datasets.public_api as tfds +from tensorflow_datasets.image.mnist import _MNIST_IMAGE_SIZE + +_OUT_RESOLUTION = (64, 64) +_TOTAL_PADDING = tuple(o - _MNIST_IMAGE_SIZE for o in _OUT_RESOLUTION) # 36, 36 +_SEQUENCE_LENGTH = 20 +_IMAGES_PER_SEQUENCE = 2 + +_citation = """ +@article{DBLP:journals/corr/SrivastavaMS15, + author = {Nitish Srivastava and + Elman Mansimov and + Ruslan Salakhutdinov}, + title = {Unsupervised Learning of Video Representations using LSTMs}, + journal = {CoRR}, + volume = {abs/1502.04681}, + year = {2015}, + url = {http://arxiv.org/abs/1502.04681}, + archivePrefix = {arXiv}, + eprint = {1502.04681}, + timestamp = {Mon, 13 Aug 2018 16:47:05 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/SrivastavaMS15}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +""" + + +class MovingMnist(tfds.core.GeneratorBasedBuilder): + + VERSION = tfds.core.Version("0.1.0") + + def _info(self): + shape = (_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,) + + # as Image - doesn't work with 1 as final dim? + # sequence = tfds.features.Image(shape=shape) + + # as video - doesn't work with 1 as final dim? + # sequence = tfds.features.Video(shape=shape) + + # as base tensor - space inefficient?? + sequence = tfds.features.Tensor(shape=shape, dtype=tf.uint8) + + return tfds.core.DatasetInfo( + builder=self, + description=( + "Moving variant of MNIST database of handwritten digits. This is the " + "data used by the authors for reporting model performance. See " + "`tfds.video.moving_sequence` for functions to generate training/" + "validation data."), + features=tfds.features.FeaturesDict( + dict(image_sequence=sequence)), + # supervised_keys=("inputs",), + urls=["http://www.cs.toronto.edu/~nitish/unsupervised_video/"], + citation=_citation, + splits=[tfds.Split.TEST] + ) + + def _split_generators(self, dl_manager): + data_path = dl_manager.download( + "http://www.cs.toronto.edu/~nitish/unsupervised_video/" + "mnist_test_seq.npy") + + # authors only provide test data. See `tfds.video.moving_sequence` for + # approach based on creating sequences from existing datasets + return [ + tfds.core.SplitGenerator( + name=tfds.Split.TEST, + num_shards=5, + gen_kwargs=dict(data_path=data_path)), + ] + + def _generate_examples(self, data_path): + """Generate MOVING_MNIST sequences as a single. + + Args: + data_path (str): Path to the data file + + Returns: + 10000 x 20 x 64 x 64 x 1 uint8 numpy array + """ + with tf.io.gfile.GFile(data_path, "r") as fp: + images = np.load(fp) + images = np.transpose(images, (1, 0, 2, 3)) + images = np.expand_dims(images, axis=-1) + for sequence in images: + yield dict(image_sequence=sequence) diff --git a/tensorflow_datasets/video/moving_sequence.py b/tensorflow_datasets/video/moving_sequence.py new file mode 100644 index 00000000000..7d77b36ffea --- /dev/null +++ b/tensorflow_datasets/video/moving_sequence.py @@ -0,0 +1,419 @@ +""" +Contains functions for creating moving sequences of smaller bouncing images. + +This is a generalization of the code provided by the authors of the moving mnist +dataset. + +Example usage: +```python +import tensorflow as tf +import tensorflow_datasets as tfds +import tensorflow_datasets.video.moving_sequence as ms + + +def animate(sequence): + import matplotlib.pyplot as plt + import matplotlib.animation as animation + + fig = plt.figure() + plt.axis("off") + ims = [[plt.imshow(im, cmap="gray", animated=True)] for im in sequence] + # don't remove `anim =` as linter may suggets + # weird behaviour, plot will freeze on last frame + anim = animation.ArtistAnimation( + fig, ims, interval=50, blit=True, repeat_delay=100) + + plt.show() + plt.close() + + +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# get a base dataset +base_dataset = tfds.load("fashion_mnist")[tfds.Split.TRAIN] +base_dataset = base_dataset.repeat().shuffle(1024) +dataset = ms.as_moving_sequence_dataset( + base_dataset, + speed=lambda n: tf.random_normal(shape=(n,))*0.1, + image_key="image", + sequence_length=20) + +data = dataset.make_one_shot_iterator().get_next() +sequence = data["image_sequence"] +sequence = tf.squeeze(sequence, axis=-1) # output_shape [20, 64, 64] + +with tf.Session() as sess: + seq = sess.run(sequence) + animate(seq) +``` + +Default arguments in `as_moving_sequence_dataset` are for the original +moving mnist dataset, with +```python +base_dataset = tfds.load("mnist")[tfds.Split.TRAIN].repeat().shuffle(1024) +dataset = ms.as_moving_sequence_dataset(base_dataset) +``` + +Compare results above with +``` +dataset = tfds.load("moving_mnist")[tfds.Split.TEST] +``` +(test data provided by original authors) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import tensorflow as tf +import tensorflow_datasets as tfds +import collections + +_merge_fns = { + "max": lambda x, y: tf.cast( + tf.math.maximum(tf.cast(x, tf.int32), tf.cast(y, tf.int32)), tf.uint8) +} + + +def _create_moving_sequence(image, pad_lefts, total_padding): + """See create_moving_sequence.""" + + with tf.name_scope("moving_sequence"): + def get_padded_image(args): + pad_left, = args + pad_right = total_padding - pad_left + padding = tf.stack([pad_left, pad_right], axis=-1) + z = tf.zeros((1, 2), dtype=pad_left.dtype) + padding = tf.concat([padding, z], axis=0) + return tf.pad(image, padding) + + padded_images = tf.map_fn( + get_padded_image, [pad_lefts], dtype=tf.uint8, infer_shape=False, + back_prop=False) + + return padded_images + + +def create_moving_sequence(image, pad_lefts, total_padding): + """ + Create a moving image sequence from the given image and left padding values. + + Args: + image: [h, w, n_channels] uint8 array + pad_lefts: [sequence_length, 2] int32 array of + left padding values + total_padding: TensorShape or list/tuple, (out_h, out_w) + + Returns: + [sequence_length, out_h, out_w, n_shannels] uint8 sequence. + """ + total_padding = tf.TensorShape(total_padding) + pad_lefts = tf.convert_to_tensor(pad_lefts, dtype=tf.float32) + image = tf.convert_to_tensor(image, dtype=tf.uint8) + if image.shape.ndims != 3: + raise ValueError("`image` must be a rank 3 tensor") + if pad_lefts.shape.ndims != 2: + raise ValueError("`sequence_pad_lefts` must be a rank 2 tensor") + if len(total_padding) != 2: + raise ValueError( + "`total_padding` must have 2 entres, got %s" + % str(total_padding.as_list())) + seq = _create_moving_sequence( + image, pad_lefts, tf.convert_to_tensor(total_padding)) + ph, pw = total_padding + h, w, n_channels = image.shape + sequence_length = pad_lefts.shape[0] + seq.set_shape((sequence_length, h + ph, w + pw, n_channels)) + return seq + + +def create_merged_moving_sequence( + images, sequence_pad_lefts, total_padding, background=tf.zeros, + merge_fn="max"): + """ + Args: + images: [n_images, h, w, n_channels] uint8 array + sequence_pad_lefts: [n_images, sequence_length, 2] int32 array of + left padding values + total_padding: TensorShape (out_h, out_w) + background: background image, or callable that takes `shape` and `dtype` + args. + merge_fn: "max" for maximum, or callable mapping (seq0, seq1) -> seq, where + each of seq0, seq1 and seq2 aretensors of the same shape/dtype as + the output. + + Returns: + [sequence_length, out_h, out_w, n_channels] overlayed padded sequence. + """ + if isinstance(merge_fn, six.string_types): + merge_fn = _merge_fns[merge_fn] + images = tf.convert_to_tensor(images, dtype=tf.uint8) + sequence_pad_lefts = tf.convert_to_tensor(sequence_pad_lefts, dtype=tf.int32) + if images.shape.ndims != 4: + raise ValueError("`images` must be a rank 4 tensor") + if sequence_pad_lefts.shape.ndims != 3: + raise ValueError("`sequence_pad_lefts` must be a rank 4 tensor") + if len(total_padding) != 2: + raise ValueError( + "`total_padding` must be len 2, got %s" + % str(total_padding.as_list())) + + image_res = [i + t for i, t in zip(images.shape[1:3], total_padding)] + + n_channels = images.shape[3] + out_image_shape = image_res + [n_channels] + + total_padding_tensor = tf.convert_to_tensor(total_padding) + + def fn(seq0, args): + image, pad_lefts = args + seq1 = _create_moving_sequence(image, pad_lefts, total_padding_tensor) + seq1.set_shape(out_image_shape) + return merge_fn(seq0, seq1) + + if callable(background): + background = background(out_image_shape, tf.uint8) + + if background.shape != out_image_shape: + raise ValueError( + "background shape should be %s, got %s" % + (str(background.shape), str(out_image_shape))) + sequence = tf.foldl( + fn, [images, sequence_pad_lefts], + initializer=background, + back_prop=False, + name="merged_moving_sequence") + + return sequence + + +def get_random_trajectories( + n_trajectories, sequence_length, ndims=2, speed=0.1, + dtype=tf.float32): + """ + Args: + n_trajectories: int32 number of trajectories + sequence_length: int32 length of sequence + ndims: int32 number of dimensions + speed: (float) length of each step, or rank 1 tensor of length + `n_trajectories` + dx = speed*normalized_velocity + dtype: returned data type. Must be float + + Returns: + trajectories: [n_trajectories, sequence_length, ndims] `dtype` tensor + on [0, 1]. + x0: [n_trajectories, ndims] `dtype` tensor of random initial positions + used + velocity: [n_trajectories, ndims] `dtype` tensor of random normalized + velocities used. + """ + if not dtype.is_floating: + raise ValueError("dtype must be float") + speed = tf.convert_to_tensor(speed, dtype=dtype) + if speed.shape.ndims not in (0, 1): + raise ValueError("speed must be scalar or rank 1 tensor") + + nt = n_trajectories + x0 = tf.random.uniform((nt, ndims), dtype=dtype) + velocity = tf.random_normal((nt, ndims), dtype=dtype) + speed = tf.convert_to_tensor(speed, dtype=dtype) + if speed.shape.ndims == 1: + if speed.shape[0].value not in (1, n_trajectories): + raise ValueError( + "If speed is a rank 1 tensor, its length must be 1 or same as " + "`n_trajectories`, got shape %s" % str(speed.shape)) + speed = tf.expand_dims(speed, axis=-1) + velocity = velocity * ( + speed / tf.linalg.norm(velocity, axis=-1, keepdims=True)) + t = tf.range(sequence_length, dtype=dtype) + linear_trajectories = get_linear_trajectories(x0, velocity, t) + bounced_trajectories = bounce_to_bbox(linear_trajectories) + return bounced_trajectories, x0, velocity + + +def get_linear_trajectories(x0, velocity, t): + """ + Args: + x0: [n_trajectories, ndims] float tensor. + velocity: [n_trajectories, ndims] float tensor + t: [sequence_length] float tensor + + Returns: + x: [n_trajectories, sequence_length, ndims] float tensor. + """ + x0 = tf.convert_to_tensor(x0) + velocity = tf.convert_to_tensor(velocity) + if x0.shape.ndims != 2: + raise ValueError("x0 must be a rank 2 tensor") + if velocity.shape.ndims != 2: + raise ValueError("velocity must be a rank 2 tensor") + if t.shape.ndims != 1: + raise ValueError("t must be a rank 1 tensor") + x0 = tf.expand_dims(x0, axis=1) + velocity = tf.expand_dims(velocity, axis=1) + dx = velocity * tf.expand_dims(tf.expand_dims(t, axis=0), axis=-1) + linear_trajectories = x0 + dx + assert linear_trajectories.shape.ndims == 3, \ + "linear_trajectories should be a rank 3 tensor" + return linear_trajectories + + +def bounce_to_bbox(points): + """ + Bounce potentially unbounded points to [0, 1]. + + Bouncing occurs by exact reflection, i.e. a pre-bound point at 1.1 is moved + to 0.9, -0.2 -> 0.2. This theoretically can occur multiple times, e.g. + 2.3 -> -0.7 -> 0.3 + + Implementation + points <- points % 2 + return min(2 - points, points) + + Args: + points: float array + + Returns: + tensor with same shape/dtype but values in [0, 1]. + """ + points = points % 2 + return tf.math.minimum(2 - points, points) + + +MovingSequence = collections.namedtuple( + "MovingSequence", + ["image_sequence", "trajectories", "start_positions", "velocities"]) + + +def images_to_moving_sequence( + images, sequence_length=20, speed=0.1, total_padding=(36, 36), + **kwargs): + """ + Convert images to a moving sequence. + + Args: + images: [?, in_h, in_w, n_channels] uint8 tensor of images. + sequence_length: int, length of sequence. + speed: float, length of each step. Scalar, or rank 1 tensor with length + the same as images.shape[0]. + total_padding: (pad_y, pad_x) total padding to be applied in each dimension. + kwargs: passed to `create_merged_moving_sequence` + + Returns: + `MovingSequence` namedtuple containing: + `image_sequence`: + [sequence_length, in_h + pad_y, in_w + pad_x, n_channels] uint8. + `trajectories`: [sequence_length, n_images, 2] float32 in [0, 1]. + `start_positions`: [n_images, 2] float32 initial positions in [0, 1]. + `velocities`: [n_images, 2] float32 normalized velocities. + """ + images = tf.convert_to_tensor(images, dtype=tf.uint8) + total_padding = tf.TensorShape(total_padding) + speed = tf.convert_to_tensor(speed, dtype=tf.float32) + n_images = images.shape[0].value + trajectories, x0, velocity = get_random_trajectories( + n_images, sequence_length, ndims=2, speed=speed, + dtype=tf.float32) + sequence_pad_lefts = tf.cast( + trajectories * tf.cast(total_padding, tf.float32), + tf.int32) + sequence = create_merged_moving_sequence( + images, sequence_pad_lefts, total_padding, **kwargs) + return MovingSequence( + image_sequence=sequence, + trajectories=trajectories, + start_positions=x0, + velocities=velocity) + + +def as_moving_sequence_dataset( + base_dataset, n_images=2, sequence_length=20, total_padding=(36, 36), + speed=0.1, image_key="image", num_parallel_calls=None, **kwargs): + """ + Get a moving sequence dataset based on another image dataset. + + This is based on batching the base_dataset and mapping through + `images_to_moving_sequence`. For good variety, consider shuffling the + `base_dataset` before calling this rather than shuffling the returned one, as + this will make it extremely unlikely to get the same combination of images + in the sequence. + + Example usage: + ```python + base_dataset = tfds.load("fashion_mnist")[tfds.Split.TRAIN] + base_dataset = base_dataset.repeat().shuffle(1024) + dataset = ms.as_moving_sequence_dataset( + base_dataset, + speed=lambda n: tf.random_normal(shape=(n,)) / 10, + sequence_length=20, total_padding=(36, 36)) + dataset = dataset.batch(128) + features = dataset.make_one_shot_iterator().get_next() + images = features["image_sequence"] + labels = features["label"] + print(images.shape) # [128, 20, 64, 64, 1] + print(labels.shape) # [2] + ``` + + Args: + base_dataset: base image dataset to use. + n_images: number of sub-images making up each frame. + sequence_length: number of frames per sequences. + total_padding: TensorShape/list/tuple with [py, px]. Each image will be + padded with this amount on either left or right (top/bottom) per frame. + speed: normalized rate(s) at which sub-images move around. Each subimage + moves this fraction of the available space each frame. Scalar or rank 1 + tensor of length `n_images`, or a callable mapping `n_images` to one + of the above + image_key: key from the base dataset containing the images. + num_parallel_calls: used in dataset `map`. + kwargs: passed to `images_to_moving_sequence`. + + Returns: + mapped dataset dict entries + `image_sequence`: [ + sequence_length, base_im_h + py, base_im_w + px, base_im_channels], + uint8 tensor. + `trajectories`: [n_images, sequence_length, 2] float tensor with values + in range [0, 1] giving position for each base image in each frame. + `start_positions`: [n_images, 2] starting positions of each subimage. + `velocities`: [n_images, 2] normalized velocities + along with other entries from the base dataset + """ + dtypes = base_dataset.output_types + if image_key not in dtypes: + raise ValueError( + "base_dataset doesn't have key `image_key='%s'`.\nAvailable keys:\n%s" + % (image_key, "\n".join(dtypes))) + dtype = dtypes[image_key] + if dtype != tf.uint8: + raise TypeError("images must be uint8, got %s" % str(dtype)) + shape = base_dataset.output_shapes[image_key] + if shape.ndims != 3: + raise ValueError("images must be rank 3, got %s" % str(shape)) + + dataset = base_dataset.batch(n_images, drop_remainder=True) + + def map_fn(data): + # Do this check inside `map_fn` makes the resulting `Dataset` + # usable via `make_one_shot_iterator` + if callable(speed): + speed_ = speed(n_images) + else: + speed_ = speed + images = data.pop(image_key) + sequence = images_to_moving_sequence( + images, + sequence_length, + total_padding=total_padding, + speed=speed_, + **kwargs) + return dict( + image_sequence=sequence.image_sequence, + trajectories=sequence.trajectories, + start_positions=sequence.start_positions, + velocities=sequence.velocities, + **data + ) + + return dataset.map(map_fn, num_parallel_calls=num_parallel_calls) From 3e0940e5b24d6de8d3bd2b2ec7359c7767e8ab84 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Tue, 29 Jan 2019 11:56:52 +1000 Subject: [PATCH 02/13] requested changes + np_to_list change --- .../core/features/sequence_feature.py | 2 +- tensorflow_datasets/video/__init__.py | 2 +- tensorflow_datasets/video/moving_mnist.py | 23 +++++-------------- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/tensorflow_datasets/core/features/sequence_feature.py b/tensorflow_datasets/core/features/sequence_feature.py index b58b4d2095b..dad217e0c9f 100644 --- a/tensorflow_datasets/core/features/sequence_feature.py +++ b/tensorflow_datasets/core/features/sequence_feature.py @@ -260,7 +260,7 @@ def np_to_list(elem): return elem elif isinstance(elem, np.ndarray): elem = np.split(elem, elem.shape[0]) - elem = np.squeeze(elem, axis=0) + elem = [np.squeeze(e, axis=0) for e in elem] return elem else: raise ValueError( diff --git a/tensorflow_datasets/video/__init__.py b/tensorflow_datasets/video/__init__.py index ad9e1ba196a..7eb4ab09ba6 100644 --- a/tensorflow_datasets/video/__init__.py +++ b/tensorflow_datasets/video/__init__.py @@ -16,6 +16,6 @@ """Video datasets.""" from tensorflow_datasets.video.bair_robot_pushing import BairRobotPushingSmall +from tensorflow_datasets.video.moving_mnist import MovingMnist from tensorflow_datasets.video.starcraft import StarcraftVideo from tensorflow_datasets.video.starcraft import StarcraftVideoConfig -from tensorflow_datasets.video.moving_mnist import MovingMnist diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py index 84a6c402b87..a7b43288e8c 100644 --- a/tensorflow_datasets/video/moving_mnist.py +++ b/tensorflow_datasets/video/moving_mnist.py @@ -10,7 +10,6 @@ from tensorflow_datasets.image.mnist import _MNIST_IMAGE_SIZE _OUT_RESOLUTION = (64, 64) -_TOTAL_PADDING = tuple(o - _MNIST_IMAGE_SIZE for o in _OUT_RESOLUTION) # 36, 36 _SEQUENCE_LENGTH = 20 _IMAGES_PER_SEQUENCE = 2 @@ -38,17 +37,6 @@ class MovingMnist(tfds.core.GeneratorBasedBuilder): VERSION = tfds.core.Version("0.1.0") def _info(self): - shape = (_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,) - - # as Image - doesn't work with 1 as final dim? - # sequence = tfds.features.Image(shape=shape) - - # as video - doesn't work with 1 as final dim? - # sequence = tfds.features.Video(shape=shape) - - # as base tensor - space inefficient?? - sequence = tfds.features.Tensor(shape=shape, dtype=tf.uint8) - return tfds.core.DatasetInfo( builder=self, description=( @@ -57,7 +45,8 @@ def _info(self): "`tfds.video.moving_sequence` for functions to generate training/" "validation data."), features=tfds.features.FeaturesDict( - dict(image_sequence=sequence)), + dict(image_sequence=tfds.features.Video( + shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))), # supervised_keys=("inputs",), urls=["http://www.cs.toronto.edu/~nitish/unsupervised_video/"], citation=_citation, @@ -79,15 +68,15 @@ def _split_generators(self, dl_manager): ] def _generate_examples(self, data_path): - """Generate MOVING_MNIST sequences as a single. + """Generate MOVING_MNIST sequences. Args: data_path (str): Path to the data file - Returns: - 10000 x 20 x 64 x 64 x 1 uint8 numpy array + Yields: + 20 x 64 x 64 x 1 uint8 numpy arrays """ - with tf.io.gfile.GFile(data_path, "r") as fp: + with tf.io.gfile.GFile(data_path, "rb") as fp: images = np.load(fp) images = np.transpose(images, (1, 0, 2, 3)) images = np.expand_dims(images, axis=-1) From 5f0a8c71a3ed438a3779f9e1814e14b8ef5525ec Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Tue, 29 Jan 2019 12:27:32 +1000 Subject: [PATCH 03/13] updated video documentation, added channels error check --- tensorflow_datasets/core/features/video_feature.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow_datasets/core/features/video_feature.py b/tensorflow_datasets/core/features/video_feature.py index 9b87e7683b4..caaa1f1e7e5 100644 --- a/tensorflow_datasets/core/features/video_feature.py +++ b/tensorflow_datasets/core/features/video_feature.py @@ -29,10 +29,11 @@ class Video(sequence_feature.Sequence): """`FeatureConnector` for videos, png-encoding frames on disk. Video: The image connector accepts as input: - * uint8 array representing an video. + * uint8 array representing a video. Output: - video: tf.Tensor of type tf.uint8 and shape [num_frames, height, width, 3] + video: tf.Tensor of type tf.uint8 and shape + [num_frames, height, width, channels], where channels must be 1 or 3 Example: * In the DatasetInfo object: @@ -51,7 +52,7 @@ def __init__(self, shape): Args: shape: tuple of ints, the shape of the video (num_frames, height, width, - channels=3). + channels), where channels is 1 or 3. Raises: ValueError: If the shape is invalid @@ -61,6 +62,8 @@ def __init__(self, shape): raise ValueError('Video shape should be of rank 4') if shape.count(None) > 1: raise ValueError('Video shape cannot have more than 1 unknown dim') + if shape[-1] not in (1, 3): + raise ValueError('Video channels must be 1 or 3, got %d' % shape[-1]) super(Video, self).__init__( image_feature.Image(shape=shape[1:], encoding_format='png'), From e31261779b9312208999bdc1edb521ce2bcafe84 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Tue, 29 Jan 2019 18:16:50 +1000 Subject: [PATCH 04/13] removed unused private variables --- tensorflow_datasets/video/moving_mnist.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py index a7b43288e8c..c0b5a76a149 100644 --- a/tensorflow_datasets/video/moving_mnist.py +++ b/tensorflow_datasets/video/moving_mnist.py @@ -7,12 +7,9 @@ import numpy as np import tensorflow as tf import tensorflow_datasets.public_api as tfds -from tensorflow_datasets.image.mnist import _MNIST_IMAGE_SIZE _OUT_RESOLUTION = (64, 64) _SEQUENCE_LENGTH = 20 -_IMAGES_PER_SEQUENCE = 2 - _citation = """ @article{DBLP:journals/corr/SrivastavaMS15, author = {Nitish Srivastava and From 8d3f38860f0ad67a9ec09d7d9a21457022041225 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Tue, 29 Jan 2019 18:18:19 +1000 Subject: [PATCH 05/13] clean up --- tensorflow_datasets/video/moving_mnist.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py index c0b5a76a149..148e4190d5d 100644 --- a/tensorflow_datasets/video/moving_mnist.py +++ b/tensorflow_datasets/video/moving_mnist.py @@ -44,7 +44,6 @@ def _info(self): features=tfds.features.FeaturesDict( dict(image_sequence=tfds.features.Video( shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))), - # supervised_keys=("inputs",), urls=["http://www.cs.toronto.edu/~nitish/unsupervised_video/"], citation=_citation, splits=[tfds.Split.TEST] From e3615c44c6397f90c38bfb31edd6e0338ae485de Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Tue, 29 Jan 2019 20:18:55 +1000 Subject: [PATCH 06/13] requested changes --- tensorflow_datasets/core/features/video_feature.py | 2 -- tensorflow_datasets/video/moving_mnist.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow_datasets/core/features/video_feature.py b/tensorflow_datasets/core/features/video_feature.py index caaa1f1e7e5..152e681fc0e 100644 --- a/tensorflow_datasets/core/features/video_feature.py +++ b/tensorflow_datasets/core/features/video_feature.py @@ -62,8 +62,6 @@ def __init__(self, shape): raise ValueError('Video shape should be of rank 4') if shape.count(None) > 1: raise ValueError('Video shape cannot have more than 1 unknown dim') - if shape[-1] not in (1, 3): - raise ValueError('Video channels must be 1 or 3, got %d' % shape[-1]) super(Video, self).__init__( image_feature.Image(shape=shape[1:], encoding_format='png'), diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py index 148e4190d5d..8b555bf4a41 100644 --- a/tensorflow_datasets/video/moving_mnist.py +++ b/tensorflow_datasets/video/moving_mnist.py @@ -45,8 +45,7 @@ def _info(self): dict(image_sequence=tfds.features.Video( shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))), urls=["http://www.cs.toronto.edu/~nitish/unsupervised_video/"], - citation=_citation, - splits=[tfds.Split.TEST] + citation=_citation ) def _split_generators(self, dl_manager): From 3a5fe00c3d8c6d4f9e2ed49e6ff8510b28cdb805 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Wed, 30 Jan 2019 11:21:35 +1000 Subject: [PATCH 07/13] simplified np_to_list --- tensorflow_datasets/core/features/sequence_feature.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow_datasets/core/features/sequence_feature.py b/tensorflow_datasets/core/features/sequence_feature.py index dad217e0c9f..2fd5cf64984 100644 --- a/tensorflow_datasets/core/features/sequence_feature.py +++ b/tensorflow_datasets/core/features/sequence_feature.py @@ -259,9 +259,7 @@ def np_to_list(elem): if isinstance(elem, list): return elem elif isinstance(elem, np.ndarray): - elem = np.split(elem, elem.shape[0]) - elem = [np.squeeze(e, axis=0) for e in elem] - return elem + return list(elem) else: raise ValueError( 'Input elements of a sequence should be either a numpy array or a ' From 38cd51270f3c16c33adb6bf4ffe7554dafe201c9 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Mon, 4 Feb 2019 13:28:10 +1000 Subject: [PATCH 08/13] adjusted for dynamic shaped images, limited interface --- tensorflow_datasets/video/moving_mnist.py | 10 +- tensorflow_datasets/video/moving_sequence.py | 429 +++++++----------- .../video/moving_sequence_test.py | 66 +++ 3 files changed, 230 insertions(+), 275 deletions(-) create mode 100644 tensorflow_datasets/video/moving_sequence_test.py diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py index 8b555bf4a41..273bc5322ab 100644 --- a/tensorflow_datasets/video/moving_mnist.py +++ b/tensorflow_datasets/video/moving_mnist.py @@ -7,6 +7,7 @@ import numpy as np import tensorflow as tf import tensorflow_datasets.public_api as tfds +from tensorflow_datasets.video.moving_sequence import images_as_moving_sequence # pylint: disable=unused-import _OUT_RESOLUTION = (64, 64) _SEQUENCE_LENGTH = 20 @@ -39,8 +40,8 @@ def _info(self): description=( "Moving variant of MNIST database of handwritten digits. This is the " "data used by the authors for reporting model performance. See " - "`tfds.video.moving_sequence` for functions to generate training/" - "validation data."), + "`tensorflow_datasets.video.moving_mnist.images_as_moving_sequence` " + "for generating training/validation data from the MNIST dataset."), features=tfds.features.FeaturesDict( dict(image_sequence=tfds.features.Video( shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))), @@ -53,8 +54,9 @@ def _split_generators(self, dl_manager): "http://www.cs.toronto.edu/~nitish/unsupervised_video/" "mnist_test_seq.npy") - # authors only provide test data. See `tfds.video.moving_sequence` for - # approach based on creating sequences from existing datasets + # authors only provide test data. + # See `tfds.video.moving_mnist.moving_sequence` for mapping function to + # create training/validation dataset from MNIST. return [ tfds.core.SplitGenerator( name=tfds.Split.TEST, diff --git a/tensorflow_datasets/video/moving_sequence.py b/tensorflow_datasets/video/moving_sequence.py index 7d77b36ffea..ad86e86bc86 100644 --- a/tensorflow_datasets/video/moving_sequence.py +++ b/tensorflow_datasets/video/moving_sequence.py @@ -1,65 +1,4 @@ -""" -Contains functions for creating moving sequences of smaller bouncing images. - -This is a generalization of the code provided by the authors of the moving mnist -dataset. - -Example usage: -```python -import tensorflow as tf -import tensorflow_datasets as tfds -import tensorflow_datasets.video.moving_sequence as ms - - -def animate(sequence): - import matplotlib.pyplot as plt - import matplotlib.animation as animation - - fig = plt.figure() - plt.axis("off") - ims = [[plt.imshow(im, cmap="gray", animated=True)] for im in sequence] - # don't remove `anim =` as linter may suggets - # weird behaviour, plot will freeze on last frame - anim = animation.ArtistAnimation( - fig, ims, interval=50, blit=True, repeat_delay=100) - - plt.show() - plt.close() - - -# --------------------------------------------------------------------------- -# --------------------------------------------------------------------------- -# get a base dataset -base_dataset = tfds.load("fashion_mnist")[tfds.Split.TRAIN] -base_dataset = base_dataset.repeat().shuffle(1024) -dataset = ms.as_moving_sequence_dataset( - base_dataset, - speed=lambda n: tf.random_normal(shape=(n,))*0.1, - image_key="image", - sequence_length=20) - -data = dataset.make_one_shot_iterator().get_next() -sequence = data["image_sequence"] -sequence = tf.squeeze(sequence, axis=-1) # output_shape [20, 64, 64] - -with tf.Session() as sess: - seq = sess.run(sequence) - animate(seq) -``` - -Default arguments in `as_moving_sequence_dataset` are for the original -moving mnist dataset, with -```python -base_dataset = tfds.load("mnist")[tfds.Split.TRAIN].repeat().shuffle(1024) -dataset = ms.as_moving_sequence_dataset(base_dataset) -``` - -Compare results above with -``` -dataset = tfds.load("moving_mnist")[tfds.Split.TEST] -``` -(test data provided by original authors) -""" +"""Provides `images_as_moving_sequence`.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -76,7 +15,17 @@ def animate(sequence): def _create_moving_sequence(image, pad_lefts, total_padding): - """See create_moving_sequence.""" + """Create a moving image sequence from the given image a left padding values. + + Args: + image: [in_h, in_w, n_channels] uint8 array + pad_lefts: [sequence_length, 2] int32 array of left padding values + total_padding: tensor of padding values, (pad_h, pad_w) + + Returns: + [sequence_length, out_h, out_w, n_channels] uint8 image sequence, where + out_h = in_h + pad_h, out_w = in_w + out_w + """ with tf.name_scope("moving_sequence"): def get_padded_image(args): @@ -94,48 +43,16 @@ def get_padded_image(args): return padded_images -def create_moving_sequence(image, pad_lefts, total_padding): - """ - Create a moving image sequence from the given image and left padding values. - - Args: - image: [h, w, n_channels] uint8 array - pad_lefts: [sequence_length, 2] int32 array of - left padding values - total_padding: TensorShape or list/tuple, (out_h, out_w) - - Returns: - [sequence_length, out_h, out_w, n_shannels] uint8 sequence. - """ - total_padding = tf.TensorShape(total_padding) - pad_lefts = tf.convert_to_tensor(pad_lefts, dtype=tf.float32) - image = tf.convert_to_tensor(image, dtype=tf.uint8) - if image.shape.ndims != 3: - raise ValueError("`image` must be a rank 3 tensor") - if pad_lefts.shape.ndims != 2: - raise ValueError("`sequence_pad_lefts` must be a rank 2 tensor") - if len(total_padding) != 2: - raise ValueError( - "`total_padding` must have 2 entres, got %s" - % str(total_padding.as_list())) - seq = _create_moving_sequence( - image, pad_lefts, tf.convert_to_tensor(total_padding)) - ph, pw = total_padding - h, w, n_channels = image.shape - sequence_length = pad_lefts.shape[0] - seq.set_shape((sequence_length, h + ph, w + pw, n_channels)) - return seq - - -def create_merged_moving_sequence( - images, sequence_pad_lefts, total_padding, background=tf.zeros, - merge_fn="max"): +def _create_merged_moving_sequence( + images, sequence_pad_lefts, image_size, total_padding, + background=tf.zeros, merge_fn="max"): """ Args: images: [n_images, h, w, n_channels] uint8 array sequence_pad_lefts: [n_images, sequence_length, 2] int32 array of left padding values - total_padding: TensorShape (out_h, out_w) + image_size: TensorShape (out_h, out_w) + total_padding: tensor, images.shape[1:3] - image_size background: background image, or callable that takes `shape` and `dtype` args. merge_fn: "max" for maximum, or callable mapping (seq0, seq1) -> seq, where @@ -152,22 +69,22 @@ def create_merged_moving_sequence( if images.shape.ndims != 4: raise ValueError("`images` must be a rank 4 tensor") if sequence_pad_lefts.shape.ndims != 3: - raise ValueError("`sequence_pad_lefts` must be a rank 4 tensor") - if len(total_padding) != 2: + raise ValueError("`sequence_pad_lefts` must be a rank 3 tensor") + if total_padding.shape != (2,): raise ValueError( "`total_padding` must be len 2, got %s" % str(total_padding.as_list())) - image_res = [i + t for i, t in zip(images.shape[1:3], total_padding)] - n_channels = images.shape[3] - out_image_shape = image_res + [n_channels] - - total_padding_tensor = tf.convert_to_tensor(total_padding) + out_image_shape = ( + [sequence_pad_lefts.shape[1]] + + image_size.as_list() + + [n_channels] + ) def fn(seq0, args): image, pad_lefts = args - seq1 = _create_moving_sequence(image, pad_lefts, total_padding_tensor) + seq1 = _create_moving_sequence(image, pad_lefts, total_padding) seq1.set_shape(out_image_shape) return merge_fn(seq0, seq1) @@ -187,52 +104,7 @@ def fn(seq0, args): return sequence -def get_random_trajectories( - n_trajectories, sequence_length, ndims=2, speed=0.1, - dtype=tf.float32): - """ - Args: - n_trajectories: int32 number of trajectories - sequence_length: int32 length of sequence - ndims: int32 number of dimensions - speed: (float) length of each step, or rank 1 tensor of length - `n_trajectories` - dx = speed*normalized_velocity - dtype: returned data type. Must be float - - Returns: - trajectories: [n_trajectories, sequence_length, ndims] `dtype` tensor - on [0, 1]. - x0: [n_trajectories, ndims] `dtype` tensor of random initial positions - used - velocity: [n_trajectories, ndims] `dtype` tensor of random normalized - velocities used. - """ - if not dtype.is_floating: - raise ValueError("dtype must be float") - speed = tf.convert_to_tensor(speed, dtype=dtype) - if speed.shape.ndims not in (0, 1): - raise ValueError("speed must be scalar or rank 1 tensor") - - nt = n_trajectories - x0 = tf.random.uniform((nt, ndims), dtype=dtype) - velocity = tf.random_normal((nt, ndims), dtype=dtype) - speed = tf.convert_to_tensor(speed, dtype=dtype) - if speed.shape.ndims == 1: - if speed.shape[0].value not in (1, n_trajectories): - raise ValueError( - "If speed is a rank 1 tensor, its length must be 1 or same as " - "`n_trajectories`, got shape %s" % str(speed.shape)) - speed = tf.expand_dims(speed, axis=-1) - velocity = velocity * ( - speed / tf.linalg.norm(velocity, axis=-1, keepdims=True)) - t = tf.range(sequence_length, dtype=dtype) - linear_trajectories = get_linear_trajectories(x0, velocity, t) - bounced_trajectories = bounce_to_bbox(linear_trajectories) - return bounced_trajectories, x0, velocity - - -def get_linear_trajectories(x0, velocity, t): +def _get_linear_trajectories(x0, velocity, t): """ Args: x0: [n_trajectories, ndims] float tensor. @@ -259,7 +131,7 @@ def get_linear_trajectories(x0, velocity, t): return linear_trajectories -def bounce_to_bbox(points): +def _bounce_to_bbox(points): """ Bounce potentially unbounded points to [0, 1]. @@ -281,139 +153,154 @@ def bounce_to_bbox(points): return tf.math.minimum(2 - points, points) +def _get_random_velocities(n_velocities, ndims, speed, dtype=tf.float32): + """Get random velocities with given speed. + + Args: + n_velocities: int, number of velocities to generate + ndims: number of dimensions, e.g. 2 for images + speed: scalar speed for each velocity, or rank 1 tensor giving speed for + each generated velocity. + dtype: `tf.DType` of returned tensor + + Returns: + [n_velocities, ndims] tensor where each row has length speed in a random + direction. + """ + velocity = tf.random_normal((n_velocities, ndims), dtype=dtype) + speed = tf.convert_to_tensor(speed, dtype=dtype) + if speed.shape.ndims == 1: + if ( + speed.shape[0].value not in (1, n_velocities) and + isinstance(n_velocities, int)): + raise ValueError( + "If speed is a rank 1 tensor, its length must be 1 or same as " + "`n_trajectories`, got shape %s" % str(speed.shape)) + speed = tf.expand_dims(speed, axis=-1) + velocity = velocity * ( + speed / tf.linalg.norm(velocity, axis=-1, keepdims=True)) + return velocity + + MovingSequence = collections.namedtuple( "MovingSequence", ["image_sequence", "trajectories", "start_positions", "velocities"]) -def images_to_moving_sequence( - images, sequence_length=20, speed=0.1, total_padding=(36, 36), - **kwargs): - """ - Convert images to a moving sequence. +def images_as_moving_sequence( + images, sequence_length=20, output_size=(64, 64), + speed=0.1, velocities=None, start_positions=None, + background=tf.zeros, merge_fn='max'): + """Turn simple static images into sequences of the originals bouncing around. + + Adapted from Srivastava et al. + http://www.cs.toronto.edu/~nitish/unsupervised_video/ + + Example usage: + ```python + import tensorflow as tf + import tensorflow_datasets as tfds + from tensorflow_datasets.video import moving_sequence + + def animate(sequence): + import numpy as np + import matplotlib.pyplot as plt + import matplotlib.animation as animation + sequence = np.squeeze(sequence, axis=-1) + + fig = plt.figure() + plt.axis("off") + ims = [[plt.imshow(im, cmap="gray", animated=True)] for im in sequence] + # don't remove `anim =` as linter may suggets + # weird behaviour, plot will freeze on last frame + anim = animation.ArtistAnimation( + fig, ims, interval=50, blit=True, repeat_delay=100) + + plt.show() + plt.close() + + + tf.enable_eager_execution() + mnist_ds = tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True) + mnist_ds = mnist_ds.repeat().shuffle(1024).batch(2, drop_remainder=True) + + def map_fn(image, label): + sequence = moving_sequence.images_as_moving_sequence( + image, sequence_length=20) + return dict(image_sequence=sequence.image_sequence) + + moving_mnist_ds = mnist_ds.map(map_fn) + # # for comparison with test data provided by original authors + # moving_mnist_ds = tfds.load("moving_mnist", split=tfds.Split.TEST) + + for seq in moving_mnist_ds: + animate(seq["image_sequence"].numpy()) + ``` Args: - images: [?, in_h, in_w, n_channels] uint8 tensor of images. + images: [n_images, in_h, in_w, n_channels] uint8 tensor of images. sequence_length: int, length of sequence. + output_size: (out_h, out_w) size returned images. speed: float, length of each step. Scalar, or rank 1 tensor with length - the same as images.shape[0]. - total_padding: (pad_y, pad_x) total padding to be applied in each dimension. - kwargs: passed to `create_merged_moving_sequence` + n_images. Ignored if velocities is not `None`. + velocities: 2D velocity of each image. Randomly generated with speed 0.1 + if not provided. This is the normalized distance moved each time step + by each image, where normalization occurs over the feasible distance the + image can move. e.g if the input image is [10 x 10] and the output image + is [60 x 60], a speed of 0.1 means the image moves (60 - 10) * 0.1 = 5 + pixels per time step. + start_positions: [n_images, 2] float32 normalized initial position of each + image in [0, 1]. Randomized uniformly if not given. + background: background image, or callable that takes `shape` and `dtype` + args. + merge_fn: "max" for maximum, or callable mapping (seq0, seq1) -> seq, where + each of seq0, seq1 and seq2 are tensors of the same shape/dtype as + the output. Returns: `MovingSequence` namedtuple containing: `image_sequence`: - [sequence_length, in_h + pad_y, in_w + pad_x, n_channels] uint8. - `trajectories`: [sequence_length, n_images, 2] float32 in [0, 1]. + [sequence_length, out_h, out_w, n_channels_out] uint8. + With default arguments for `background`/`merge_fn`, + `n_channels_out` is the same as `n_channels` + `trajectories`: [sequence_length, n_images, 2] float32 in [0, 1] + 2D normalized coordinates of each image at every time step. `start_positions`: [n_images, 2] float32 initial positions in [0, 1]. - `velocities`: [n_images, 2] float32 normalized velocities. + 2D normalized initial position of each image. + `velocities`: [n_images, 2] float32 normalized velocities. Each image + moves by this amount (give or take due to pixel rounding) per time + step. """ + ndims = 2 images = tf.convert_to_tensor(images, dtype=tf.uint8) - total_padding = tf.TensorShape(total_padding) - speed = tf.convert_to_tensor(speed, dtype=tf.float32) - n_images = images.shape[0].value - trajectories, x0, velocity = get_random_trajectories( - n_images, sequence_length, ndims=2, speed=speed, - dtype=tf.float32) + output_size = tf.TensorShape(output_size) + if len(output_size) != ndims: + raise ValueError("output_size must have exactly %d elements, got %s" + % (ndims, output_size)) + image_shape = tf.shape(images) + n_images = image_shape[0] + if start_positions is None: + start_positions = tf.random_uniform((n_images, ndims), dtype=tf.float32) + elif start_positions.shape.ndims != 2 or start_positions.shape[-1] != ndims: + raise ValueError("start_positions must be rank 2 and %d-D" % ndims) + if velocities is None: + velocities = _get_random_velocities(n_images, ndims, speed, tf.float32) + t = tf.range(sequence_length, dtype=tf.float32) + trajectories = _get_linear_trajectories(start_positions, velocities, t) + trajectories = _bounce_to_bbox(trajectories) + + total_padding = output_size - image_shape[1:3] + with tf.control_dependencies([tf.assert_non_negative(total_padding)]): + total_padding = tf.identity(total_padding) + sequence_pad_lefts = tf.cast( - trajectories * tf.cast(total_padding, tf.float32), - tf.int32) - sequence = create_merged_moving_sequence( - images, sequence_pad_lefts, total_padding, **kwargs) + tf.math.round(trajectories * tf.cast(total_padding, tf.float32)), tf.int32) + + sequence = _create_merged_moving_sequence( + images, sequence_pad_lefts, output_size, total_padding, + background=background, merge_fn=merge_fn) return MovingSequence( image_sequence=sequence, trajectories=trajectories, - start_positions=x0, - velocities=velocity) - - -def as_moving_sequence_dataset( - base_dataset, n_images=2, sequence_length=20, total_padding=(36, 36), - speed=0.1, image_key="image", num_parallel_calls=None, **kwargs): - """ - Get a moving sequence dataset based on another image dataset. - - This is based on batching the base_dataset and mapping through - `images_to_moving_sequence`. For good variety, consider shuffling the - `base_dataset` before calling this rather than shuffling the returned one, as - this will make it extremely unlikely to get the same combination of images - in the sequence. - - Example usage: - ```python - base_dataset = tfds.load("fashion_mnist")[tfds.Split.TRAIN] - base_dataset = base_dataset.repeat().shuffle(1024) - dataset = ms.as_moving_sequence_dataset( - base_dataset, - speed=lambda n: tf.random_normal(shape=(n,)) / 10, - sequence_length=20, total_padding=(36, 36)) - dataset = dataset.batch(128) - features = dataset.make_one_shot_iterator().get_next() - images = features["image_sequence"] - labels = features["label"] - print(images.shape) # [128, 20, 64, 64, 1] - print(labels.shape) # [2] - ``` - - Args: - base_dataset: base image dataset to use. - n_images: number of sub-images making up each frame. - sequence_length: number of frames per sequences. - total_padding: TensorShape/list/tuple with [py, px]. Each image will be - padded with this amount on either left or right (top/bottom) per frame. - speed: normalized rate(s) at which sub-images move around. Each subimage - moves this fraction of the available space each frame. Scalar or rank 1 - tensor of length `n_images`, or a callable mapping `n_images` to one - of the above - image_key: key from the base dataset containing the images. - num_parallel_calls: used in dataset `map`. - kwargs: passed to `images_to_moving_sequence`. - - Returns: - mapped dataset dict entries - `image_sequence`: [ - sequence_length, base_im_h + py, base_im_w + px, base_im_channels], - uint8 tensor. - `trajectories`: [n_images, sequence_length, 2] float tensor with values - in range [0, 1] giving position for each base image in each frame. - `start_positions`: [n_images, 2] starting positions of each subimage. - `velocities`: [n_images, 2] normalized velocities - along with other entries from the base dataset - """ - dtypes = base_dataset.output_types - if image_key not in dtypes: - raise ValueError( - "base_dataset doesn't have key `image_key='%s'`.\nAvailable keys:\n%s" - % (image_key, "\n".join(dtypes))) - dtype = dtypes[image_key] - if dtype != tf.uint8: - raise TypeError("images must be uint8, got %s" % str(dtype)) - shape = base_dataset.output_shapes[image_key] - if shape.ndims != 3: - raise ValueError("images must be rank 3, got %s" % str(shape)) - - dataset = base_dataset.batch(n_images, drop_remainder=True) - - def map_fn(data): - # Do this check inside `map_fn` makes the resulting `Dataset` - # usable via `make_one_shot_iterator` - if callable(speed): - speed_ = speed(n_images) - else: - speed_ = speed - images = data.pop(image_key) - sequence = images_to_moving_sequence( - images, - sequence_length, - total_padding=total_padding, - speed=speed_, - **kwargs) - return dict( - image_sequence=sequence.image_sequence, - trajectories=sequence.trajectories, - start_positions=sequence.start_positions, - velocities=sequence.velocities, - **data - ) - - return dataset.map(map_fn, num_parallel_calls=num_parallel_calls) + start_positions=start_positions, + velocities=velocities) diff --git a/tensorflow_datasets/video/moving_sequence_test.py b/tensorflow_datasets/video/moving_sequence_test.py new file mode 100644 index 00000000000..37aeb242ef1 --- /dev/null +++ b/tensorflow_datasets/video/moving_sequence_test.py @@ -0,0 +1,66 @@ +"""Tests for moving_sequence.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +import tensorflow_datasets.video.moving_sequence as ms + +class MovingSequenceTest(tf.test.TestCase): + @tf.contrib.eager.run_test_in_graph_and_eager_modes() + def test_images_as_moving_sequence(self): + h, w = (28, 28) + sequence_length = 8 + + vh = 1 / (sequence_length) + vw = 1 / (2*(sequence_length)) + image = tf.ones((28, 28, 1), dtype=tf.uint8) + + velocity = tf.constant([vh, vw], dtype=tf.float32) + out_size = (h + sequence_length, w + sequence_length) + start_position = tf.constant([0, 0], dtype=tf.float32) + + + images = tf.expand_dims(image, axis=0) + velocities = tf.expand_dims(velocity, axis=0) + start_positions = tf.expand_dims(start_position, axis=0) + + sequence = ms.images_as_moving_sequence( + images, start_positions=start_positions, velocities=velocities, + output_size=out_size, sequence_length=sequence_length) + sequence = tf.cast(sequence.image_sequence, tf.float32) + + self.assertAllEqual( + tf.reduce_sum(sequence, axis=(1, 2, 3)), + tf.fill((sequence_length,), tf.reduce_sum(tf.cast(image, tf.float32)))) + + for i, full_image in enumerate(tf.unstack(sequence, axis=0)): + j = i // 2 + subimage = full_image[i:i+h, j:j+w] + n_true = tf.reduce_sum(subimage) + # allow for pixel rounding errors in each dimension + self.assertAllEqual(n_true >= (h-1)*(w-1), True) + + def test_dynamic_shape_inputs(self): + graph = tf.Graph() + with graph.as_default(): + image_floats = tf.placeholder( + shape=(None, None, None, 1), dtype=tf.float32) + images = tf.cast(image_floats, tf.uint8) + h, w = 64, 64 + sequence_length = 20 + sequence = ms.images_as_moving_sequence( + images, output_size=(h, w), + sequence_length=sequence_length).image_sequence + + with tf.Session(graph=graph) as sess: + for ni, ih, iw in ((2, 31, 32), (3, 37, 38)): + image_vals = np.random.uniform(high=255, size=(ni, ih, iw, 1)) + out = sess.run(sequence, feed_dict={image_floats: image_vals}) + self.assertAllEqual(out.shape, [sequence_length, h, w, 1]) + + + +if __name__ == '__main__': + tf.test.main() From faea0ee8196940265671df89e3bc339da34f744e Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Mon, 4 Feb 2019 13:37:24 +1000 Subject: [PATCH 09/13] test tweak and n_channels_out adjustment --- tensorflow_datasets/video/moving_sequence.py | 6 ++++-- tensorflow_datasets/video/moving_sequence_test.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow_datasets/video/moving_sequence.py b/tensorflow_datasets/video/moving_sequence.py index ad86e86bc86..b1d2092772a 100644 --- a/tensorflow_datasets/video/moving_sequence.py +++ b/tensorflow_datasets/video/moving_sequence.py @@ -60,7 +60,9 @@ def _create_merged_moving_sequence( the output. Returns: - [sequence_length, out_h, out_w, n_channels] overlayed padded sequence. + [sequence_length, out_h, out_w, n_channels_out] overlayed padded sequence. + n_channels_out defined by background/merge_fn output + (same as n_channels in for default values). """ if isinstance(merge_fn, six.string_types): merge_fn = _merge_fns[merge_fn] @@ -91,7 +93,7 @@ def fn(seq0, args): if callable(background): background = background(out_image_shape, tf.uint8) - if background.shape != out_image_shape: + if background.shape[:-1] != out_image_shape[:-1]: raise ValueError( "background shape should be %s, got %s" % (str(background.shape), str(out_image_shape))) diff --git a/tensorflow_datasets/video/moving_sequence_test.py b/tensorflow_datasets/video/moving_sequence_test.py index 37aeb242ef1..4c7b500e48d 100644 --- a/tensorflow_datasets/video/moving_sequence_test.py +++ b/tensorflow_datasets/video/moving_sequence_test.py @@ -54,7 +54,7 @@ def test_dynamic_shape_inputs(self): images, output_size=(h, w), sequence_length=sequence_length).image_sequence - with tf.Session(graph=graph) as sess: + with self.session(graph=graph) as sess: for ni, ih, iw in ((2, 31, 32), (3, 37, 38)): image_vals = np.random.uniform(high=255, size=(ni, ih, iw, 1)) out = sess.run(sequence, feed_dict={image_floats: image_vals}) From c4164b24be4aef3fd0a0023cfd81c12e29e502aa Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Mon, 4 Feb 2019 15:06:31 +1000 Subject: [PATCH 10/13] removed foldl loop, unvectorized image_as_moving_sequence --- tensorflow_datasets/video/moving_mnist.py | 8 +- tensorflow_datasets/video/moving_sequence.py | 239 ++++++------------ .../video/moving_sequence_test.py | 29 +-- 3 files changed, 90 insertions(+), 186 deletions(-) diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py index 273bc5322ab..e03aa3cb60b 100644 --- a/tensorflow_datasets/video/moving_mnist.py +++ b/tensorflow_datasets/video/moving_mnist.py @@ -7,7 +7,7 @@ import numpy as np import tensorflow as tf import tensorflow_datasets.public_api as tfds -from tensorflow_datasets.video.moving_sequence import images_as_moving_sequence # pylint: disable=unused-import +from tensorflow_datasets.video.moving_sequence import image_as_moving_sequence # pylint: disable=unused-import _OUT_RESOLUTION = (64, 64) _SEQUENCE_LENGTH = 20 @@ -40,7 +40,7 @@ def _info(self): description=( "Moving variant of MNIST database of handwritten digits. This is the " "data used by the authors for reporting model performance. See " - "`tensorflow_datasets.video.moving_mnist.images_as_moving_sequence` " + "`tensorflow_datasets.video.moving_mnist.image_as_moving_sequence` " "for generating training/validation data from the MNIST dataset."), features=tfds.features.FeaturesDict( dict(image_sequence=tfds.features.Video( @@ -55,8 +55,8 @@ def _split_generators(self, dl_manager): "mnist_test_seq.npy") # authors only provide test data. - # See `tfds.video.moving_mnist.moving_sequence` for mapping function to - # create training/validation dataset from MNIST. + # See `tfds.video.moving_mnist.image_as_moving_sequence` for mapping + # function to create training/validation dataset from MNIST. return [ tfds.core.SplitGenerator( name=tfds.Split.TEST, diff --git a/tensorflow_datasets/video/moving_sequence.py b/tensorflow_datasets/video/moving_sequence.py index b1d2092772a..330eaae6366 100644 --- a/tensorflow_datasets/video/moving_sequence.py +++ b/tensorflow_datasets/video/moving_sequence.py @@ -1,18 +1,11 @@ -"""Provides `images_as_moving_sequence`.""" +"""Provides `image_as_moving_sequence`.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import six import tensorflow as tf -import tensorflow_datasets as tfds import collections -_merge_fns = { - "max": lambda x, y: tf.cast( - tf.math.maximum(tf.cast(x, tf.int32), tf.cast(y, tf.int32)), tf.uint8) -} - def _create_moving_sequence(image, pad_lefts, total_padding): """Create a moving image sequence from the given image a left padding values. @@ -43,93 +36,31 @@ def get_padded_image(args): return padded_images -def _create_merged_moving_sequence( - images, sequence_pad_lefts, image_size, total_padding, - background=tf.zeros, merge_fn="max"): +def _get_linear_trajectory(x0, velocity, t): """ Args: - images: [n_images, h, w, n_channels] uint8 array - sequence_pad_lefts: [n_images, sequence_length, 2] int32 array of - left padding values - image_size: TensorShape (out_h, out_w) - total_padding: tensor, images.shape[1:3] - image_size - background: background image, or callable that takes `shape` and `dtype` - args. - merge_fn: "max" for maximum, or callable mapping (seq0, seq1) -> seq, where - each of seq0, seq1 and seq2 aretensors of the same shape/dtype as - the output. + x0: N-D float tensor. + velocity: N-D float tensor + t: [sequence_length]-length float tensor Returns: - [sequence_length, out_h, out_w, n_channels_out] overlayed padded sequence. - n_channels_out defined by background/merge_fn output - (same as n_channels in for default values). - """ - if isinstance(merge_fn, six.string_types): - merge_fn = _merge_fns[merge_fn] - images = tf.convert_to_tensor(images, dtype=tf.uint8) - sequence_pad_lefts = tf.convert_to_tensor(sequence_pad_lefts, dtype=tf.int32) - if images.shape.ndims != 4: - raise ValueError("`images` must be a rank 4 tensor") - if sequence_pad_lefts.shape.ndims != 3: - raise ValueError("`sequence_pad_lefts` must be a rank 3 tensor") - if total_padding.shape != (2,): - raise ValueError( - "`total_padding` must be len 2, got %s" - % str(total_padding.as_list())) - - n_channels = images.shape[3] - out_image_shape = ( - [sequence_pad_lefts.shape[1]] + - image_size.as_list() + - [n_channels] - ) - - def fn(seq0, args): - image, pad_lefts = args - seq1 = _create_moving_sequence(image, pad_lefts, total_padding) - seq1.set_shape(out_image_shape) - return merge_fn(seq0, seq1) - - if callable(background): - background = background(out_image_shape, tf.uint8) - - if background.shape[:-1] != out_image_shape[:-1]: - raise ValueError( - "background shape should be %s, got %s" % - (str(background.shape), str(out_image_shape))) - sequence = tf.foldl( - fn, [images, sequence_pad_lefts], - initializer=background, - back_prop=False, - name="merged_moving_sequence") - - return sequence - - -def _get_linear_trajectories(x0, velocity, t): - """ - Args: - x0: [n_trajectories, ndims] float tensor. - velocity: [n_trajectories, ndims] float tensor - t: [sequence_length] float tensor - - Returns: - x: [n_trajectories, sequence_length, ndims] float tensor. + x: [sequence_length, ndims] float tensor. """ x0 = tf.convert_to_tensor(x0) velocity = tf.convert_to_tensor(velocity) - if x0.shape.ndims != 2: - raise ValueError("x0 must be a rank 2 tensor") - if velocity.shape.ndims != 2: - raise ValueError("velocity must be a rank 2 tensor") + t = tf.convert_to_tensor(t) + if x0.shape.ndims != 1: + raise ValueError("x0 must be a rank 1 tensor") + if velocity.shape.ndims != 1: + raise ValueError("velocity must be a rank 1 tensor") if t.shape.ndims != 1: raise ValueError("t must be a rank 1 tensor") - x0 = tf.expand_dims(x0, axis=1) - velocity = tf.expand_dims(velocity, axis=1) - dx = velocity * tf.expand_dims(tf.expand_dims(t, axis=0), axis=-1) + x0 = tf.expand_dims(x0, axis=0) + velocity = tf.expand_dims(velocity, axis=0) + dx = velocity * tf.expand_dims(t, axis=-1) linear_trajectories = x0 + dx - assert linear_trajectories.shape.ndims == 3, \ - "linear_trajectories should be a rank 3 tensor" + assert linear_trajectories.shape.ndims == 2, \ + "linear_trajectories should be a rank 2 tensor" return linear_trajectories @@ -155,44 +86,18 @@ def _bounce_to_bbox(points): return tf.math.minimum(2 - points, points) -def _get_random_velocities(n_velocities, ndims, speed, dtype=tf.float32): - """Get random velocities with given speed. - - Args: - n_velocities: int, number of velocities to generate - ndims: number of dimensions, e.g. 2 for images - speed: scalar speed for each velocity, or rank 1 tensor giving speed for - each generated velocity. - dtype: `tf.DType` of returned tensor - - Returns: - [n_velocities, ndims] tensor where each row has length speed in a random - direction. - """ - velocity = tf.random_normal((n_velocities, ndims), dtype=dtype) - speed = tf.convert_to_tensor(speed, dtype=dtype) - if speed.shape.ndims == 1: - if ( - speed.shape[0].value not in (1, n_velocities) and - isinstance(n_velocities, int)): - raise ValueError( - "If speed is a rank 1 tensor, its length must be 1 or same as " - "`n_trajectories`, got shape %s" % str(speed.shape)) - speed = tf.expand_dims(speed, axis=-1) - velocity = velocity * ( - speed / tf.linalg.norm(velocity, axis=-1, keepdims=True)) - return velocity - +def _get_random_unit_vector(ndims=2, dtype=tf.float32): + x = tf.random_normal((ndims,), dtype=dtype) + return x / tf.linalg.norm(x, axis=-1, keepdims=True) MovingSequence = collections.namedtuple( "MovingSequence", - ["image_sequence", "trajectories", "start_positions", "velocities"]) + ["image_sequence", "trajectory", "start_position", "velocity"]) -def images_as_moving_sequence( - images, sequence_length=20, output_size=(64, 64), - speed=0.1, velocities=None, start_positions=None, - background=tf.zeros, merge_fn='max'): +def image_as_moving_sequence( + image, sequence_length=20, output_size=(64, 64), velocity=0.1, + start_position=None): """Turn simple static images into sequences of the originals bouncing around. Adapted from Srivastava et al. @@ -224,14 +129,16 @@ def animate(sequence): tf.enable_eager_execution() mnist_ds = tfds.load("mnist", split=tfds.Split.TRAIN, as_supervised=True) - mnist_ds = mnist_ds.repeat().shuffle(1024).batch(2, drop_remainder=True) + mnist_ds = mnist_ds.repeat().shuffle(1024) def map_fn(image, label): - sequence = moving_sequence.images_as_moving_sequence( - image, sequence_length=20) - return dict(image_sequence=sequence.image_sequence) + sequence = moving_sequence.image_as_moving_sequence( + image, sequence_length=20) + return sequence.image_sequence + + moving_mnist_ds = mnist_ds.map(map_fn).batch(2).map( + lambda x: dict(image_sequence=tf.reduce_max(x, axis=0))) - moving_mnist_ds = mnist_ds.map(map_fn) # # for comparison with test data provided by original authors # moving_mnist_ds = tfds.load("moving_mnist", split=tfds.Split.TEST) @@ -240,69 +147,67 @@ def map_fn(image, label): ``` Args: - images: [n_images, in_h, in_w, n_channels] uint8 tensor of images. + image: [in_h, in_w, n_channels] tensor defining the sub-image to be bouncing + around. sequence_length: int, length of sequence. output_size: (out_h, out_w) size returned images. - speed: float, length of each step. Scalar, or rank 1 tensor with length - n_images. Ignored if velocities is not `None`. - velocities: 2D velocity of each image. Randomly generated with speed 0.1 - if not provided. This is the normalized distance moved each time step - by each image, where normalization occurs over the feasible distance the - image can move. e.g if the input image is [10 x 10] and the output image - is [60 x 60], a speed of 0.1 means the image moves (60 - 10) * 0.1 = 5 - pixels per time step. - start_positions: [n_images, 2] float32 normalized initial position of each + velocity: scalar speed or 2D velocity of image. If scalar, the 2D + velocity is randomly generated with this magnitude. This is the + normalized distance moved each time step by the sub-image, where + normalization occurs over the feasible distance the sub-image can move + e.g if the input image is [10 x 10] and the output image is [60 x 60], + a speed of 0.1 means the sub-image moves (60 - 10) * 0.1 = 5 pixels per + time step. + start_position: 2D float32 normalized initial position of each image in [0, 1]. Randomized uniformly if not given. - background: background image, or callable that takes `shape` and `dtype` - args. - merge_fn: "max" for maximum, or callable mapping (seq0, seq1) -> seq, where - each of seq0, seq1 and seq2 are tensors of the same shape/dtype as - the output. Returns: `MovingSequence` namedtuple containing: `image_sequence`: - [sequence_length, out_h, out_w, n_channels_out] uint8. - With default arguments for `background`/`merge_fn`, - `n_channels_out` is the same as `n_channels` - `trajectories`: [sequence_length, n_images, 2] float32 in [0, 1] - 2D normalized coordinates of each image at every time step. - `start_positions`: [n_images, 2] float32 initial positions in [0, 1]. - 2D normalized initial position of each image. - `velocities`: [n_images, 2] float32 normalized velocities. Each image - moves by this amount (give or take due to pixel rounding) per time - step. + [sequence_length, out_h, out_w, n_channels] image at each time step. + padded values are all zero. Same dtype as input image. + `trajectory`: [sequence_length, 2] float32 in [0, 1] + 2D normalized coordinates of the image at every time step. + `start_position`: 2D float32 initial position in [0, 1]. + 2D normalized initial position of image. Same as input if provided, + otherwise the randomly value generated. + `velocity`: 2D float32 normalized velocity. Same as input velocity + if provided as a 2D tensor, otherwise the random velocity generated. """ ndims = 2 - images = tf.convert_to_tensor(images, dtype=tf.uint8) + image = tf.convert_to_tensor(image) + if image.shape.ndims != 3: + raise ValueError("image must be rank 3, got %s" % str(image)) output_size = tf.TensorShape(output_size) if len(output_size) != ndims: raise ValueError("output_size must have exactly %d elements, got %s" % (ndims, output_size)) - image_shape = tf.shape(images) - n_images = image_shape[0] - if start_positions is None: - start_positions = tf.random_uniform((n_images, ndims), dtype=tf.float32) - elif start_positions.shape.ndims != 2 or start_positions.shape[-1] != ndims: - raise ValueError("start_positions must be rank 2 and %d-D" % ndims) - if velocities is None: - velocities = _get_random_velocities(n_images, ndims, speed, tf.float32) + image_shape = tf.shape(image) + if start_position is None: + start_position = tf.random_uniform((ndims,), dtype=tf.float32) + elif start_position.shape != (ndims,): + raise ValueError("start_positions must (%d,)" % ndims) + velocity = tf.convert_to_tensor(velocity, dtype=tf.float32) + if velocity.shape.ndims == 0: + velocity = _get_random_unit_vector(ndims, tf.float32) * velocity + elif velocity.shape.ndims != 1: + raise ValueError("velocity must be rank 0 or rank 1, got %s" % velocity) t = tf.range(sequence_length, dtype=tf.float32) - trajectories = _get_linear_trajectories(start_positions, velocities, t) - trajectories = _bounce_to_bbox(trajectories) + trajectory = _get_linear_trajectory(start_position, velocity, t) + trajectory = _bounce_to_bbox(trajectory) - total_padding = output_size - image_shape[1:3] + total_padding = output_size - image_shape[:2] with tf.control_dependencies([tf.assert_non_negative(total_padding)]): total_padding = tf.identity(total_padding) sequence_pad_lefts = tf.cast( - tf.math.round(trajectories * tf.cast(total_padding, tf.float32)), tf.int32) + tf.math.round(trajectory * tf.cast(total_padding, tf.float32)), tf.int32) - sequence = _create_merged_moving_sequence( - images, sequence_pad_lefts, output_size, total_padding, - background=background, merge_fn=merge_fn) + sequence = _create_moving_sequence(image, sequence_pad_lefts, total_padding) + sequence.set_shape( + [sequence_length] + output_size.as_list() + [image.shape[-1]]) return MovingSequence( image_sequence=sequence, - trajectories=trajectories, - start_positions=start_positions, - velocities=velocities) + trajectory=trajectory, + start_position=start_position, + velocity=velocity) diff --git a/tensorflow_datasets/video/moving_sequence_test.py b/tensorflow_datasets/video/moving_sequence_test.py index 4c7b500e48d..6245369f509 100644 --- a/tensorflow_datasets/video/moving_sequence_test.py +++ b/tensorflow_datasets/video/moving_sequence_test.py @@ -5,10 +5,14 @@ import numpy as np import tensorflow as tf +from tensorflow_datasets.core import test_utils import tensorflow_datasets.video.moving_sequence as ms +test_utils.run_test_in_graph_and_eager_modes = tf.contrib.eager.run_test_in_graph_and_eager_modes + + class MovingSequenceTest(tf.test.TestCase): - @tf.contrib.eager.run_test_in_graph_and_eager_modes() + @test_utils.run_test_in_graph_and_eager_modes() def test_images_as_moving_sequence(self): h, w = (28, 28) sequence_length = 8 @@ -21,13 +25,8 @@ def test_images_as_moving_sequence(self): out_size = (h + sequence_length, w + sequence_length) start_position = tf.constant([0, 0], dtype=tf.float32) - - images = tf.expand_dims(image, axis=0) - velocities = tf.expand_dims(velocity, axis=0) - start_positions = tf.expand_dims(start_position, axis=0) - - sequence = ms.images_as_moving_sequence( - images, start_positions=start_positions, velocities=velocities, + sequence = ms.image_as_moving_sequence( + image, start_position=start_position, velocity=velocity, output_size=out_size, sequence_length=sequence_length) sequence = tf.cast(sequence.image_sequence, tf.float32) @@ -46,17 +45,17 @@ def test_dynamic_shape_inputs(self): graph = tf.Graph() with graph.as_default(): image_floats = tf.placeholder( - shape=(None, None, None, 1), dtype=tf.float32) - images = tf.cast(image_floats, tf.uint8) + shape=(None, None, 1), dtype=tf.float32) + image = tf.cast(image_floats, tf.uint8) h, w = 64, 64 sequence_length = 20 - sequence = ms.images_as_moving_sequence( - images, output_size=(h, w), + sequence = ms.image_as_moving_sequence( + image, output_size=(h, w), sequence_length=sequence_length).image_sequence - with self.session(graph=graph) as sess: - for ni, ih, iw in ((2, 31, 32), (3, 37, 38)): - image_vals = np.random.uniform(high=255, size=(ni, ih, iw, 1)) + with tf.Session(graph=graph) as sess: + for ih, iw in ((31, 32), (37, 38)): + image_vals = np.random.uniform(high=255, size=(ih, iw, 1)) out = sess.run(sequence, feed_dict={image_floats: image_vals}) self.assertAllEqual(out.shape, [sequence_length, h, w, 1]) From 05be775ed75ae9bc576bf1c1a0af73d17e09858d Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Mon, 4 Feb 2019 15:26:29 +1000 Subject: [PATCH 11/13] merged master changes, fixed tests --- tensorflow_datasets/video/moving_sequence.py | 11 +++++---- .../video/moving_sequence_test.py | 23 +------------------ 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/tensorflow_datasets/video/moving_sequence.py b/tensorflow_datasets/video/moving_sequence.py index 330eaae6366..ba36c51c0a5 100644 --- a/tensorflow_datasets/video/moving_sequence.py +++ b/tensorflow_datasets/video/moving_sequence.py @@ -87,7 +87,7 @@ def _bounce_to_bbox(points): def _get_random_unit_vector(ndims=2, dtype=tf.float32): - x = tf.random_normal((ndims,), dtype=dtype) + x = tf.random.normal((ndims,), dtype=dtype) return x / tf.linalg.norm(x, axis=-1, keepdims=True) MovingSequence = collections.namedtuple( @@ -184,7 +184,7 @@ def map_fn(image, label): % (ndims, output_size)) image_shape = tf.shape(image) if start_position is None: - start_position = tf.random_uniform((ndims,), dtype=tf.float32) + start_position = tf.random.uniform((ndims,), dtype=tf.float32) elif start_position.shape != (ndims,): raise ValueError("start_positions must (%d,)" % ndims) velocity = tf.convert_to_tensor(velocity, dtype=tf.float32) @@ -197,8 +197,11 @@ def map_fn(image, label): trajectory = _bounce_to_bbox(trajectory) total_padding = output_size - image_shape[:2] - with tf.control_dependencies([tf.assert_non_negative(total_padding)]): - total_padding = tf.identity(total_padding) + + # cond = tf.assert_greater(total_padding, -1) + # if not tf.executing_eagerly(): + # with tf.control_dependencies([cond]): + # total_padding = tf.identity(total_padding) sequence_pad_lefts = tf.cast( tf.math.round(trajectory * tf.cast(total_padding, tf.float32)), tf.int32) diff --git a/tensorflow_datasets/video/moving_sequence_test.py b/tensorflow_datasets/video/moving_sequence_test.py index 6245369f509..9c46edcc0b7 100644 --- a/tensorflow_datasets/video/moving_sequence_test.py +++ b/tensorflow_datasets/video/moving_sequence_test.py @@ -8,11 +8,9 @@ from tensorflow_datasets.core import test_utils import tensorflow_datasets.video.moving_sequence as ms -test_utils.run_test_in_graph_and_eager_modes = tf.contrib.eager.run_test_in_graph_and_eager_modes - class MovingSequenceTest(tf.test.TestCase): - @test_utils.run_test_in_graph_and_eager_modes() + @test_utils.run_in_graph_and_eager_modes() def test_images_as_moving_sequence(self): h, w = (28, 28) sequence_length = 8 @@ -41,25 +39,6 @@ def test_images_as_moving_sequence(self): # allow for pixel rounding errors in each dimension self.assertAllEqual(n_true >= (h-1)*(w-1), True) - def test_dynamic_shape_inputs(self): - graph = tf.Graph() - with graph.as_default(): - image_floats = tf.placeholder( - shape=(None, None, 1), dtype=tf.float32) - image = tf.cast(image_floats, tf.uint8) - h, w = 64, 64 - sequence_length = 20 - sequence = ms.image_as_moving_sequence( - image, output_size=(h, w), - sequence_length=sequence_length).image_sequence - - with tf.Session(graph=graph) as sess: - for ih, iw in ((31, 32), (37, 38)): - image_vals = np.random.uniform(high=255, size=(ih, iw, 1)) - out = sess.run(sequence, feed_dict={image_floats: image_vals}) - self.assertAllEqual(out.shape, [sequence_length, h, w, 1]) - - if __name__ == '__main__': tf.test.main() From 2383f641af739ec567bcb9374369b4a95cd6a2e7 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Mon, 4 Feb 2019 17:44:50 -0800 Subject: [PATCH 12/13] Small tweaks --- tensorflow_datasets/video/moving_mnist.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow_datasets/video/moving_mnist.py b/tensorflow_datasets/video/moving_mnist.py index e03aa3cb60b..4b0cc86a09c 100644 --- a/tensorflow_datasets/video/moving_mnist.py +++ b/tensorflow_datasets/video/moving_mnist.py @@ -11,7 +11,8 @@ _OUT_RESOLUTION = (64, 64) _SEQUENCE_LENGTH = 20 -_citation = """ +_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/" +_CITATION = """\ @article{DBLP:journals/corr/SrivastavaMS15, author = {Nitish Srivastava and Elman Mansimov and @@ -40,19 +41,17 @@ def _info(self): description=( "Moving variant of MNIST database of handwritten digits. This is the " "data used by the authors for reporting model performance. See " - "`tensorflow_datasets.video.moving_mnist.image_as_moving_sequence` " + "`tfds.video.moving_mnist.image_as_moving_sequence` " "for generating training/validation data from the MNIST dataset."), features=tfds.features.FeaturesDict( dict(image_sequence=tfds.features.Video( shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))), - urls=["http://www.cs.toronto.edu/~nitish/unsupervised_video/"], - citation=_citation + urls=[_URL], + citation=_CITATION, ) def _split_generators(self, dl_manager): - data_path = dl_manager.download( - "http://www.cs.toronto.edu/~nitish/unsupervised_video/" - "mnist_test_seq.npy") + data_path = dl_manager.download(_URL + "mnist_test_seq.npy") # authors only provide test data. # See `tfds.video.moving_mnist.image_as_moving_sequence` for mapping @@ -65,7 +64,7 @@ def _split_generators(self, dl_manager): ] def _generate_examples(self, data_path): - """Generate MOVING_MNIST sequences. + """Generate MovingMnist sequences. Args: data_path (str): Path to the data file From df567469ed904e2efb409fa918cc9285fe5a6e03 Mon Sep 17 00:00:00 2001 From: Dominic Jack Date: Tue, 5 Feb 2019 13:07:24 +1000 Subject: [PATCH 13/13] requested changes --- tensorflow_datasets/video/moving_sequence.py | 8 +++++--- tensorflow_datasets/video/moving_sequence_test.py | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tensorflow_datasets/video/moving_sequence.py b/tensorflow_datasets/video/moving_sequence.py index ba36c51c0a5..18d7582d203 100644 --- a/tensorflow_datasets/video/moving_sequence.py +++ b/tensorflow_datasets/video/moving_sequence.py @@ -108,6 +108,7 @@ def image_as_moving_sequence( import tensorflow as tf import tensorflow_datasets as tfds from tensorflow_datasets.video import moving_sequence + tf.compat.v1.enable_eager_execution() def animate(sequence): import numpy as np @@ -199,9 +200,10 @@ def map_fn(image, label): total_padding = output_size - image_shape[:2] # cond = tf.assert_greater(total_padding, -1) - # if not tf.executing_eagerly(): - # with tf.control_dependencies([cond]): - # total_padding = tf.identity(total_padding) + cond = tf.compat.v1.assert_greater(total_padding, -1) + if not tf.executing_eagerly(): + with tf.control_dependencies([cond]): + total_padding = tf.identity(total_padding) sequence_pad_lefts = tf.cast( tf.math.round(trajectory * tf.cast(total_padding, tf.float32)), tf.int32) diff --git a/tensorflow_datasets/video/moving_sequence_test.py b/tensorflow_datasets/video/moving_sequence_test.py index 9c46edcc0b7..148b14e96c1 100644 --- a/tensorflow_datasets/video/moving_sequence_test.py +++ b/tensorflow_datasets/video/moving_sequence_test.py @@ -7,6 +7,7 @@ import tensorflow as tf from tensorflow_datasets.core import test_utils import tensorflow_datasets.video.moving_sequence as ms +tf.compat.v1.enable_eager_execution() class MovingSequenceTest(tf.test.TestCase): @@ -29,15 +30,17 @@ def test_images_as_moving_sequence(self): sequence = tf.cast(sequence.image_sequence, tf.float32) self.assertAllEqual( - tf.reduce_sum(sequence, axis=(1, 2, 3)), - tf.fill((sequence_length,), tf.reduce_sum(tf.cast(image, tf.float32)))) + self.evaluate(tf.reduce_sum(sequence, axis=(1, 2, 3))), + self.evaluate( + tf.fill( + (sequence_length,), tf.reduce_sum(tf.cast(image, tf.float32))))) for i, full_image in enumerate(tf.unstack(sequence, axis=0)): j = i // 2 subimage = full_image[i:i+h, j:j+w] n_true = tf.reduce_sum(subimage) # allow for pixel rounding errors in each dimension - self.assertAllEqual(n_true >= (h-1)*(w-1), True) + self.assertTrue(self.evaluate(n_true) >= (h-1)*(w-1)) if __name__ == '__main__':