Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion research/slim/export_inference_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def with the variables inlined as constants using:
'image_size', None,
'The image size to use, otherwise use the model default_image_size.')

tf.app.flags.DEFINE_integer(
'image_channel', 3,
'The image channel to use, otherwise use the model default_image_channel.')

tf.app.flags.DEFINE_integer(
'batch_size', None,
'Batch size for the exported model. Defaulted to "None" so batch size can '
Expand Down Expand Up @@ -114,9 +118,11 @@ def main(_):
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
image_size = FLAGS.image_size or network_fn.default_image_size
image_channel = network_fn.default_image_channel \
if hasattr(network_fn, "default_image_channel") else FLAGS.image_channel
placeholder = tf.placeholder(name='input', dtype=tf.float32,
shape=[FLAGS.batch_size, image_size,
image_size, 3])
image_size, image_channel])
network_fn(placeholder)

if FLAGS.quantize:
Expand Down
1 change: 1 addition & 0 deletions research/slim/nets/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def lenet(images, num_classes=10, is_training=False,

return logits, end_points
lenet.default_image_size = 28
lenet.default_image_channel = 1


def lenet_arg_scope(weight_decay=0.0):
Expand Down