Skip to content

Commit

Permalink
Latest changes to the official models
Browse files Browse the repository at this point in the history
  • Loading branch information
nealwu committed Sep 21, 2017
1 parent 2c5c3f3 commit 2bbe262
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 6 deletions.
1 change: 0 additions & 1 deletion official/.gitignore
@@ -1,3 +1,2 @@
cnn/data
MNIST-data
labels.txt
7 changes: 6 additions & 1 deletion official/mnist/convert_to_records.py
Expand Up @@ -13,7 +13,12 @@
# limitations under the License.
# ==============================================================================

"""Converts MNIST data to TFRecords file format with Example protos."""
"""Converts MNIST data to TFRecords file format with Example protos.
To read about optimizations that can be applied to the input preprocessing
stage, see: https://www.tensorflow.org/performance/performance_guide#input_pipeline_optimization.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down
1 change: 1 addition & 0 deletions official/mnist/mnist.py
Expand Up @@ -92,6 +92,7 @@ def mnist_model(inputs, mode):
if tf.test.is_built_with_cuda():
# When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance.
# See https://www.tensorflow.org/performance/performance_guide#data_formats
data_format = 'channels_first'
inputs = tf.transpose(inputs, [0, 3, 1, 2])

Expand Down
3 changes: 2 additions & 1 deletion official/resnet/cifar10_download_and_extract.py
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
# ==============================================================================

"""Converts MNIST data to TFRecords file format with Example protos."""
"""Downloads and extracts the binary version of the CIFAR-10 dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down
3 changes: 3 additions & 0 deletions official/resnet/imagenet.py
Expand Up @@ -27,6 +27,9 @@
WARNING: Don't use for object detection, in this case all the bounding boxes
of the image belong to just one class.
To read about optimizations that can be applied to the input preprocessing
stage, see: https://www.tensorflow.org/performance/performance_guide#input_pipeline_optimization.
"""

from __future__ import absolute_import
Expand Down
6 changes: 3 additions & 3 deletions official/resnet/resnet_model.py
Expand Up @@ -41,6 +41,7 @@
def batch_norm_relu(inputs, is_training, data_format):
"""Performs a batch normalization followed by a ReLU."""
# We set fused=True for a significant performance boost.
# See https://www.tensorflow.org/performance/performance_guide#common_fused_ops
inputs = tf.layers.batch_normalization(
inputs=inputs, axis=1 if data_format == 'channels_first' else 3,
momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True,
Expand Down Expand Up @@ -240,6 +241,7 @@ def model(inputs, is_training):
if data_format == 'channels_first':
# Convert from channels_last (NHWC) to channels_first (NCHW). This
# provides a large performance boost on GPU.
# See https://www.tensorflow.org/performance/performance_guide#data_formats
inputs = tf.transpose(inputs, [0, 3, 1, 2])

inputs = conv2d_fixed_padding(
Expand All @@ -261,14 +263,12 @@ def model(inputs, is_training):
data_format=data_format)

inputs = batch_norm_relu(inputs, is_training, data_format)

inputs = tf.layers.average_pooling2d(
inputs=inputs, pool_size=8, strides=1, padding='VALID',
data_format=data_format)
inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs, [-1, 64])
inputs = tf.layers.dense(
inputs=inputs, units=num_classes)
inputs = tf.layers.dense(inputs=inputs, units=num_classes)
inputs = tf.identity(inputs, 'final_dense')
return inputs

Expand Down

0 comments on commit 2bbe262

Please sign in to comment.