Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions slim/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,50 @@ py_library(
":inception_v1",
":inception_v2",
":inception_v3",
":inception_v4",
],
)

py_library(
name = "inception_utils",
srcs = ["nets/inception_utils.py"],
srcs_version = "PY2AND3",
)

py_library(
name = "inception_v1",
srcs = ["nets/inception_v1.py"],
srcs_version = "PY2AND3",
deps = [
":inception_utils",
],
)

py_library(
name = "inception_v2",
srcs = ["nets/inception_v2.py"],
srcs_version = "PY2AND3",
deps = [
":inception_utils",
],
)

py_library(
name = "inception_v3",
srcs = ["nets/inception_v3.py"],
srcs_version = "PY2AND3",
deps = [
":inception_utils",
],
)

py_library(
name = "inception_v4",
srcs = ["nets/inception_v4.py"],
srcs_version = "PY2AND3",
deps = [
":inception_utils",
],
)

py_library(
Expand Down Expand Up @@ -218,6 +243,15 @@ py_test(
deps = [":inception"],
)

py_test(
name = "inception_v4_test",
size = "large",
srcs = ["nets/inception_v4_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [":inception"],
)

py_test(
name = "inception_resnet_v2_test",
size = "large",
Expand Down
7 changes: 4 additions & 3 deletions slim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,10 @@ crops at multiple scales.

Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
:----:|:------------:|:----------:|:-------:|:--------:|
[Inception V1](http://arxiv.org/abs/1409.4842v1)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v1.py)|[inception_v1.tar.gz](http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz)|69.8|89.6|
[Inception V2](http://arxiv.org/abs/1502.03167)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v2.py)|[inception_v2.tar.gz](http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz)|73.9|91.8|
[Inception V3](http://arxiv.org/abs/1512.00567)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v3.py)|[inception_v3.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)|78.0|93.9|
[Inception V1](http://arxiv.org/abs/1409.4842v1)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v1.py)|[inception_v1_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz)|69.8|89.6|
[Inception V2](http://arxiv.org/abs/1502.03167)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v2.py)|[inception_v2_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz)|73.9|91.8|
[Inception V3](http://arxiv.org/abs/1512.00567)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v3.py)|[inception_v3_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)|78.0|93.9|
[Inception V4](http://arxiv.org/abs/1602.07261)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v4.py)|[inception_v4_2016_09_09.tar.gz](http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz)|80.2|95.2|
[Inception-ResNet-v2](http://arxiv.org/abs/1602.07261)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_resnet_v2.py)|[inception_resnet_v2.tar.gz](http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz)|80.4|95.3|
[ResNet 50](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_50.tar.gz](http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz)|75.2|92.2|
[ResNet 101](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_101.tar.gz](http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz)|76.4|92.9|
Expand Down
5 changes: 4 additions & 1 deletion slim/nets/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Brings inception_v1, inception_v2 and inception_v3 under one namespace."""
"""Brings all inception models under one namespace."""

from __future__ import absolute_import
from __future__ import division
Expand All @@ -30,4 +30,7 @@
from nets.inception_v3 import inception_v3
from nets.inception_v3 import inception_v3_arg_scope
from nets.inception_v3 import inception_v3_base
from nets.inception_v4 import inception_v4
from nets.inception_v4 import inception_v4_arg_scope
from nets.inception_v4 import inception_v4_base
# pylint: enable=unused-import
71 changes: 71 additions & 0 deletions slim/nets/inception_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common code shared by all inception models.

Usage of arg scope:
with slim.arg_scope(inception_arg_scope()):
logits, end_points = inception.inception_v3(images, num_classes,
is_training=is_training)

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

slim = tf.contrib.slim


def inception_arg_scope(weight_decay=0.00004,
use_batch_norm=True,
batch_norm_decay=0.9997,
batch_norm_epsilon=0.001):
"""Defines the default arg scope for inception models.

Args:
weight_decay: The weight decay to use for regularizing the model.
use_batch_norm: "If `True`, batch_norm is applied after each convolution.
batch_norm_decay: Decay for batch norm moving average.
batch_norm_epsilon: Small float added to variance to avoid dividing by zero
in batch norm.

Returns:
An `arg_scope` to use for the inception models.
"""
batch_norm_params = {
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
}
if use_batch_norm:
normalizer_fn = slim.batch_norm
normalizer_params = batch_norm_params
else:
normalizer_fn = None
normalizer_params = {}
# Set weight_decay for weights in Conv and FC layers.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params) as sc:
return sc
41 changes: 3 additions & 38 deletions slim/nets/inception_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import tensorflow as tf

from nets import inception_utils

slim = tf.contrib.slim
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)

Expand Down Expand Up @@ -300,41 +302,4 @@ def inception_v1(inputs,
return logits, end_points
inception_v1.default_image_size = 224


def inception_v1_arg_scope(weight_decay=0.00004,
use_batch_norm=True):
"""Defines the default InceptionV1 arg scope.

Note: Althougth the original paper didn't use batch_norm we found it useful.

Args:
weight_decay: The weight decay to use for regularizing the model.
use_batch_norm: "If `True`, batch_norm is applied after each convolution.

Returns:
An `arg_scope` to use for the inception v3 model.
"""
batch_norm_params = {
# Decay for the moving averages.
'decay': 0.9997,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
}
if use_batch_norm:
normalizer_fn = slim.batch_norm
normalizer_params = batch_norm_params
else:
normalizer_fn = None
normalizer_params = {}
# Set weight_decay for weights in Conv and FC layers.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=normalizer_fn,
normalizer_params=normalizer_params) as sc:
return sc
inception_v1_arg_scope = inception_utils.inception_arg_scope
31 changes: 3 additions & 28 deletions slim/nets/inception_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import tensorflow as tf

from nets import inception_utils

slim = tf.contrib.slim
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)

Expand Down Expand Up @@ -515,31 +517,4 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
return kernel_size_out


def inception_v2_arg_scope(weight_decay=0.00004):
"""Defines the default InceptionV2 arg scope.

Args:
weight_decay: The weight decay to use for regularizing the model.

Returns:
An `arg_scope` to use for the inception v3 model.
"""
batch_norm_params = {
# Decay for the moving averages.
'decay': 0.9997,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
}

# Set weight_decay for weights in Conv and FC layers.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params) as sc:
return sc
inception_v2_arg_scope = inception_utils.inception_arg_scope
33 changes: 3 additions & 30 deletions slim/nets/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import tensorflow as tf

from nets import inception_utils

slim = tf.contrib.slim
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)

Expand Down Expand Up @@ -555,33 +557,4 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
return kernel_size_out


def inception_v3_arg_scope(weight_decay=0.00004,
stddev=0.1):
"""Defines the default InceptionV3 arg scope.

Args:
weight_decay: The weight decay to use for regularizing the model.
stddev: The standard deviation of the trunctated normal weight initializer.

Returns:
An `arg_scope` to use for the inception v3 model.
"""
batch_norm_params = {
# Decay for the moving averages.
'decay': 0.9997,
# epsilon to prevent 0s in variance.
'epsilon': 0.001,
# collection containing update_ops.
'updates_collections': tf.GraphKeys.UPDATE_OPS,
}

# Set weight_decay for weights in Conv and FC layers.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay)):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params) as sc:
return sc
inception_v3_arg_scope = inception_utils.inception_arg_scope
Loading