Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions official/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ APIs.
## Setup

To begin, you'll simply need the latest version of TensorFlow installed.

First convert the MNIST data to TFRecord file format by running the following:

```
python convert_to_records.py
```

Then to train the model, run the following:

```
Expand Down
97 changes: 97 additions & 0 deletions official/mnist/convert_to_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Converts MNIST data to TFRecords file format with Example protos.

To read about optimizations that can be applied to the input preprocessing
stage, see: https://www.tensorflow.org/performance/performance_guide#input_pipeline_optimization.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

import tensorflow as tf

from tensorflow.contrib.learn.python.learn.datasets import mnist

parser = argparse.ArgumentParser()

parser.add_argument('--directory', type=str, default='/tmp/mnist_data',
help='Directory to download data files and write the '
'converted result.')

parser.add_argument('--validation_size', type=int, default=0,
help='Number of examples to separate from the training '
'data for the validation set.')


def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def convert_to(dataset, name, directory):
"""Converts a dataset to TFRecords."""
images = dataset.images
labels = dataset.labels
num_examples = dataset.num_examples

if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]

filename = os.path.join(directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()


def main(unused_argv):
# Get the data.
datasets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)

# Convert to Examples and write the result to TFRecords.
convert_to(datasets.train, 'train', FLAGS.directory)
convert_to(datasets.validation, 'validation', FLAGS.directory)
convert_to(datasets.test, 'test', FLAGS.directory)


if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
142 changes: 81 additions & 61 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,53 +22,75 @@
import sys

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

parser = argparse.ArgumentParser()

# Basic model parameters.
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Number of images to process in a batch')
parser.add_argument('--batch_size', type=int, default=100,
help='Number of images to process in a batch')

parser.add_argument(
'--data_dir',
type=str,
default='/tmp/mnist_data',
help='Path to directory containing the MNIST dataset')
parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data',
help='Path to the MNIST data directory.')

parser.add_argument(
'--model_dir',
type=str,
default='/tmp/mnist_model',
help='The directory where the model will be stored.')
parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model',
help='The directory where the model will be stored.')

parser.add_argument(
'--train_epochs', type=int, default=40, help='Number of epochs to train.')
parser.add_argument('--train_epochs', type=int, default=40,
help='Number of epochs to train.')

parser.add_argument(
'--data_format',
type=str,
default=None,
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')

'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')

_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}


def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the tf.data input pipeline."""

def example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([28 * 28])

# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)

dataset = tf.data.TFRecordDataset([filename])

# Apply dataset transformations
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
# a small dataset, we can easily shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])

def train_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for training."""
data = input_data.read_data_sets(data_dir, one_hot=True).train
return tf.data.Dataset.from_tensor_slices((data.images, data.labels))
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)

# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(example_parser).prefetch(batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()

def eval_dataset(data_dir):
"""Returns a tf.data.Dataset yielding (image, label) pairs for evaluation."""
data = input_data.read_data_sets(data_dir, one_hot=True).test
return tf.data.Dataset.from_tensors((data.images, data.labels))
return images, labels


def mnist_model(inputs, mode, data_format):
Expand All @@ -82,8 +104,8 @@ def mnist_model(inputs, mode, data_format):
# When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance.
# See https://www.tensorflow.org/performance/performance_guide#data_formats
data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last')
data_format = ('channels_first' if tf.test.is_built_with_cuda() else
'channels_last')

if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2])
Expand All @@ -105,8 +127,8 @@ def mnist_model(inputs, mode, data_format):
# First max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 28, 28, 32]
# Output Tensor Shape: [batch_size, 14, 14, 32]
pool1 = tf.layers.max_pooling2d(
inputs=conv1, pool_size=[2, 2], strides=2, data_format=data_format)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2,
data_format=data_format)

# Convolutional Layer #2
# Computes 64 features using a 5x5 filter.
Expand All @@ -125,8 +147,8 @@ def mnist_model(inputs, mode, data_format):
# Second max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 14, 14, 64]
# Output Tensor Shape: [batch_size, 7, 7, 64]
pool2 = tf.layers.max_pooling2d(
inputs=conv2, pool_size=[2, 2], strides=2, data_format=data_format)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2,
data_format=data_format)

# Flatten tensor into a batch of vectors
# Input Tensor Shape: [batch_size, 7, 7, 64]
Expand All @@ -137,7 +159,8 @@ def mnist_model(inputs, mode, data_format):
# Densely connected layer with 1024 neurons
# Input Tensor Shape: [batch_size, 7 * 7 * 64]
# Output Tensor Shape: [batch_size, 1024]
dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu)
dense = tf.layers.dense(inputs=pool2_flat, units=1024,
activation=tf.nn.relu)

# Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout(
Expand Down Expand Up @@ -188,37 +211,34 @@ def mnist_model_fn(features, labels, mode, params):


def main(unused_argv):
# Make sure that training and testing data have been converted.
train_file = os.path.join(FLAGS.data_dir, 'train.tfrecords')
test_file = os.path.join(FLAGS.data_dir, 'test.tfrecords')
assert (tf.gfile.Exists(train_file) and tf.gfile.Exists(test_file)), (
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.')

# Create the Estimator
mnist_classifier = tf.estimator.Estimator(
model_fn=mnist_model_fn,
model_dir=FLAGS.model_dir,
params={
'data_format': FLAGS.data_format
})
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir,
params={'data_format': FLAGS.data_format})

# Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log = {'train_accuracy': 'train_accuracy'}
tensors_to_log = {
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)

# Train the model
def train_input_fn():
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
dataset = train_dataset(FLAGS.data_dir)
dataset = dataset.shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = dataset.make_one_shot_iterator().get_next()
return (images, labels)

mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
mnist_classifier.train(
input_fn=lambda: input_fn(
True, train_file, FLAGS.batch_size, FLAGS.train_epochs),
hooks=[logging_hook])

# Evaluate the model and print results
def eval_input_fn():
return eval_dataset(FLAGS.data_dir).make_one_shot_iterator().get_next()

eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
eval_results = mnist_classifier.evaluate(
input_fn=lambda: input_fn(False, test_file, FLAGS.batch_size))
print()
print('Evaluation results:\n\t%s' % eval_results)

Expand Down
Loading