-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MovingMNIST dataset #28
Changes from 5 commits
a55bf80
3e0940e
5f0a8c7
e312617
8d3f388
e3615c4
3a5fe00
38cd512
faea0ee
c4164b2
9d8b1d0
05be775
2383f64
df56746
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,10 +29,11 @@ class Video(sequence_feature.Sequence): | |
"""`FeatureConnector` for videos, png-encoding frames on disk. | ||
|
||
Video: The image connector accepts as input: | ||
* uint8 array representing an video. | ||
* uint8 array representing a video. | ||
|
||
Output: | ||
video: tf.Tensor of type tf.uint8 and shape [num_frames, height, width, 3] | ||
video: tf.Tensor of type tf.uint8 and shape | ||
[num_frames, height, width, channels], where channels must be 1 or 3 | ||
|
||
Example: | ||
* In the DatasetInfo object: | ||
|
@@ -51,7 +52,7 @@ def __init__(self, shape): | |
|
||
Args: | ||
shape: tuple of ints, the shape of the video (num_frames, height, width, | ||
channels=3). | ||
channels), where channels is 1 or 3. | ||
|
||
Raises: | ||
ValueError: If the shape is invalid | ||
|
@@ -61,6 +62,8 @@ def __init__(self, shape): | |
raise ValueError('Video shape should be of rank 4') | ||
if shape.count(None) > 1: | ||
raise ValueError('Video shape cannot have more than 1 unknown dim') | ||
if shape[-1] not in (1, 3): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove. The test is already in Image() called bellow: https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/core/features/image_feature.py#L108 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, but does this mean videos with 0 or 2 channels are also accepted? In the interest of keeping documentation up to date (had me confused initially when it said channels had to be 3). |
||
raise ValueError('Video channels must be 1 or 3, got %d' % shape[-1]) | ||
|
||
super(Video, self).__init__( | ||
image_feature.Image(shape=shape[1:], encoding_format='png'), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import collections | ||
import functools | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_datasets.public_api as tfds | ||
|
||
_OUT_RESOLUTION = (64, 64) | ||
_SEQUENCE_LENGTH = 20 | ||
_citation = """ | ||
@article{DBLP:journals/corr/SrivastavaMS15, | ||
author = {Nitish Srivastava and | ||
Elman Mansimov and | ||
Ruslan Salakhutdinov}, | ||
title = {Unsupervised Learning of Video Representations using LSTMs}, | ||
journal = {CoRR}, | ||
volume = {abs/1502.04681}, | ||
year = {2015}, | ||
url = {http://arxiv.org/abs/1502.04681}, | ||
archivePrefix = {arXiv}, | ||
eprint = {1502.04681}, | ||
timestamp = {Mon, 13 Aug 2018 16:47:05 +0200}, | ||
biburl = {https://dblp.org/rec/bib/journals/corr/SrivastavaMS15}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
""" | ||
|
||
|
||
class MovingMnist(tfds.core.GeneratorBasedBuilder): | ||
|
||
VERSION = tfds.core.Version("0.1.0") | ||
|
||
def _info(self): | ||
return tfds.core.DatasetInfo( | ||
builder=self, | ||
description=( | ||
"Moving variant of MNIST database of handwritten digits. This is the " | ||
"data used by the authors for reporting model performance. See " | ||
"`tfds.video.moving_sequence` for functions to generate training/" | ||
jackd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"validation data."), | ||
jackd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
features=tfds.features.FeaturesDict( | ||
dict(image_sequence=tfds.features.Video( | ||
shape=(_SEQUENCE_LENGTH,) + _OUT_RESOLUTION + (1,)))), | ||
urls=["http://www.cs.toronto.edu/~nitish/unsupervised_video/"], | ||
citation=_citation, | ||
splits=[tfds.Split.TEST] | ||
jackd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
def _split_generators(self, dl_manager): | ||
data_path = dl_manager.download( | ||
"http://www.cs.toronto.edu/~nitish/unsupervised_video/" | ||
"mnist_test_seq.npy") | ||
|
||
# authors only provide test data. See `tfds.video.moving_sequence` for | ||
jackd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# approach based on creating sequences from existing datasets | ||
return [ | ||
tfds.core.SplitGenerator( | ||
name=tfds.Split.TEST, | ||
num_shards=5, | ||
gen_kwargs=dict(data_path=data_path)), | ||
] | ||
|
||
def _generate_examples(self, data_path): | ||
"""Generate MOVING_MNIST sequences. | ||
|
||
Args: | ||
data_path (str): Path to the data file | ||
|
||
Yields: | ||
20 x 64 x 64 x 1 uint8 numpy arrays | ||
""" | ||
with tf.io.gfile.GFile(data_path, "rb") as fp: | ||
images = np.load(fp) | ||
images = np.transpose(images, (1, 0, 2, 3)) | ||
images = np.expand_dims(images, axis=-1) | ||
for sequence in images: | ||
yield dict(image_sequence=sequence) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch. Thanks for fixing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to go away and come back before realizing there was a much better way of doing this, unless I'm failing to appreciate some corner cases (elements won't automatically be converted np arrays... but given the name, I'm guessing that shouldn't be relevant?