Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MovingMNIST dataset #28

Merged
merged 14 commits into from
Feb 7, 2019
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def np_to_list(elem):
return elem
elif isinstance(elem, np.ndarray):
elem = np.split(elem, elem.shape[0])
elem = np.squeeze(elem, axis=0)
elem = [np.squeeze(e, axis=0) for e in elem]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Thanks for fixing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to go away and come back before realizing there was a much better way of doing this, unless I'm failing to appreciate some corner cases (elements won't automatically be converted np arrays... but given the name, I'm guessing that shouldn't be relevant?

return elem
else:
raise ValueError(
Expand Down
9 changes: 6 additions & 3 deletions tensorflow_datasets/core/features/video_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ class Video(sequence_feature.Sequence):
"""`FeatureConnector` for videos, png-encoding frames on disk.

Video: The image connector accepts as input:
* uint8 array representing an video.
* uint8 array representing a video.

Output:
video: tf.Tensor of type tf.uint8 and shape [num_frames, height, width, 3]
video: tf.Tensor of type tf.uint8 and shape
[num_frames, height, width, channels], where channels must be 1 or 3

Example:
* In the DatasetInfo object:
Expand All @@ -51,7 +52,7 @@ def __init__(self, shape):

Args:
shape: tuple of ints, the shape of the video (num_frames, height, width,
channels=3).
channels), where channels is 1 or 3.

Raises:
ValueError: If the shape is invalid
Expand All @@ -61,6 +62,8 @@ def __init__(self, shape):
raise ValueError('Video shape should be of rank 4')
if shape.count(None) > 1:
raise ValueError('Video shape cannot have more than 1 unknown dim')
if shape[-1] not in (1, 3):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, but does this mean videos with 0 or 2 channels are also accepted? In the interest of keeping documentation up to date (had me confused initially when it said channels had to be 3).

raise ValueError('Video channels must be 1 or 3, got %d' % shape[-1])

super(Video, self).__init__(
image_feature.Image(shape=shape[1:], encoding_format='png'),
Expand Down
1 change: 1 addition & 0 deletions tensorflow_datasets/video/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
"""Video datasets."""

from tensorflow_datasets.video.bair_robot_pushing import BairRobotPushingSmall
from tensorflow_datasets.video.moving_mnist import MovingMnist
from tensorflow_datasets.video.starcraft import StarcraftVideo
from tensorflow_datasets.video.starcraft import StarcraftVideoConfig
80 changes: 80 additions & 0 deletions tensorflow_datasets/video/moving_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import functools
import numpy as np
import tensorflow as tf
import tensorflow_datasets.public_api as tfds

_OUT_RESOLUTION = (64, 64)
_SEQUENCE_LENGTH = 20
_citation = """
@article{DBLP:journals/corr/SrivastavaMS15,
author = {Nitish Srivastava and
Elman Mansimov and
Ruslan Salakhutdinov},
title = {Unsupervised Learning of Video Representations using LSTMs},
journal = {CoRR},
volume = {abs/1502.04681},
year = {2015},
url = {http://arxiv.org/abs/1502.04681},
archivePrefix = {arXiv},
eprint = {1502.04681},
timestamp = {Mon, 13 Aug 2018 16:47:05 +0200},
biburl = {https://dblp.org/rec/bib/journals/corr/SrivastavaMS15},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
"""


class MovingMnist(tfds.core.GeneratorBasedBuilder):

VERSION = tfds.core.Version("0.1.0")

def _info(self):
return tfds.core.DatasetInfo(
builder=self,
description=(
"Moving variant of MNIST database of handwritten digits. This is the "
"data used by the authors for reporting model performance. See "
"`tfds.video.moving_sequence` for functions to generate training/"
jackd marked this conversation as resolved.
Show resolved Hide resolved
"validation data."),
jackd marked this conversation as resolved.
Show resolved Hide resolved
features=tfds.features.FeaturesDict(
dict(image_sequence=tfds.features.Video(
shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))),
urls=["http://www.cs.toronto.edu/~nitish/unsupervised_video/"],
citation=_citation,
splits=[tfds.Split.TEST]
jackd marked this conversation as resolved.
Show resolved Hide resolved
)

def _split_generators(self, dl_manager):
data_path = dl_manager.download(
"http://www.cs.toronto.edu/~nitish/unsupervised_video/"
"mnist_test_seq.npy")

# authors only provide test data. See `tfds.video.moving_sequence` for
jackd marked this conversation as resolved.
Show resolved Hide resolved
# approach based on creating sequences from existing datasets
return [
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
num_shards=5,
gen_kwargs=dict(data_path=data_path)),
]

def _generate_examples(self, data_path):
"""Generate MOVING_MNIST sequences.

Args:
data_path (str): Path to the data file

Yields:
20 x 64 x 64 x 1 uint8 numpy arrays
"""
with tf.io.gfile.GFile(data_path, "rb") as fp:
images = np.load(fp)
images = np.transpose(images, (1, 0, 2, 3))
images = np.expand_dims(images, axis=-1)
for sequence in images:
yield dict(image_sequence=sequence)