Skip to content

Commit 27224ce

Browse files
committed
Support in GroupNormLayer
1 parent be188e3 commit 27224ce

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

tensorlayer/layers/normalization.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tensorlayer.layers.core import Layer
88
from tensorlayer.layers.core import LayersConfig
99
from tensorlayer.layers.core import TF_GRAPHKEYS_VARIABLES
10+
from tensorlayer.layers.utils import get_collection_trainable
1011

1112
from tensorlayer import logging
1213

@@ -323,36 +324,57 @@ class GroupNormLayer(Layer):
323324
"""
324325

325326
@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
326-
def __init__(self, prev_layer, groups=32, epsilon=1e-06, act=None, name='groupnorm'):
327+
def __init__(self, prev_layer, groups=32, epsilon=1e-06, act=None, data_format='channels_last', name='groupnorm'):
327328
super(GroupNormLayer, self).__init__(prev_layer=prev_layer, act=act, name=name)
328329

329330
logging.info(
330331
"GroupNormLayer %s: act: %s" % (self.name, self.act.__name__ if self.act is not None else 'No Activation')
331332
)
332333

333-
channels = self.inputs.get_shape().as_list()[-1]
334-
if groups > channels:
335-
raise ValueError('Invalid groups %d for %d channels.' % (groups, channels))
336-
if channels % groups != 0:
337-
raise ValueError('%d channels is not commensurate with %d groups.' % (channels, groups))
334+
shape = self.inputs.get_shape().as_list()
335+
if len(shape) != 4:
336+
raise Exception("GroupNormLayer only supports 2D images.")
338337

339-
with tf.variable_scope(name) as vs:
338+
if data_format == 'channels_last':
339+
channels = shape[-1]
340340
int_shape = tf.concat(
341341
[tf.shape(self.inputs)[0:3],
342342
tf.convert_to_tensor([groups, channels // groups])], axis=0
343343
)
344+
elif data_format == 'channels_first':
345+
channels = shape[1]
346+
int_shape = tf.concat(
347+
[
348+
tf.shape(self.inputs)[0:1],
349+
tf.convert_to_tensor([groups, channels // groups]),
350+
tf.shape(self.inputs)[2:4]
351+
], axis=0
352+
)
353+
else:
354+
raise ValueError("data_format must be 'channels_last' or 'channels_first'.")
344355

356+
if groups > channels:
357+
raise ValueError('Invalid groups %d for %d channels.' % (groups, channels))
358+
if channels % groups != 0:
359+
raise ValueError('%d channels is not commensurate with %d groups.' % (channels, groups))
360+
361+
with tf.variable_scope(name):
345362
x = tf.reshape(self.inputs, int_shape)
346-
mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True)
347-
x = (x - mean) / tf.sqrt(var + epsilon)
363+
if data_format == 'channels_last':
364+
mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True)
365+
gamma = tf.get_variable('gamma', channels, initializer=tf.ones_initializer())
366+
beta = tf.get_variable('beta', channels, initializer=tf.zeros_initializer())
367+
else:
368+
mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
369+
gamma = tf.get_variable('gamma', [1, channels, 1, 1], initializer=tf.ones_initializer())
370+
beta = tf.get_variable('beta', [1, channels, 1, 1], initializer=tf.zeros_initializer())
348371

349-
gamma = tf.get_variable('gamma', channels, initializer=tf.ones_initializer())
350-
beta = tf.get_variable('beta', channels, initializer=tf.zeros_initializer())
372+
x = (x - mean) / tf.sqrt(var + epsilon)
351373

352374
self.outputs = tf.reshape(x, tf.shape(self.inputs)) * gamma + beta
353375
self.outputs = self._apply_activation(self.outputs)
354376

355-
variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
377+
variables = get_collection_trainable(self.name)
356378

357379
self._add_layers(self.outputs)
358380
self._add_params(variables)

tests/test_layers_normalization.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def model(x, is_train=True, reuse=False):
2121
n = tl.layers.LocalResponseNormLayer(n, name='norm_local')
2222
n = tl.layers.LayerNormLayer(n, reuse=reuse, name='norm_layer')
2323
n = tl.layers.InstanceNormLayer(n, name='norm_instance')
24-
n = tl.layers.GroupNormLayer(n, groups=40, name='groupnorm')
24+
# n = tl.layers.GroupNormLayer(n, groups=40, name='groupnorm')
25+
n.outputs = tf.reshape(n.outputs, [-1, 80, 100, 100])
26+
n = tl.layers.GroupNormLayer(n, groups=40, data_format='channels_first', name='groupnorm')
27+
n.outputs = tf.reshape(n.outputs, [-1, 100, 100, 80])
2528
n = tl.layers.SwitchNormLayer(n, name='switchnorm')
2629
n = tl.layers.QuanConv2dWithBN(n, n_filter=3, is_train=is_train, name='quan_cnn_with_bn')
2730
n = tl.layers.FlattenLayer(n, name='flatten')

0 commit comments

Comments
 (0)