Skip to content

Conversation

@asimshankar
Copy link
Contributor

Using FixedLengthRecordDataset also provides an opportunity to use the same input pipeline code for the TPU demos (https://github.com/tensorflow/tpu-demos/tree/42a987e/cloud_tpu/models/mnist) without having to convert the raw data to TFRecords.

- Prior to this change, the use of tf.data.Dataset essentially embedded
  the entire training/evaluation dataset into the graph as a constant,
  leading to unnecessarily humungous graphs (Fixes #3017)
- Also, use batching on the evaluation dataset to allow
  evaluation on GPUs that cannot fit the entire evaluation dataset in
  memory (Fixes #3046)
@asimshankar asimshankar requested review from mrry and nealwu January 2, 2018 22:04
@asimshankar asimshankar requested a review from k-w-w as a code owner January 2, 2018 22:04
Copy link
Contributor

@mrry mrry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of documentation nits....



def maybe_download(directory, filename):
"""Download a file from the MNIST dataset, if it doesn't already exist."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps mention that this gunzips the file as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = dataset.make_one_shot_iterator().get_next()
(images, labels) = ds.make_one_shot_iterator().get_next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're in here, would it make sense to switch to the new style of returning a Dataset directly? (Or perhaps, since 1.5 hasn't landed yet, we should have a TODO to make that switch?)

(Same applies to eval_input_fn() below.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, was waiting for 1.5 to land.
Tempted to avoid TODOs in these "best practices" samples, unless you feel strongly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough!

@k-w-w
Copy link
Contributor

k-w-w commented Jan 2, 2018

@mrry Does Dataset.cache() cache all the examples into memory (which would defeat the purpose of trying to use less memory)?

@asimshankar
Copy link
Contributor Author

@k-w-w - It caches them in CPU memory, not GPU memory. (That said, we could remove the use of cache() here).

@k-w-w
Copy link
Contributor

k-w-w commented Jan 2, 2018

@asimshankar ohh, I see. Thanks for the response!

@mrry
Copy link
Contributor

mrry commented Jan 2, 2018

Even with caching it will tend to use less memory, because the payloads of tf.constant() ops end up being stored multiple times in RAM (something that would be good to fix independently...). While you could disable caching, that moves the reading and parsing onto the critical path, and there's a good chance it would make the training process I/O-bound.

ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = dataset.make_one_shot_iterator().get_next()
(images, labels) = ds.make_one_shot_iterator().get_next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough!

Copy link
Contributor

@k-w-w k-w-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing these issues!

@k-w-w
Copy link
Contributor

k-w-w commented Jan 2, 2018

@mrry good to know, thanks! (looking forward to the dataset performance guide to learn more)

@mrry
Copy link
Contributor

mrry commented Jan 2, 2018

@k-w-w On that topic, you can get a sneak peek at the guide here:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/performance/datasets_performance.md

It should be on tensorflow.org once the 1.5 release is published.

Copy link
Contributor

@nealwu nealwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a few comments.

@@ -0,0 +1,112 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018! Nice :)

f.name))


def maybe_download(directory, filename):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't like the name maybe_download since the 'maybe' part seems very ambiguous. Let's call this attempt_download or just download instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to download.

url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
zipped_filename = filename + '.gz'
zipped_filepath = os.path.join(directory, zipped_filename)
tf.contrib.learn.datasets.base.maybe_download(zipped_filename, directory, url)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way we can do this without contrib?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to use urllib.request.urlretrieve. This means that the retry logic implemented in
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/base.py#L189 no longer applies, but I suspect that this retry business is less relevant now that we're using a CVDF mirror.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

def eval_input_fn():
return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
return dataset.test(FLAGS.data_dir).batch(
FLAGS.batch_size).make_one_shot_iterator().get_next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think this is the right code style. Can we split this onto two lines instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the result of running the Python formatter (https://github.com/google/yapf), so it should be right? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising; the part that seemed weird to me was indenting the arguments to the next line and then following up with more function calls. If the formatter says it's good though, should be fine.

@nealwu nealwu changed the title [mnist]: Use FixedLengthRecordDatatest [mnist]: Use FixedLengthRecordDataset Jan 2, 2018
@asimshankar asimshankar merged commit 8e4a1e2 into tensorflow:master Jan 3, 2018
@asimshankar asimshankar deleted the mnist branch January 3, 2018 01:46
Adrrei pushed a commit to Adrrei/models that referenced this pull request Dec 16, 2018
[mnist]: Use FixedLengthRecordDataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants