Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions slim/nets/densenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""
DenseNet from arXiv:1608.06993v3
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from nets import densenet_utils

slim = tf.contrib.slim
densenet_arg_scope = densenet_utils.densenet_arg_scope
DenseBlock = densenet_utils.DenseBlock
TransitionLayer = densenet_utils.TransitionLayer



def densenet(
inputs,
num_classes=1000,
n_filters_first_conv=16,
n_dense=4,
growth_rate=12,
n_layers_per_block=[6, 12, 24, 16],
dropout_p=0.2,
bottleneck=False,
compression=1.0,
is_training=False,
dense_prediction=False,
reuse=None,
scope=None):
"""
DenseNet as described for ImageNet use. Supports B (bottleneck) and
C (compression) variants.
Args:
n_classes: number of classes
n_filters_first_conv: number of filters for the first convolution applied
n_dense: number of dense_blocks
growth_rate: number of new feature maps created by each layer in a dense block
n_layers_per_block: number of layers per block. Can be an int or a list of size 2 * n_dense + 1
dropout_p: dropout rate applied after each convolution (0. for not using)
is_training: whether is training or not.
dense_prediction: Bool, defaults to False
reuse: whether or not the network and its variables should be reused. To be
able to reuse 'scope' must be given.
scope: Optional variable_scope.
Returns:
net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
if
end_points: A dictionary from components of the network to the corresponding
activation.
"""
# check n_layers_per_block argument
if type(n_layers_per_block) == list:
assert (len(n_layers_per_block) == n_dense)
elif type(n_layers_per_block) == int:
n_layers_per_block = [n_layers_per_block] * n_dense
else:
raise ValueError


with tf.variable_scope(scope, 'densenet', [inputs], reuse=reuse) as sc:
end_points_collection = sc.name + '_end_points'
with slim.arg_scope([slim.conv2d, DenseBlock, TransitionLayer],
outputs_collections=end_points_collection):
with slim.arg_scope([slim.batch_norm, slim.dropout],
is_training=is_training):

#####################
# First Convolution #
#####################
# We perform a first convolution.
# If DenseNet BC, first convolution has 2*growth_rate output channels
if bottleneck and compression < 1.0:
n_filters_first_conv = 2 * growth_rate

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this behavior should be somehow be mentioned in the densenet method documentation.

net = slim.conv2d(inputs, n_filters_first_conv, [7, 7],
stride = [2, 2], scope='first_conv')
net = slim.pool(net, [2, 2], stride= [2, 2], pooling_type='MAX')
n_filters = n_filters_first_conv

#####################
# Dense blocks #
#####################

for i in range(n_dense-1):
# Dense Block
net, _ = DenseBlock(net, n_layers_per_block[i],
growth_rate, dropout_p,
bottleneck=bottleneck,
scope='denseblock%d' % (i+1))
n_filters += n_layers_per_block[i] * growth_rate

# Transition layer
net = TransitionLayer(net, n_filters, dropout_p,
compression=compression,
scope='transition%d'%(i+1))


# Final dense block (no transition layer afterwards)
net, _ = DenseBlock(net, n_layers_per_block[n_dense-1],
growth_rate, dropout_p,
scope='denseblock%d' % (n_dense))

#####################
# Outputs #
#####################
pool_name = 'pool%d' % (n_dense + 1)
if dense_prediction:
net = slim.pool(net, [7, 7], pooling_type='AVG', scope=pool_name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we have not 224x224 input? From my point of view we should use dynamic average pooling based on input shape.

net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='logits')

else:
net = tf.reduce_mean(net, [1, 2], name=pool_name, keep_dims=True)
net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope='4Dlogits')
net = tf.squeeze(net, [1, 2], name='logits')

# Convert end_points_collection into a dictionary of end_points.
end_points = slim.utils.convert_collection_to_dict(
end_points_collection)

end_points['predictions'] = slim.softmax(net, scope='predictions')
return net, end_points
46 changes: 46 additions & 0 deletions slim/nets/densenet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for slim.nets.densenet"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just unused import

import tensorflow as tf

from nets import densenet

slim = tf.contrib.slim


class DensenetTest(tf.test.TestCase):

def testBuildClassificationNetwork(self):
batch_size = 5
height, width = 224, 224
num_classes = 1000

inputs = tf.random_uniform((batch_size, height, width, 3))
logits, end_points = densenet.densenet(inputs, num_classes)
self.assertTrue(logits.op.name.startswith('densenet/logits'))
self.assertListEqual(logits.get_shape().as_list(),
[batch_size, num_classes])
self.assertTrue('predictions' in end_points)
self.assertListEqual(end_points['predictions'].get_shape().as_list(),
[batch_size, num_classes])

