Skip to content

Commit

Permalink
add BNReLU as a nonlin
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 11, 2016
1 parent 7d9582a commit fec3a4a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
12 changes: 3 additions & 9 deletions examples/cifar10_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,15 @@ def _get_cost(self, input_vars, is_training):

image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn1', l, is_training)
l = tf.nn.relu(l)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training))
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')

l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=tf.identity)
l = BatchNorm('bn2', l, is_training)
l = tf.nn.relu(l)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training))
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')

l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID')
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=tf.identity)
l = BatchNorm('bn3', l, is_training)
l = tf.nn.relu(l)
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training))
l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1))
l = tf.nn.dropout(l, keep_prob)
Expand Down
15 changes: 14 additions & 1 deletion tensorpack/models/nonlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from copy import copy

from ._common import *
from .batch_norm import BatchNorm

__all__ = ['Maxout', 'PReLU', 'LeakyReLU']
__all__ = ['Maxout', 'PReLU', 'LeakyReLU', 'BNReLU']

@layer_register()
def Maxout(x, num_unit):
Expand Down Expand Up @@ -59,3 +60,15 @@ def LeakyReLU(x, alpha, name=None):
return x * 0.5
else:
return tf.mul(x, 0.5, name=name)


def BNReLU(is_training):
"""
:returns: a activation function that performs BN + ReLU (a too common combination)
"""
def f(x, name=None):
with tf.variable_scope('bn'):
x = BatchNorm.f(x, is_training)
x = tf.nn.relu(x, name=name)
return x
return f

0 comments on commit fec3a4a

Please sign in to comment.