Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion research/object_detection/model_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def create_train_and_eval_specs(train_input_fn,
input_fn=train_input_fn, max_steps=train_steps)

if eval_spec_names is None:
eval_spec_names = range(len(eval_input_fns))
eval_spec_names = [ str(i) for i in range(len(eval_input_fns)) ]

eval_specs = []
for eval_spec_name, eval_input_fn in zip(eval_spec_names, eval_input_fns):
Expand Down
20 changes: 14 additions & 6 deletions research/slim/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ py_binary(
],
)

py_binary(
py_library(
name = "cifar10",
srcs = ["datasets/cifar10.py"],
deps = [
Expand All @@ -89,7 +89,7 @@ py_binary(
],
)

py_binary(
py_library(
name = "flowers",
srcs = ["datasets/flowers.py"],
deps = [
Expand All @@ -98,7 +98,7 @@ py_binary(
],
)

py_binary(
py_library(
name = "imagenet",
srcs = ["datasets/imagenet.py"],
deps = [
Expand All @@ -107,7 +107,7 @@ py_binary(
],
)

py_binary(
py_library(
name = "mnist",
srcs = ["datasets/mnist.py"],
deps = [
Expand Down Expand Up @@ -715,8 +715,8 @@ py_binary(
],
)

py_binary(
name = "eval_image_classifier",
py_library(
name = "eval_image_classifier_lib",
srcs = ["eval_image_classifier.py"],
deps = [
":dataset_factory",
Expand All @@ -726,6 +726,14 @@ py_binary(
],
)

py_binary(
name = "eval_image_classifier",
srcs = ["eval_image_classifier.py"],
deps = [
":eval_image_classifier_lib",
],
)

py_binary(
name = "export_inference_graph",
srcs = ["export_inference_graph.py"],
Expand Down
6 changes: 6 additions & 0 deletions research/slim/eval_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@
tf.app.flags.DEFINE_integer(
'eval_image_size', None, 'Eval image size')

tf.app.flags.DEFINE_bool(
'quantize', False, 'whether to use quantized graph or not.')

FLAGS = tf.app.flags.FLAGS


Expand Down Expand Up @@ -138,6 +141,9 @@ def main(_):
####################
logits, _ = network_fn(images)

if FLAGS.quantize:
tf.contrib.quantize.create_eval_graph()

if FLAGS.moving_average_decay:
variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, tf_global_step)
Expand Down
7 changes: 7 additions & 0 deletions research/slim/export_inference_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def with the variables inlined as constants using:
tf.app.flags.DEFINE_string(
'dataset_dir', '', 'Directory to save intermediate dataset files to')

tf.app.flags.DEFINE_bool(
'quantize', False, 'whether to use quantized graph or not.')

FLAGS = tf.app.flags.FLAGS


Expand All @@ -115,6 +118,10 @@ def main(_):
shape=[FLAGS.batch_size, image_size,
image_size, 3])
network_fn(placeholder)

if FLAGS.quantize:
tf.contrib.quantize.create_eval_graph()

graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
Expand Down
11 changes: 11 additions & 0 deletions research/slim/nets/mobilenet/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def mobilenet(input_tensor,
finegrain_classification_mode=False,
min_depth=None,
divisible_by=None,
activation_fn=None,
**kwargs):
"""Creates mobilenet V2 network.

Expand All @@ -117,6 +118,8 @@ def mobilenet(input_tensor,
many channels after application of depth multiplier.
divisible_by: If provided will ensure that all layers # channels
will be divisible by this number.
activation_fn: Activation function to use, defaults to tf.nn.relu6 if not
specified.
**kwargs: passed directly to mobilenet.mobilenet:
prediction_fn- what prediction function to use.
reuse-: whether to reuse variables (if reuse set to true, scope
Expand All @@ -136,6 +139,12 @@ def mobilenet(input_tensor,
conv_defs = copy.deepcopy(conv_defs)
if depth_multiplier < 1:
conv_defs['spec'][-1].params['num_outputs'] /= depth_multiplier
if activation_fn:
conv_defs = copy.deepcopy(conv_defs)
defaults = conv_defs['defaults']
conv_defaults = (
defaults[(slim.conv2d, slim.fully_connected, slim.separable_conv2d)])
conv_defaults['activation_fn'] = activation_fn

depth_args = {}
# NB: do not set depth_args unless they are provided to avoid overriding
Expand All @@ -154,6 +163,8 @@ def mobilenet(input_tensor,
multiplier=depth_multiplier,
**kwargs)

mobilenet.default_image_size = 224


def wrapped_partial(func, *args, **kwargs):
partial_func = functools.partial(func, *args, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions research/slim/nets/mobilenet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])

# _CONV_DEFS specifies the MobileNet body
_CONV_DEFS = [
# MOBILENETV1_CONV_DEFS specifies the MobileNet body
MOBILENETV1_CONV_DEFS = [
Conv(kernel=[3, 3], stride=2, depth=32),
DepthSepConv(kernel=[3, 3], stride=1, depth=64),
DepthSepConv(kernel=[3, 3], stride=2, depth=128),
Expand Down Expand Up @@ -221,7 +221,7 @@ def mobilenet_v1_base(inputs,
raise ValueError('depth_multiplier is not greater than zero.')

if conv_defs is None:
conv_defs = _CONV_DEFS
conv_defs = MOBILENETV1_CONV_DEFS

if output_stride is not None and output_stride not in [8, 16, 32]:
raise ValueError('Only allowed output_stride values are 8, 16, 32.')
Expand Down
32 changes: 21 additions & 11 deletions research/slim/nets/nasnet/nasnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _build_aux_head(net, end_points, num_classes, hparams, scope):
end_points['AuxLogits'] = aux_logits


def _imagenet_stem(inputs, hparams, stem_cell):
def _imagenet_stem(inputs, hparams, stem_cell, current_step=None):
"""Stem used for models trained on ImageNet."""
num_stem_cells = 2

Expand All @@ -266,7 +266,8 @@ def _imagenet_stem(inputs, hparams, stem_cell):
filter_scaling=filter_scaling,
stride=2,
prev_layer=cell_outputs[-2],
cell_num=cell_num)
cell_num=cell_num,
current_step=current_step)
cell_outputs.append(net)
filter_scaling *= hparams.filter_scaling_rate
return net, cell_outputs
Expand All @@ -286,7 +287,8 @@ def _cifar_stem(inputs, hparams):

def build_nasnet_cifar(images, num_classes,
is_training=True,
config=None):
config=None,
current_step=None):
"""Build NASNet model for the Cifar Dataset."""
hparams = cifar_config() if config is None else copy.deepcopy(config)
_update_hparams(hparams, is_training)
Expand Down Expand Up @@ -326,14 +328,16 @@ def build_nasnet_cifar(images, num_classes,
num_classes=num_classes,
hparams=hparams,
is_training=is_training,
stem_type='cifar')
stem_type='cifar',
current_step=current_step)
build_nasnet_cifar.default_image_size = 32


def build_nasnet_mobile(images, num_classes,
is_training=True,
final_endpoint=None,
config=None):
config=None,
current_step=None):
"""Build NASNet Mobile model for the ImageNet Dataset."""
hparams = (mobile_imagenet_config() if config is None
else copy.deepcopy(config))
Expand Down Expand Up @@ -377,14 +381,16 @@ def build_nasnet_mobile(images, num_classes,
hparams=hparams,
is_training=is_training,
stem_type='imagenet',
final_endpoint=final_endpoint)
final_endpoint=final_endpoint,
current_step=current_step)
build_nasnet_mobile.default_image_size = 224


def build_nasnet_large(images, num_classes,
is_training=True,
final_endpoint=None,
config=None):
config=None,
current_step=None):
"""Build NASNet Large model for the ImageNet Dataset."""
hparams = (large_imagenet_config() if config is None
else copy.deepcopy(config))
Expand Down Expand Up @@ -428,7 +434,8 @@ def build_nasnet_large(images, num_classes,
hparams=hparams,
is_training=is_training,
stem_type='imagenet',
final_endpoint=final_endpoint)
final_endpoint=final_endpoint,
current_step=current_step)
build_nasnet_large.default_image_size = 331


Expand All @@ -439,7 +446,8 @@ def _build_nasnet_base(images,
hparams,
is_training,
stem_type,
final_endpoint=None):
final_endpoint=None,
current_step=None):
"""Constructs a NASNet image model."""

end_points = {}
Expand Down Expand Up @@ -482,7 +490,8 @@ def add_and_check_endpoint(endpoint_name, net):
filter_scaling=filter_scaling,
stride=2,
prev_layer=cell_outputs[-2],
cell_num=true_cell_num)
cell_num=true_cell_num,
current_step=current_step)
if add_and_check_endpoint(
'Reduction_Cell_{}'.format(reduction_indices.index(cell_num)), net):
return net, end_points
Expand All @@ -496,7 +505,8 @@ def add_and_check_endpoint(endpoint_name, net):
filter_scaling=filter_scaling,
stride=stride,
prev_layer=prev_layer,
cell_num=true_cell_num)
cell_num=true_cell_num,
current_step=current_step)

if add_and_check_endpoint('Cell_{}'.format(cell_num), net):
return net, end_points
Expand Down
18 changes: 18 additions & 0 deletions research/slim/nets/nasnet/nasnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,24 @@ def testOverrideHParamsLargeModel(self):
self.assertListEqual(
end_points['Stem'].shape.as_list(), [batch_size, 336, 42, 42])

def testCurrentStepCifarModel(self):
batch_size = 5
height, width = 32, 32
num_classes = 10
inputs = tf.random_uniform((batch_size, height, width, 3))
global_step = tf.train.create_global_step()
with slim.arg_scope(nasnet.nasnet_cifar_arg_scope()):
logits, end_points = nasnet.build_nasnet_cifar(inputs,
num_classes,
current_step=global_step)
auxlogits = end_points['AuxLogits']
predictions = end_points['Predictions']
self.assertListEqual(auxlogits.get_shape().as_list(),
[batch_size, num_classes])
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertListEqual(predictions.get_shape().as_list(),
[batch_size, num_classes])

if __name__ == '__main__':
tf.test.main()
18 changes: 10 additions & 8 deletions research/slim/nets/nasnet/nasnet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _cell_base(self, net, prev_layer):
return net

def __call__(self, net, scope=None, filter_scaling=1, stride=1,
prev_layer=None, cell_num=-1):
prev_layer=None, cell_num=-1, current_step=None):
"""Runs the conv cell."""
self._cell_num = cell_num
self._filter_scaling = filter_scaling
Expand All @@ -325,10 +325,12 @@ def __call__(self, net, scope=None, filter_scaling=1, stride=1,
# Apply conv operations
with tf.variable_scope('left'):
h1 = self._apply_conv_operation(h1, operation_left,
stride, original_input_left)
stride, original_input_left,
current_step)
with tf.variable_scope('right'):
h2 = self._apply_conv_operation(h2, operation_right,
stride, original_input_right)
stride, original_input_right,
current_step)

# Combine hidden states using 'add'.
with tf.variable_scope('combine'):
Expand All @@ -343,7 +345,7 @@ def __call__(self, net, scope=None, filter_scaling=1, stride=1,
return net

def _apply_conv_operation(self, net, operation,
stride, is_from_original_input):
stride, is_from_original_input, current_step):
"""Applies the predicted conv operation to net."""
# Dont stride if this is not one of the original hiddenstates
if stride > 1 and not is_from_original_input:
Expand All @@ -367,7 +369,7 @@ def _apply_conv_operation(self, net, operation,
raise ValueError('Unimplemented operation', operation)

if operation != 'none':
net = self._apply_drop_path(net)
net = self._apply_drop_path(net, current_step=current_step)
return net

def _combine_unused_states(self, net):
Expand Down Expand Up @@ -433,9 +435,9 @@ def _apply_drop_path(self, net, current_step=None,
drop_path_keep_prob = 1 - layer_ratio * (1 - drop_path_keep_prob)
if drop_connect_version in ['v1', 'v3']:
# Decrease the keep probability over time
if not current_step:
current_step = tf.cast(tf.train.get_or_create_global_step(),
tf.float32)
if current_step is None:
current_step = tf.train.get_or_create_global_step()
current_step = tf.cast(current_step, tf.float32)
drop_path_burn_in_steps = self._total_training_steps
current_ratio = current_step / drop_path_burn_in_steps
current_ratio = tf.minimum(1.0, current_ratio)
Expand Down
9 changes: 9 additions & 0 deletions research/slim/train_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@

tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')

tf.app.flags.DEFINE_integer(
'quantize_delay', -1,
'Number of steps to start quantized training. Set to -1 would disable '
'quantized training.')

#######################
# Learning Rate Flags #
#######################
Expand Down Expand Up @@ -511,6 +516,10 @@ def clone_fn(batch_queue):
else:
moving_average_variables, variable_averages = None, None

if FLAGS.quantize_delay >= 0:
tf.contrib.quantize.create_training_graph(
quant_delay=FLAGS.quantize_delay)

#########################################
# Configure the optimization procedure. #
#########################################
Expand Down