Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions official/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
8 changes: 5 additions & 3 deletions official/mnist/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from __future__ import division
from __future__ import print_function

import gzip
import os
import shutil
import gzip

import numpy as np
from six.moves import urllib
Expand All @@ -36,7 +36,7 @@ def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
num_images = read32(f)
read32(f) # num_images, unused
rows = read32(f)
cols = read32(f)
if magic != 2051:
Expand All @@ -52,7 +52,7 @@ def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with tf.gfile.Open(filename, 'rb') as f:
magic = read32(f)
num_items = read32(f)
read32(f) # num_items, unused
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
Expand All @@ -77,6 +77,8 @@ def download(directory, filename):


def dataset(directory, images_file, labels_file):
"""Download and parse MNIST dataset."""

images_file = download(directory, images_file)
labels_file = download(directory, labels_file)

Expand Down
33 changes: 21 additions & 12 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
import argparse
import sys

import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.mnist import dataset
from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper

LEARNING_RATE = 1e-4


class Model(tf.keras.Model):
"""Model to recognize digits in the MNIST dataset.

Expand Down Expand Up @@ -145,31 +146,36 @@ def model_fn(features, labels, mode, params):


def validate_batch_size_for_multi_gpu(batch_size):
"""For multi-gpu, batch-size must be a multiple of the number of
available GPUs.
"""For multi-gpu, batch-size must be a multiple of the number of GPUs.

Note that this should eventually be handled by replicate_model_fn
directly. Multi-GPU support is currently experimental, however,
so doing the work here until that feature is in place.

Args:
batch_size: the number of examples processed in each training batch.

Raises:
ValueError: if no GPUs are found, or selected batch_size is invalid.
"""
from tensorflow.python.client import device_lib
from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason that we have to do a local import and not put it on top?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to hide this code, as it should not be exposed to the user, and we want to remove it as soon as it's done by Estimator directly. So, for now, leaving bundled.


local_device_protos = device_lib.list_local_devices()
num_gpus = sum([1 for d in local_device_protos if d.device_type == 'GPU'])
if not num_gpus:
raise ValueError('Multi-GPU mode was specified, but no GPUs '
'were found. To use CPU, run without --multi_gpu.')
'were found. To use CPU, run without --multi_gpu.')

remainder = batch_size % num_gpus
if remainder:
err = ('When running with multiple GPUs, batch size '
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
'must be a multiple of the number of available GPUs. '
'Found {} GPUs with a batch size of {}; try --batch_size={} instead.'
).format(num_gpus, batch_size, batch_size - remainder)
raise ValueError(err)


def main(unused_argv):
def main(_):
model_function = model_fn

if FLAGS.multi_gpu:
Expand All @@ -195,6 +201,8 @@ def main(unused_argv):

# Set up training and evaluation input functions.
def train_input_fn():
"""Prepare data for training."""

# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
Expand All @@ -215,7 +223,7 @@ def eval_input_fn():
FLAGS.hooks, batch_size=FLAGS.batch_size)

# Train and evaluate model.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
for _ in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
Expand All @@ -231,10 +239,11 @@ def eval_input_fn():

class MNISTArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model."""

def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser()])
parsers.BaseParser(),
parsers.ImageModelParser()])

self.add_argument(
'--export_dir',
Expand Down
19 changes: 10 additions & 9 deletions official/mnist/mnist_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
import sys
import time

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order

from official.mnist import dataset as mnist_dataset
from official.mnist import mnist
from official.mnist import dataset
from official.utils.arg_parsers import parsers

FLAGS = None
Expand Down Expand Up @@ -110,9 +110,9 @@ def main(_):
print('Using device %s, and data format %s.' % (device, data_format))

# Load the datasets
train_ds = dataset.train(FLAGS.data_dir).shuffle(60000).batch(
train_ds = mnist_dataset.train(FLAGS.data_dir).shuffle(60000).batch(
FLAGS.batch_size)
test_ds = dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)
test_ds = mnist_dataset.test(FLAGS.data_dir).batch(FLAGS.batch_size)

# Create the model and optimizer
model = mnist.Model(data_format)
Expand Down Expand Up @@ -159,12 +159,13 @@ def main(_):


class MNISTEagerArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model with eager trainng loop."""
"""Argument parser for running MNIST model with eager training loop."""

def __init__(self):
super(MNISTEagerArgParser, self).__init__(parents=[
parsers.BaseParser(epochs_between_evals=False, multi_gpu=False,
hooks=False),
parsers.ImageModelParser()])
parsers.BaseParser(
epochs_between_evals=False, multi_gpu=False, hooks=False),
parsers.ImageModelParser()])

