|
7 | 7 | from tensorlayer.layers.core import Layer
|
8 | 8 | from tensorlayer.layers.core import LayersConfig
|
9 | 9 | from tensorlayer.layers.core import TF_GRAPHKEYS_VARIABLES
|
| 10 | +from tensorlayer.layers.utils import get_collection_trainable |
10 | 11 |
|
11 | 12 | from tensorlayer import logging
|
12 | 13 |
|
@@ -323,36 +324,57 @@ class GroupNormLayer(Layer):
|
323 | 324 | """
|
324 | 325 |
|
325 | 326 | @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
|
326 |
| - def __init__(self, prev_layer, groups=32, epsilon=1e-06, act=None, name='groupnorm'): |
| 327 | + def __init__(self, prev_layer, groups=32, epsilon=1e-06, act=None, data_format='channels_last', name='groupnorm'): |
327 | 328 | super(GroupNormLayer, self).__init__(prev_layer=prev_layer, act=act, name=name)
|
328 | 329 |
|
329 | 330 | logging.info(
|
330 | 331 | "GroupNormLayer %s: act: %s" % (self.name, self.act.__name__ if self.act is not None else 'No Activation')
|
331 | 332 | )
|
332 | 333 |
|
333 |
| - channels = self.inputs.get_shape().as_list()[-1] |
334 |
| - if groups > channels: |
335 |
| - raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) |
336 |
| - if channels % groups != 0: |
337 |
| - raise ValueError('%d channels is not commensurate with %d groups.' % (channels, groups)) |
| 334 | + shape = self.inputs.get_shape().as_list() |
| 335 | + if len(shape) != 4: |
| 336 | + raise Exception("GroupNormLayer only supports 2D images.") |
338 | 337 |
|
339 |
| - with tf.variable_scope(name) as vs: |
| 338 | + if data_format == 'channels_last': |
| 339 | + channels = shape[-1] |
340 | 340 | int_shape = tf.concat(
|
341 | 341 | [tf.shape(self.inputs)[0:3],
|
342 | 342 | tf.convert_to_tensor([groups, channels // groups])], axis=0
|
343 | 343 | )
|
| 344 | + elif data_format == 'channels_first': |
| 345 | + channels = shape[1] |
| 346 | + int_shape = tf.concat( |
| 347 | + [ |
| 348 | + tf.shape(self.inputs)[0:1], |
| 349 | + tf.convert_to_tensor([groups, channels // groups]), |
| 350 | + tf.shape(self.inputs)[2:4] |
| 351 | + ], axis=0 |
| 352 | + ) |
| 353 | + else: |
| 354 | + raise ValueError("data_format must be 'channels_last' or 'channels_first'.") |
344 | 355 |
|
| 356 | + if groups > channels: |
| 357 | + raise ValueError('Invalid groups %d for %d channels.' % (groups, channels)) |
| 358 | + if channels % groups != 0: |
| 359 | + raise ValueError('%d channels is not commensurate with %d groups.' % (channels, groups)) |
| 360 | + |
| 361 | + with tf.variable_scope(name): |
345 | 362 | x = tf.reshape(self.inputs, int_shape)
|
346 |
| - mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True) |
347 |
| - x = (x - mean) / tf.sqrt(var + epsilon) |
| 363 | + if data_format == 'channels_last': |
| 364 | + mean, var = tf.nn.moments(x, [1, 2, 4], keep_dims=True) |
| 365 | + gamma = tf.get_variable('gamma', channels, initializer=tf.ones_initializer()) |
| 366 | + beta = tf.get_variable('beta', channels, initializer=tf.zeros_initializer()) |
| 367 | + else: |
| 368 | + mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True) |
| 369 | + gamma = tf.get_variable('gamma', [1, channels, 1, 1], initializer=tf.ones_initializer()) |
| 370 | + beta = tf.get_variable('beta', [1, channels, 1, 1], initializer=tf.zeros_initializer()) |
348 | 371 |
|
349 |
| - gamma = tf.get_variable('gamma', channels, initializer=tf.ones_initializer()) |
350 |
| - beta = tf.get_variable('beta', channels, initializer=tf.zeros_initializer()) |
| 372 | + x = (x - mean) / tf.sqrt(var + epsilon) |
351 | 373 |
|
352 | 374 | self.outputs = tf.reshape(x, tf.shape(self.inputs)) * gamma + beta
|
353 | 375 | self.outputs = self._apply_activation(self.outputs)
|
354 | 376 |
|
355 |
| - variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) |
| 377 | + variables = get_collection_trainable(self.name) |
356 | 378 |
|
357 | 379 | self._add_layers(self.outputs)
|
358 | 380 | self._add_params(variables)
|
|
0 commit comments