if __name__ == '__main__':
tf.test.main()
162 changes: 162 additions & 0 deletions slim/nets/densenet_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""
Contains blocks for building DenseNet-based models
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


slim = tf.contrib.slim


def densenet_arg_scope(
weight_decay=0.0001,
batch_norm_decay=0.997,
batch_norm_epsilon=1e-5,
batch_norm_scale=True,
activation_fn=tf.nn.relu,
use_batch_norm=True):
"""
Args:
weight_decay: The weight decay to use for regularizing the model.

batch_norm_decay: The moving average decay when estimating layer activation
statistics in batch normalization.

batch_norm_epsilon: Small constant to prevent division by zero when
normalizing activations by their variance in batch normalization.

batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
activations in the batch normalization layer.

activation_fn: The activation function which is used in ResNet.

use_batch_norm: Whether or not to use batch normalization.

Returns:
An `arg_scope` to use for the densenet models.
"""
batch_norm_params = {
'decay': batch_norm_decay,
'epsilon': batch_norm_epsilon,
'scale': batch_norm_scale,
'activation_fn': activation_fn,
'updates_collections': tf.GraphKeys.UPDATE_OPS,
}

with slim.arg_scope(
[slim.conv2d],
padding='SAME',
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=slim.variance_scaling_initializer(),
activation_fn=None,
normalizer_fn=slim.batch_norm if use_batch_norm else None,
normalizer_params=batch_norm_params):
with slim.arg_scope([slim.batch_norm], **batch_norm_params) as arg_sc:
return arg_sc


def preact_conv(inputs, n_filters, filter_size=[3, 3], dropout_p=0.2):
"""
Basic pre-activation layer for DenseNets
Apply successivly BatchNormalization, ReLU nonlinearity, Convolution and
Dropout (if dropout_p > 0) on the inputs
"""
preact = slim.batch_norm(inputs)
conv = slim.conv2d(preact, n_filters, filter_size, normalizer_fn=None)
if dropout_p != 0.0:
conv = slim.dropout(conv, keep_prob=(1.0-dropout_p))
return conv


@slim.add_arg_scope
def DenseBlock(stack, n_layers, growth_rate, dropout_p, bottleneck=False,
scope=None, outputs_collections=None):
"""
DenseBlock for DenseNet and FC-DenseNet

Args:
stack: input 4D tensor
n_layers: number of internal layers
growth_rate: number of feature maps per internal layer

Returns:
stack: current stack of feature maps (4D tensor)
new_features: 4D tensor containing only the new feature maps generated
in this block
"""
with tf.name_scope(scope) as sc:
new_features = []
for j in range(n_layers):
# Compute new feature maps
# if bottleneck, do a 1x1 conv before the 3x3
if bottleneck:
stack = preact_conv(stack, 4*growth_rate, filter_size=[1, 1],
dropout_p=0.0)
layer = preact_conv(stack, growth_rate, dropout_p=dropout_p)
new_features.append(layer)
# stack new layer
stack = tf.concat([stack, layer], axis=-1)
new_features = tf.concat(new_features, axis=-1)
return stack, new_features

@slim.add_arg_scope
def TransitionLayer(inputs, n_filters, dropout_p=0.2, compression=1.0,
scope=None, outputs_collections=None):
"""
Transition layer for DenseNet
Apply 1x1 BN + conv then 2x2 max pooling
"""
with tf.name_scope(scope) as sc:
if compression < 1.0:
n_filters = tf.to_int32(tf.floor(n_filters*compression))
l = preact_conv(inputs, n_filters, filter_size=[1, 1], dropout_p=dropout_p)
l = slim.pool(l, [2, 2], stride=[2, 2], pooling_type='AVG')

return l

@slim.add_arg_scope
def TransitionDown(inputs, n_filters, dropout_p=0.2, scope=None,
outputs_collections=None):
"""
Transition Down (TD) for FC-DenseNet
Apply 1x1 BN + ReLU + conv then 2x2 max pooling
"""
with tf.name_scope(scope) as sc:
l = preact_conv(inputs, n_filters, filter_size=[1, 1], dropout_p=dropout_p)
l = slim.pool(l, [2, 2], stride=[2, 2], pooling_type='MAX')
return l


@slim.add_arg_scope
def TransitionUp(block_to_upsample, skip_connection, n_filters_keep,
scope=None, outputs_collections=None):
"""
Transition Up for FC-DenseNet
Performs upsampling on block_to_upsample by a factor 2 and concatenates it
with the skip_connection
"""
with tf.name_scope(scope) as sc:
# Upsample
l = slim.conv2d_transpose(block_to_upsample, n_filters_keep,
kernel_size=[3, 3], stride=[2, 2])
# Concatenate with skip connection
l = tf.concat([l, skip_connection], axis=-1)
return l
Loading