Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions official/mnist/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2018! Nice :)

#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tf.data.Dataset interface to the MNIST dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil
import gzip
import numpy as np
from six.moves import urllib
import tensorflow as tf


def read32(bytestream):
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dt)[0]


def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with open(filename) as f:
magic = read32(f)
num_images = read32(f)
rows = read32(f)
cols = read32(f)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
if rows != 28 or cols != 28:
raise ValueError(
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
(f.name, rows, cols))


def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with open(filename) as f:
magic = read32(f)
num_items = read32(f)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))


def download(directory, filename):
"""Download (and unzip) a file from the MNIST dataset, if it doesn't already exist."""
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
zipped_filepath = filepath + '.gz'
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath


def dataset(directory, images_file, labels_file):
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)

check_image_file_header(images_file)
check_labels_file_header(labels_file)

def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0

def one_hot_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8
label = tf.reshape(label, []) # label is a scalar
return tf.one_hot(label, 10)

images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(one_hot_label)
return tf.data.Dataset.zip((images, labels))


def train(directory):
"""tf.data.Dataset object for MNIST training data."""
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')


def test(directory):
"""tf.data.Dataset object for MNIST test data."""
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
23 changes: 6 additions & 17 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,7 @@
import sys

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


def train_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
data = input_data.read_data_sets(data_dir, one_hot=True).train
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))


def eval_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
data = input_data.read_data_sets(data_dir, one_hot=True).test
return tf.data.Dataset.from_tensors((data.images, data.labels))
import dataset


class Model(object):
Expand Down Expand Up @@ -151,10 +139,10 @@ def train_input_fn():
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
dataset = train_dataset(FLAGS.data_dir)
dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
ds = dataset.train(FLAGS.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = dataset.make_one_shot_iterator().get_next()
(images, labels) = ds.make_one_shot_iterator().get_next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we're in here, would it make sense to switch to the new style of returning a Dataset directly? (Or perhaps, since 1.5 hasn't landed yet, we should have a TODO to make that switch?)

(Same applies to eval_input_fn() below.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, was waiting for 1.5 to land.
Tempted to avoid TODOs in these "best practices" samples, unless you feel strongly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough!

return (images, labels)

# Set up training hook that logs the training accuracy every 100 steps.
Expand All @@ -165,7 +153,8 @@ def train_input_fn():

# Evaluate the model and print results
def eval_input_fn():
return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()
return dataset.test(FLAGS.data_dir).batch(
FLAGS.batch_size).make_one_shot_iterator().get_next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think this is the right code style. Can we split this onto two lines instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the result of running the Python formatter (https://github.com/google/yapf), so it should be right? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's surprising; the part that seemed weird to me was indenting the arguments to the next line and then following up with more function calls. If the formatter says it's good though, should be fine.


eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print()
Expand Down