Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions official/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ You can download 190 MB pre-trained versions of ResNet-50 achieving 76.3% and 75

Other versions and formats:

* [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv2_imagenet_checkpoint.tar.gz)
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv2_imagenet_savedmodel.tar.gz)
* [ResNet-v2-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv2_imagenet_frozen_graph.pb)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnetv1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnetv1_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Frozen Graph](http://download.tensorflow.org/models/official/resnetv1_imagenet_frozen_graph.pb)
* [ResNet-v2-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v2_imagenet_checkpoint.tar.gz)
* [ResNet-v2-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v2_imagenet_savedmodel.tar.gz)
* [ResNet-v1-ImageNet Checkpoint](http://download.tensorflow.org/models/official/resnet_v1_imagenet_checkpoint.tar.gz)
* [ResNet-v1-ImageNet SavedModel](http://download.tensorflow.org/models/official/resnet_v1_imagenet_savedmodel.tar.gz)
33 changes: 22 additions & 11 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data."""

def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION):
version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data.

Args:
Expand All @@ -156,6 +157,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.

Raises:
ValueError: if invalid resnet_size is chosen
Expand All @@ -180,7 +182,9 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
block_strides=[1, 2, 2],
final_size=64,
version=version,
data_format=data_format)
data_format=data_format,
dtype=dtype
)


def cifar10_model_fn(features, labels, mode, params):
Expand All @@ -204,15 +208,22 @@ def cifar10_model_fn(features, labels, mode, params):
def loss_filter_fn(_):
return True

return resnet_run_loop.resnet_model_fn(features, labels, mode, Cifar10Model,
resnet_size=params['resnet_size'],
weight_decay=weight_decay,
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'])
return resnet_run_loop.resnet_model_fn(
features=features,
labels=labels,
mode=mode,
model_class=Cifar10Model,
resnet_size=params['resnet_size'],
weight_decay=weight_decay,
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=loss_filter_fn,
multi_gpu=params['multi_gpu'],
dtype=params['dtype']
)


def main(argv):
Expand Down
106 changes: 66 additions & 40 deletions official/resnet/cifar10_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,38 +71,61 @@ def test_dataset_input_fn(self):
for pixel in row:
self.assertAllClose(pixel, np.array([-1.225, 0., 1.225]), rtol=1e-3)

def _cifar10_model_fn_helper(self, mode, version, dtype, multi_gpu=False):
with tf.Graph().as_default() as g:
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'dtype': dtype,
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'loss_scale': 128 if dtype == tf.float16 else 1,
'multi_gpu': multi_gpu
})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)

if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)

if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)

for v in tf.trainable_variables():
self.assertEqual(v.dtype.base_dtype, tf.float32)

tensors_to_check = ('initial_conv:0', 'block_layer1:0', 'block_layer2:0',
'block_layer3:0', 'final_reduce_mean:0',
'final_dense:0')

for tensor_name in tensors_to_check:
tensor = g.get_tensor_by_name('resnet_model/' + tensor_name)
self.assertEqual(tensor.dtype, dtype,
'Tensor {} has dtype {}, while dtype {} was '
'expected'.format(tensor, tensor.dtype,
dtype))

def cifar10_model_fn_helper(self, mode, version, multi_gpu=False):
input_fn = cifar10_main.get_synth_input_fn()
dataset = input_fn(True, '', _BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
spec = cifar10_main.cifar10_model_fn(
features, labels, mode, {
'resnet_size': 32,
'data_format': 'channels_last',
'batch_size': _BATCH_SIZE,
'version': version,
'multi_gpu': multi_gpu
})

predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape,
(_BATCH_SIZE, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (_BATCH_SIZE,))
self.assertEqual(predictions['classes'].dtype, tf.int64)

if mode != tf.estimator.ModeKeys.PREDICT:
loss = spec.loss
self.assertAllEqual(loss.shape, ())
self.assertEqual(loss.dtype, tf.float32)

if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = spec.eval_metric_ops
self.assertAllEqual(eval_metric_ops['accuracy'][0].shape, ())
self.assertAllEqual(eval_metric_ops['accuracy'][1].shape, ())
self.assertEqual(eval_metric_ops['accuracy'][0].dtype, tf.float32)
self.assertEqual(eval_metric_ops['accuracy'][1].dtype, tf.float32)
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float32,
multi_gpu=multi_gpu)
self._cifar10_model_fn_helper(mode=mode, version=version, dtype=tf.float16,
multi_gpu=multi_gpu)

def test_cifar10_model_fn_train_mode_v1(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.TRAIN, version=1)
Expand Down Expand Up @@ -130,19 +153,22 @@ def test_cifar10_model_fn_predict_mode_v1(self):
def test_cifar10_model_fn_predict_mode_v2(self):
self.cifar10_model_fn_helper(tf.estimator.ModeKeys.PREDICT, version=2)

def test_cifar10model_shape(self):
def _test_cifar10model_shape(self, version):
batch_size = 135
num_classes = 246

for version in (1, 2):
model = cifar10_main.Cifar10Model(
32, data_format='channels_last', num_classes=num_classes,
version=version)
fake_input = tf.random_uniform(
[batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)
model = cifar10_main.Cifar10Model(32, data_format='channels_last',
num_classes=num_classes, version=version)
fake_input = tf.random_uniform([batch_size, _HEIGHT, _WIDTH, _NUM_CHANNELS])
output = model(fake_input, training=True)

self.assertAllEqual(output.shape, (batch_size, num_classes))

def test_cifar10model_shape_v1(self):
self._test_cifar10model_shape(version=1)

self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10model_shape_v2(self):
self._test_cifar10model_shape(version=2)

def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(
Expand Down
33 changes: 22 additions & 11 deletions official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data."""

def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
version=resnet_model.DEFAULT_VERSION):
version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data.

Args:
Expand All @@ -214,6 +215,7 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
enables users to extend the same model to their own datasets.
version: Integer representing which version of the ResNet network to use.
See README for details. Valid values: [1, 2]
dtype: The TensorFlow dtype to use for calculations.
"""

# For bigger models, we want to use "bottleneck" layers
Expand All @@ -239,7 +241,9 @@ def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES,
block_strides=[1, 2, 2, 2],
final_size=final_size,
version=version,
data_format=data_format)
data_format=data_format,
dtype=dtype
)


def _get_block_sizes(resnet_size):
Expand Down Expand Up @@ -283,15 +287,22 @@ def imagenet_model_fn(features, labels, mode, params):
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4])

return resnet_run_loop.resnet_model_fn(features, labels, mode, ImagenetModel,
resnet_size=params['resnet_size'],
weight_decay=1e-4,
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_filter_fn=None,
multi_gpu=params['multi_gpu'])
return resnet_run_loop.resnet_model_fn(
features=features,
labels=labels,
mode=mode,
model_class=ImagenetModel,
resnet_size=params['resnet_size'],
weight_decay=1e-4,
learning_rate_fn=learning_rate_fn,
momentum=0.9,
data_format=params['data_format'],
version=params['version'],
loss_scale=params['loss_scale'],
loss_filter_fn=None,
multi_gpu=params['multi_gpu'],
dtype=params['dtype']
)


def main(argv):
Expand Down
Loading