self.add_argument(
'--log_interval', '-li',
Expand Down
5 changes: 3 additions & 2 deletions official/mnist/mnist_eager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow.contrib.eager as tfe
import tensorflow as tf # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order

from official.mnist import mnist
from official.mnist import mnist_eager
Expand Down Expand Up @@ -60,6 +60,7 @@ def evaluate(defun=False):


class MNISTTest(tf.test.TestCase):
"""Run tests for MNIST eager loop."""

def test_train(self):
train(defun=False)
Expand Down
7 changes: 5 additions & 2 deletions official/mnist/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import time

import tensorflow as tf # pylint: disable=g-bad-import-order

from official.mnist import mnist

BATCH_SIZE = 100
Expand All @@ -42,6 +43,7 @@ def make_estimator():


class Tests(tf.test.TestCase):
"""Run tests for MNIST model."""

def test_mnist(self):
classifier = make_estimator()
Expand All @@ -57,7 +59,7 @@ def test_mnist(self):

input_fn = lambda: tf.random_uniform([3, 784])
predictions_generator = classifier.predict(input_fn)
for i in range(3):
for _ in range(3):
predictions = next(predictions_generator)
self.assertEqual(predictions['probabilities'].shape, (10,))
self.assertEqual(predictions['classes'].shape, ())
Expand Down Expand Up @@ -103,6 +105,7 @@ def test_mnist_model_fn_predict_mode(self):


class Benchmarks(tf.test.Benchmark):
"""Simple speed benchmarking for MNIST."""

def benchmark_train_step_time(self):
classifier = make_estimator()
Expand Down
5 changes: 3 additions & 2 deletions official/mnist/mnist_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.mnist import dataset
from official.mnist import mnist

Expand Down Expand Up @@ -132,7 +133,7 @@ def main(argv):
tf.logging.set_verbosity(tf.logging.INFO)

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
Expand Down
2 changes: 1 addition & 1 deletion official/resnet/cifar10_download_and_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
help='Directory to download data and extract the tarball')


def main(unused_argv):
def main(_):
"""Download and extract the tarball from Alex's website."""
if not os.path.exists(FLAGS.data_dir):
os.makedirs(FLAGS.data_dir)
Expand Down
20 changes: 13 additions & 7 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import sys

import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.resnet import resnet_model
from official.resnet import resnet_run_loop
Expand Down Expand Up @@ -127,22 +127,25 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,

num_images = is_training and _NUM_IMAGES['train'] or _NUM_IMAGES['validation']

return resnet_run_loop.process_record_dataset(dataset, is_training, batch_size,
_NUM_IMAGES['train'], parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)
return resnet_run_loop.process_record_dataset(
dataset, is_training, batch_size, _NUM_IMAGES['train'],
parse_record, num_epochs, num_parallel_calls,
examples_per_epoch=num_images, multi_gpu=multi_gpu)


def get_synth_input_fn():
return resnet_run_loop.get_synth_input_fn(_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)
return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES)


###############################################################################
# Running the model
###############################################################################
class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data."""

def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION):
version=resnet_model.DEFAULT_VERSION):
"""These are the parameters that work for CIFAR-10 data.
Args:
Expand All @@ -153,6 +156,9 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
Raises:
ValueError: if invalid resnet_size is chosen
"""
if resnet_size % 6 != 2:
raise ValueError('resnet_size must be 6n + 2:', resnet_size)
Expand Down Expand Up @@ -195,7 +201,7 @@ def cifar10_model_fn(features, labels, mode, params):
# for the CIFAR-10 dataset, perhaps because the regularization prevents
# overfitting on the small data set. We therefore include all vars when
# regularizing and computing loss during training.
def loss_filter_fn(name):
def loss_filter_fn(_):
return True

return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,
Expand Down
14 changes: 9 additions & 5 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tempfile import mkstemp

import numpy as np
import tensorflow as tf
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.resnet import cifar10_main
from official.utils.testing import integration
Expand All @@ -34,6 +34,8 @@


class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet.
"""

def tearDown(self):
super(BaseTest, self).tearDown()
Expand All @@ -52,7 +54,7 @@ def test_dataset_input_fn(self):
data_file.close()

fake_dataset = tf.data.FixedLengthRecordDataset(
filename, cifar10_main._RECORD_BYTES)
filename, cifar10_main._RECORD_BYTES) # pylint: disable=protected-access
fake_dataset = fake_dataset.map(
lambda val: cifar10_main.parse_record(val, False))
image, label = fake_dataset.make_one_shot_iterator().get_next()
Expand Down Expand Up @@ -133,9 +135,11 @@ def test_cifar10model_shape(self):
num_classes = 246

for version in (1, 2):
model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
model = cifar10_main.Cifar10Model(
32, data_format='channels_last', num_classes=num_classes,
version=version)
fake_input = tf.random_uniform(
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)

self.assertAllEqual(output.shape, (batch_size, num_classes))
Expand Down
Loading