-
Notifications
You must be signed in to change notification settings - Fork 45.4k
[mnist]: Use FixedLengthRecordDataset #3093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Prior to this change, the use of tf.data.Dataset essentially embedded the entire training/evaluation dataset into the graph as a constant, leading to unnecessarily humungous graphs (Fixes #3017) - Also, use batching on the evaluation dataset to allow evaluation on GPUs that cannot fit the entire evaluation dataset in memory (Fixes #3046)
mrry
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of documentation nits....
official/mnist/dataset.py
Outdated
|
|
||
|
|
||
| def maybe_download(directory, filename): | ||
| """Download a file from the MNIST dataset, if it doesn't already exist.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps mention that this gunzips the file as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat( | ||
| FLAGS.train_epochs) | ||
| (images, labels) = dataset.make_one_shot_iterator().get_next() | ||
| (images, labels) = ds.make_one_shot_iterator().get_next() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While we're in here, would it make sense to switch to the new style of returning a Dataset directly? (Or perhaps, since 1.5 hasn't landed yet, we should have a TODO to make that switch?)
(Same applies to eval_input_fn() below.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, was waiting for 1.5 to land.
Tempted to avoid TODOs in these "best practices" samples, unless you feel strongly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough!
|
@mrry Does |
|
@k-w-w - It caches them in CPU memory, not GPU memory. (That said, we could remove the use of |
|
@asimshankar ohh, I see. Thanks for the response! |
|
Even with caching it will tend to use less memory, because the payloads of |
| ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat( | ||
| FLAGS.train_epochs) | ||
| (images, labels) = dataset.make_one_shot_iterator().get_next() | ||
| (images, labels) = ds.make_one_shot_iterator().get_next() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough!
k-w-w
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing these issues!
|
@mrry good to know, thanks! (looking forward to the dataset performance guide to learn more) |
|
@k-w-w On that topic, you can get a sneak peek at the guide here: It should be on tensorflow.org once the 1.5 release is published. |
nealwu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a few comments.
| @@ -0,0 +1,112 @@ | |||
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018! Nice :)
official/mnist/dataset.py
Outdated
| f.name)) | ||
|
|
||
|
|
||
| def maybe_download(directory, filename): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't like the name maybe_download since the 'maybe' part seems very ambiguous. Let's call this attempt_download or just download instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to download.
official/mnist/dataset.py
Outdated
| url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' | ||
| zipped_filename = filename + '.gz' | ||
| zipped_filepath = os.path.join(directory, zipped_filename) | ||
| tf.contrib.learn.datasets.base.maybe_download(zipped_filename, directory, url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way we can do this without contrib?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to use urllib.request.urlretrieve. This means that the retry logic implemented in
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/base.py#L189 no longer applies, but I suspect that this retry business is less relevant now that we're using a CVDF mirror.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good!
| def eval_input_fn(): | ||
| return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next() | ||
| return dataset.test(FLAGS.data_dir).batch( | ||
| FLAGS.batch_size).make_one_shot_iterator().get_next() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think this is the right code style. Can we split this onto two lines instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the result of running the Python formatter (https://github.com/google/yapf), so it should be right? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's surprising; the part that seemed weird to me was indenting the arguments to the next line and then following up with more function calls. If the formatter says it's good though, should be fine.
[mnist]: Use FixedLengthRecordDataset
the entire training/evaluation dataset into the graph as a constant,
leading to unnecessarily humungous graphs (Fixes Official MNIST example should avoid huge constants #3017)
evaluation on GPUs that cannot fit the entire evaluation dataset in
memory (Fixes the mnist under official model has OOM issue #3046)
Using FixedLengthRecordDataset also provides an opportunity to use the same input pipeline code for the TPU demos (https://github.com/tensorflow/tpu-demos/tree/42a987e/cloud_tpu/models/mnist) without having to convert the raw data to TFRecords.