Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MovingMNIST dataset #28

Merged
merged 14 commits into from
Feb 7, 2019
Merged

Conversation

jackd
Copy link
Contributor

@jackd jackd commented Jan 28, 2019

moving_sequence doesn't actually provide a dataset per-se, but it's obviously strongly linked to moving_mnist. Not sure if that's precisely what this repo is intended for, but had fun writing it (and watching bouncing shoes/coats from fashion_mnist was interesting).

@googlebot googlebot added the cla: yes Author has signed CLA label Jan 28, 2019
@jackd jackd mentioned this pull request Jan 28, 2019
Copy link
Member

@Conchylicultor Conchylicultor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution. Can you explain what was the issue with tfds.feature.Video ?

tensorflow_datasets/video/moving_mnist.py Outdated Show resolved Hide resolved
# sequence = tfds.features.Image(shape=shape)

# as video - doesn't work with 1 as final dim?
# sequence = tfds.features.Video(shape=shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Video should works.
tfds.features.Video(shape=(seq_lenqth, height, width, 1))

If not, this is a bug from our end. Which error are you seeing ?

Copy link
Contributor

@rsepassi rsepassi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @jackd! And yes, let's try to get Video to work.

tensorflow_datasets/video/__init__.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_mnist.py Outdated Show resolved Hide resolved
@rsepassi rsepassi changed the title added moving_mnist test data and mapping fns for other splits Add MovingMNIST dataset Jan 28, 2019
@jackd
Copy link
Contributor Author

jackd commented Jan 29, 2019

Well that was a fun rabbit hole to dive down. Found/fixed a bug in numpy related to squeeze on lists - it was ignoring the axis argument and being a little overzealous and squeezing out the final rank of the 20, 64, 64, 1 sequence, resulting in shape errors down the line.

That being said, I feel the fixed implementation of squeeze should make the existing implementation of np_to_list (this repo) raise an error. As far as I can tell, the old implementation of np_to_list in this repo and bugged squeeze was effectively the identity function. Fixed np_to_list here, but maybe this isn't the best place to issue such a change?

@@ -61,6 +62,8 @@ def __init__(self, shape):
raise ValueError('Video shape should be of rank 4')
if shape.count(None) > 1:
raise ValueError('Video shape cannot have more than 1 unknown dim')
if shape[-1] not in (1, 3):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, but does this mean videos with 0 or 2 channels are also accepted? In the interest of keeping documentation up to date (had me confused initially when it said channels had to be 3).

tensorflow_datasets/video/moving_mnist.py Outdated Show resolved Hide resolved
@@ -260,7 +260,7 @@ def np_to_list(elem):
return elem
elif isinstance(elem, np.ndarray):
elem = np.split(elem, elem.shape[0])
elem = np.squeeze(elem, axis=0)
elem = [np.squeeze(e, axis=0) for e in elem]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Thanks for fixing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to go away and come back before realizing there was a much better way of doing this, unless I'm failing to appreciate some corner cases (elements won't automatically be converted np arrays... but given the name, I'm guessing that shouldn't be relevant?

Copy link
Contributor

@rsepassi rsepassi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset looks good, but I'm not sure about accepting the moving_sequence module into TFDS. The videos are usable in an ML model as-is; what is moving_sequence useful for?

@jackd
Copy link
Contributor Author

jackd commented Feb 1, 2019

The moving_mnist dataset is the testing dataset used/provided by the authors. They also provided code to dynamically produce their training/validation data. moving_sequence is a port of that to tensorflow with some generalizations (customizable number of images, speed, base dataset, output size) and a simplified bouncing mechanism (visually indistinguishable - original code clipped movement passed boundary to the boundary before reversing direction, this implementation reflects is directly resulting in a much simpler implementation that doesn't involve looping over all time steps. Happy to ellaborate more).

In the interest of reproducibility I think it's appropriate to package them together. Anyone wanting to test on this dataset will, presumably, want a similar dataset to train with, and accessing it from the same source as the test data makes sense to me. How one packages it is another question, and I'm open to suggestions.

My original implementation of moving_mnist overrode as_dataset for train/validation splits to return a moving_sequence mapped mnist dataset. This makes the interface uniform with other datasets, but it also raised other questions.

  • What's an epoch?
  • Should the base mnist dataset be shuffled? If so, how do we discourage users from shuffling the output dataset?
  • Should both train and test sets from mnist be used in the train moving_mnist dataset?

Someone could just make those decisions (or moving_mnist could just have a very large set of builder_kwargs for the train/val case) but I feel if users already understand the Dataset interface it would be easier to give them the lower level tools and examples on how to use them. We could always do both.

A lot of these questions above are related to dynamically generated datasets in general (or those involving non-trivial mapping operations at least). Does tfds have a policy on these? Most of the infrastructure seems tailored to datasets on file, but in terms of my understanding of the projects goals and interface I'm not sure this is the limitation of its scope.

tensorflow_datasets/video/moving_mnist.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_mnist.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_mnist.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_sequence.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_sequence.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_sequence.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_sequence.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_sequence.py Outdated Show resolved Hide resolved
tensorflow_datasets/video/moving_sequence.py Show resolved Hide resolved
tensorflow_datasets/video/moving_sequence.py Outdated Show resolved Hide resolved
@rsepassi
Copy link
Contributor

rsepassi commented Feb 3, 2019

Thank you for explaining!

I think you made the right call to only include the test data and to include this function here so that users can create the moving sequences themselves from the MNIST dataset.

My main request is that we limit the surface area for the new module to just the 1 key method and to add a test for images_as_moving_sequence on some dummy data and use the test_utils.run_in_graph_and_eager_modes decorator.

@jackd
Copy link
Contributor Author

jackd commented Feb 4, 2019

No disagreement here. I've added a few more optional kwargs to the main function in place of other publicly visible functions and adjusted for dynamic sizing, but the implementation remains very much the same.

@jackd
Copy link
Contributor Author

jackd commented Feb 4, 2019

... lunch-time brought the revelation that I'm overcomplicating this trying to do this for multiple images in one shot - and that removing the foldl entirely is likely much easier. Will make changes and fix test errors now...

@rsepassi
Copy link
Contributor

rsepassi commented Feb 4, 2019 via email

@jackd
Copy link
Contributor Author

jackd commented Feb 4, 2019

Finally, an excuse to play with tf 2.0 :). Merged in master changes and updated to single-image implementation.

Copy link
Contributor

@rsepassi rsepassi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, just a few little things in the test and then we're good to go!

@jackd
Copy link
Contributor Author

jackd commented Feb 5, 2019

ack, just found tf 2.0 bug with assert_greater. will put in the pull request and link when done, but until it gets merged in the compat.v1 workaround might have to stay.

Update: I have no idea how assert_xxx_v2 are supposed to be used in graph mode. None of them have return statements, but the omissions look intentional - the Returns part of the documentation is also removed.

Copy link
Contributor

@rsepassi rsepassi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Will merge after I've verified on my end that the dataset generates ok.

@tfds-copybara tfds-copybara merged commit df56746 into tensorflow:master Feb 7, 2019
tfds-copybara pushed a commit that referenced this pull request Feb 7, 2019
PiperOrigin-RevId: 232784325
@jackd jackd deleted the moving_mnist branch February 7, 2019 23:20
return tf.math.minimum(2 - points, points)


def _get_random_unit_vector(ndims=2, dtype=tf.float32):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current code generates the direction from a normal distribution. I believe this is not what the original code does which generates the direction uniformly.

Please see:

line 245 in data_handlers.py in http://www.cs.toronto.edu/~nitish/unsupervised_video/unsup_video_lstm.tar.gz
where the direction is sampled uniformly from 0 - 2*pi

Is that correct or am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad to see someone's actually reading the code! The original takes a unit vector with angle sampled uniformly. Here we take random normal coordinates and normalize them. The resulting distributions are equivalent. See e.g. alternative method 1 or convince yourself with the below code.

import numpy as np
import matplotlib.pyplot as plt

n = 100000
x, y = np.random.normal(size=(2, n))
angle = np.arctan2(y, x)
plt.hist(angle)
plt.show()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, cool thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Author has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants