-
Notifications
You must be signed in to change notification settings - Fork 45.4k
[mnist]: Use FixedLengthRecordDataset #3093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """tf.data.Dataset interface to the MNIST dataset.""" | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import os | ||
| import shutil | ||
| import gzip | ||
| import numpy as np | ||
| from six.moves import urllib | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| def read32(bytestream): | ||
| """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" | ||
| dt = np.dtype(np.uint32).newbyteorder('>') | ||
| return np.frombuffer(bytestream.read(4), dtype=dt)[0] | ||
|
|
||
|
|
||
| def check_image_file_header(filename): | ||
| """Validate that filename corresponds to images for the MNIST dataset.""" | ||
| with open(filename) as f: | ||
| magic = read32(f) | ||
| num_images = read32(f) | ||
| rows = read32(f) | ||
| cols = read32(f) | ||
| if magic != 2051: | ||
| raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, | ||
| f.name)) | ||
| if rows != 28 or cols != 28: | ||
| raise ValueError( | ||
| 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % | ||
| (f.name, rows, cols)) | ||
|
|
||
|
|
||
| def check_labels_file_header(filename): | ||
| """Validate that filename corresponds to labels for the MNIST dataset.""" | ||
| with open(filename) as f: | ||
| magic = read32(f) | ||
| num_items = read32(f) | ||
| if magic != 2049: | ||
| raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, | ||
| f.name)) | ||
|
|
||
|
|
||
| def download(directory, filename): | ||
| """Download (and unzip) a file from the MNIST dataset, if it doesn't already exist.""" | ||
| if not tf.gfile.Exists(directory): | ||
| tf.gfile.MakeDirs(directory) | ||
| filepath = os.path.join(directory, filename) | ||
| if tf.gfile.Exists(filepath): | ||
| return filepath | ||
| # CVDF mirror of http://yann.lecun.com/exdb/mnist/ | ||
| url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' | ||
| zipped_filepath = filepath + '.gz' | ||
| print('Downloading %s to %s' % (url, zipped_filepath)) | ||
| urllib.request.urlretrieve(url, zipped_filepath) | ||
| with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out: | ||
| shutil.copyfileobj(f_in, f_out) | ||
| os.remove(zipped_filepath) | ||
| return filepath | ||
|
|
||
|
|
||
| def dataset(directory, images_file, labels_file): | ||
| images_file = download(directory, images_file) | ||
| labels_file = download(directory, labels_file) | ||
|
|
||
| check_image_file_header(images_file) | ||
| check_labels_file_header(labels_file) | ||
|
|
||
| def decode_image(image): | ||
| # Normalize from [0, 255] to [0.0, 1.0] | ||
| image = tf.decode_raw(image, tf.uint8) | ||
| image = tf.cast(image, tf.float32) | ||
| image = tf.reshape(image, [784]) | ||
| return image / 255.0 | ||
|
|
||
| def one_hot_label(label): | ||
| label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8 | ||
| label = tf.reshape(label, []) # label is a scalar | ||
| return tf.one_hot(label, 10) | ||
|
|
||
| images = tf.data.FixedLengthRecordDataset( | ||
| images_file, 28 * 28, header_bytes=16).map(decode_image) | ||
| labels = tf.data.FixedLengthRecordDataset( | ||
| labels_file, 1, header_bytes=8).map(one_hot_label) | ||
| return tf.data.Dataset.zip((images, labels)) | ||
|
|
||
|
|
||
| def train(directory): | ||
| """tf.data.Dataset object for MNIST training data.""" | ||
| return dataset(directory, 'train-images-idx3-ubyte', | ||
| 'train-labels-idx1-ubyte') | ||
|
|
||
|
|
||
| def test(directory): | ||
| """tf.data.Dataset object for MNIST test data.""" | ||
| return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,19 +22,7 @@ | |
| import sys | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow.examples.tutorials.mnist import input_data | ||
|
|
||
|
|
||
| def train_dataset(data_dir): | ||
| """Returns a tf.data.Dataset yielding (image, label) pairs for training.""" | ||
| data = input_data.read_data_sets(data_dir, one_hot=True).train | ||
| return tf.data.Dataset.from_tensor_slices((data.images, data.labels)) | ||
|
|
||
|
|
||
| def eval_dataset(data_dir): | ||
| """Returns a tf.data.Dataset yielding (image, label) pairs for evaluation.""" | ||
| data = input_data.read_data_sets(data_dir, one_hot=True).test | ||
| return tf.data.Dataset.from_tensors((data.images, data.labels)) | ||
| import dataset | ||
|
|
||
|
|
||
| class Model(object): | ||
|
|
@@ -151,10 +139,10 @@ def train_input_fn(): | |
| # When choosing shuffle buffer sizes, larger sizes result in better | ||
| # randomness, while smaller sizes use less memory. MNIST is a small | ||
| # enough dataset that we can easily shuffle the full epoch. | ||
| dataset = train_dataset(FLAGS.data_dir) | ||
| dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat( | ||
| ds = dataset.train(FLAGS.data_dir) | ||
| ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat( | ||
| FLAGS.train_epochs) | ||
| (images, labels) = dataset.make_one_shot_iterator().get_next() | ||
| (images, labels) = ds.make_one_shot_iterator().get_next() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While we're in here, would it make sense to switch to the new style of returning a (Same applies to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup, was waiting for 1.5 to land.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough! |
||
| return (images, labels) | ||
|
|
||
| # Set up training hook that logs the training accuracy every 100 steps. | ||
|
|
@@ -165,7 +153,8 @@ def train_input_fn(): | |
|
|
||
| # Evaluate the model and print results | ||
| def eval_input_fn(): | ||
| return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next() | ||
| return dataset.test(FLAGS.data_dir).batch( | ||
| FLAGS.batch_size).make_one_shot_iterator().get_next() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I don't think this is the right code style. Can we split this onto two lines instead?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was the result of running the Python formatter (https://github.com/google/yapf), so it should be right? :)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's surprising; the part that seemed weird to me was indenting the arguments to the next line and then following up with more function calls. If the formatter says it's good though, should be fine. |
||
|
|
||
| eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) | ||
| print() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2018! Nice :)