diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index 3d393ad62..64124d034 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -279,6 +279,8 @@ Layer list GlobalMeanPool1d GlobalMaxPool2d GlobalMeanPool2d + GlobalMaxPool3d + GlobalMeanPool3d SubpixelConv1d SubpixelConv2d @@ -603,6 +605,14 @@ Pooling layer for any dimensions and any pooling functions. ^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: GlobalMeanPool2d +3D Global Max pooling +^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: GlobalMaxPool3d + +3D Global Mean pooling +^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: GlobalMeanPool3d + Normalization layer -------------------- diff --git a/tensorlayer/layers/padding.py b/tensorlayer/layers/padding.py index 8ed5f8c31..0d566401c 100644 --- a/tensorlayer/layers/padding.py +++ b/tensorlayer/layers/padding.py @@ -137,92 +137,3 @@ def __init__( logging.info("ZeroPad3d %s: padding:%s" % (self.name, str(padding))) self.outputs = tf.keras.layers.ZeroPadding3D(padding=padding, name=name)(self.inputs) self.all_layers.append(self.outputs) - - -class ZeroPad1d(Layer): - """ - The :class:`ZeroPad1d` class is a 1D padding layer for signal [batch, length, channel]. - - Parameters - ---------- - layer : :class:`Layer` - The previous layer. - padding : int, or tuple of 2 ints - - If int, zeros to add at the beginning and end of the padding dimension (axis 1). - - If tuple of 2 ints, zeros to add at the beginning and at the end of the padding dimension. - name : str - A unique layer name. - - """ - - def __init__( - self, - prev_layer, - padding, - name='zeropad1d', - ): - Layer.__init__(self, prev_layer=prev_layer, name=name) - self.inputs = prev_layer.outputs - logging.info("ZeroPad1d %s: padding:%s" % (self.name, str(padding))) - self.outputs = tf.keras.layers.ZeroPadding1D(padding=padding, name=name)(self.inputs) - self.all_layers.append(self.outputs) - - -class ZeroPad2d(Layer): - """ - The :class:`ZeroPad2d` class is a 2D padding layer for image [batch, height, width, channel]. - - Parameters - ---------- - layer : :class:`Layer` - The previous layer. - padding : int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. - - If int, the same symmetric padding is applied to width and height. - - If tuple of 2 ints, interpreted as two different symmetric padding values for height and width as ``(symmetric_height_pad, symmetric_width_pad)``. - - If tuple of 2 tuples of 2 ints, interpreted as ``((top_pad, bottom_pad), (left_pad, right_pad))``. - name : str - A unique layer name. - - """ - - def __init__( - self, - prev_layer, - padding, - name='zeropad2d', - ): - Layer.__init__(self, prev_layer=prev_layer, name=name) - self.inputs = prev_layer.outputs - logging.info("ZeroPad2d %s: padding:%s" % (self.name, str(padding))) - self.outputs = tf.keras.layers.ZeroPadding2D(padding=padding, name=name)(self.inputs) - self.all_layers.append(self.outputs) - - -class ZeroPad3d(Layer): - """ - The :class:`ZeroPad3d` class is a 3D padding layer for volume [batch, height, width, depth, channel]. - - Parameters - ---------- - layer : :class:`Layer` - The previous layer. - padding : int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. - - If int, the same symmetric padding is applied to width and height. - - If tuple of 2 ints, interpreted as two different symmetric padding values for height and width as ``(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)``. - - If tuple of 2 tuples of 2 ints, interpreted as ``((left_dim1_pad, right_dim1_pad), (left_dim2_pad, right_dim2_pad), (left_dim3_pad, right_dim3_pad))``. - name : str - A unique layer name. - - """ - - def __init__( - self, - prev_layer, - padding, - name='zeropad3d', - ): - Layer.__init__(self, prev_layer=prev_layer, name=name) - self.inputs = prev_layer.outputs - logging.info("ZeroPad3d %s: padding:%s" % (self.name, str(padding))) - self.outputs = tf.keras.layers.ZeroPadding3D(padding=padding, name=name)(self.inputs) - self.all_layers.append(self.outputs) diff --git a/tensorlayer/layers/pooling.py b/tensorlayer/layers/pooling.py index 034ebd1ac..926bc431f 100644 --- a/tensorlayer/layers/pooling.py +++ b/tensorlayer/layers/pooling.py @@ -17,6 +17,8 @@ 'GlobalMeanPool1d', 'GlobalMaxPool2d', 'GlobalMeanPool2d', + 'GlobalMaxPool3d', + 'GlobalMeanPool3d', ] @@ -62,12 +64,7 @@ def __init__( Layer.__init__(self, prev_layer=prev_layer, name=name) self.inputs = prev_layer.outputs logging.info("PoolLayer %s: ksize:%s strides:%s padding:%s pool:%s" % (self.name, str(ksize), str(strides), padding, pool.__name__)) - self.outputs = pool(self.inputs, ksize=ksize, strides=strides, padding=padding, name=name) - - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) self.all_layers.append(self.outputs) @@ -77,7 +74,7 @@ def maxpool1d(net, filter_size=3, strides=2, padding='valid', data_format='chann Parameters ---------- net : :class:`Layer` - The previous layer with a output rank as 3. + The previous layer with a output rank as 3 [batch, length, channel]. filter_size : tuple of int Pooling window size. strides : tuple of int @@ -113,7 +110,7 @@ def meanpool1d(net, filter_size=3, strides=2, padding='valid', data_format='chan Parameters ------------ net : :class:`Layer` - The previous layer with a output rank as 3. + The previous layer with a output rank as 3 [batch, length, channel]. filter_size : tuple of int Pooling window size. strides : tuple of int @@ -149,7 +146,7 @@ def maxpool2d(net, filter_size=(3, 3), strides=(2, 2), padding='SAME', name='max Parameters ----------- net : :class:`Layer` - The previous layer with a output rank as 4. + The previous layer with a output rank as 4 [batch, height, width, channel]. filter_size : tuple of int (height, width) for filter size. strides : tuple of int @@ -178,7 +175,7 @@ def meanpool2d(net, filter_size=(3, 3), strides=(2, 2), padding='SAME', name='me Parameters ----------- layer : :class:`Layer` - The previous layer with a output rank as 4. + The previous layer with a output rank as 4 [batch, height, width, channel]. filter_size : tuple of int (height, width) for filter size. strides : tuple of int @@ -208,7 +205,7 @@ class MaxPool3d(Layer): Parameters ------------ layer : :class:`Layer` - The previous layer with a output rank as 5. + The previous layer with a output rank as 5 [batch, height, width, depth, channel]. filter_size : tuple of int Pooling window size. strides : tuple of int @@ -231,22 +228,12 @@ class MaxPool3d(Layer): """ def __init__(self, prev_layer, filter_size=(3, 3, 3), strides=(2, 2, 2), padding='valid', data_format='channels_last', name='maxpool3d'): - # check layer name (fixed) Layer.__init__(self, prev_layer=prev_layer, name=name) - # the input of this layer is the output of previous layer (fixed) self.inputs = prev_layer.outputs - logging.info("MaxPool3d %s: filter_size:%s strides:%s padding:%s" % (name, str(filter_size), str(strides), str(padding))) - self.outputs = tf.layers.max_pooling3d(prev_layer.outputs, filter_size, strides, padding=padding, data_format=data_format, name=name) - - # get stuff from previous layer (fixed) - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) - # update layer (customized) self.all_layers.append(self.outputs) @@ -258,7 +245,7 @@ class MeanPool3d(Layer): Parameters ------------ layer : :class:`Layer` - The previous layer with a output rank as 5. + The previous layer with a output rank as 5 [batch, height, width, depth, channel]. filter_size : tuple of int Pooling window size. strides : tuple of int @@ -283,24 +270,14 @@ class MeanPool3d(Layer): def __init__(self, prev_layer, filter_size=(3, 3, 3), strides=(2, 2, 2), padding='valid', data_format='channels_last', name='meanpool3d'): # check layer name (fixed) Layer.__init__(self, prev_layer=prev_layer, name=name) - # the input of this layer is the output of previous layer (fixed) self.inputs = prev_layer.outputs - # print out info (customized) logging.info("MeanPool3d %s: filter_size:%s strides:%s padding:%s" % (name, str(filter_size), str(strides), str(padding))) - # operation (customized) self.outputs = tf.layers.average_pooling3d(prev_layer.outputs, filter_size, strides, padding=padding, data_format=data_format, name=name) - - # get stuff from previous layer (fixed) - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) - # update layer (customized) self.all_layers.append(self.outputs) - # self.all_params.extend( [W, b] ) class GlobalMaxPool1d(Layer): @@ -309,7 +286,7 @@ class GlobalMaxPool1d(Layer): Parameters ------------ layer : :class:`Layer` - The previous layer with a output rank as 3. + The previous layer with a output rank as 3 [batch, length, channel]. name : str A unique layer name. @@ -328,24 +305,14 @@ def __init__( ): # check layer name (fixed) Layer.__init__(self, prev_layer=prev_layer, name=name) - # the input of this layer is the output of previous layer (fixed) self.inputs = prev_layer.outputs - # print out info (customized) logging.info("GlobalMaxPool1d %s" % name) - # operation (customized) self.outputs = tf.reduce_max(prev_layer.outputs, axis=1, name=name) - - # get stuff from previous layer (fixed) - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) - # update layer (customized) self.all_layers.append(self.outputs) - # self.all_params.extend( [W, b] ) class GlobalMeanPool1d(Layer): @@ -354,7 +321,7 @@ class GlobalMeanPool1d(Layer): Parameters ------------ layer : :class:`Layer` - The previous layer with a output rank as 3. + The previous layer with a output rank as 3 [batch, length, channel]. name : str A unique layer name. @@ -373,24 +340,14 @@ def __init__( ): # check layer name (fixed) Layer.__init__(self, prev_layer=prev_layer, name=name) - # the input of this layer is the output of previous layer (fixed) self.inputs = prev_layer.outputs - # print out info (customized) logging.info("GlobalMeanPool1d %s" % name) - # operation (customized) self.outputs = tf.reduce_mean(prev_layer.outputs, axis=1, name=name) - - # get stuff from previous layer (fixed) - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) - # update layer (customized) self.all_layers.append(self.outputs) - # self.all_params.extend( [W, b] ) class GlobalMaxPool2d(Layer): @@ -399,7 +356,7 @@ class GlobalMaxPool2d(Layer): Parameters ------------ layer : :class:`Layer` - The previous layer with a output rank as 4. + The previous layer with a output rank as 4 [batch, height, width, channel]. name : str A unique layer name. @@ -418,24 +375,14 @@ def __init__( ): # check layer name (fixed) Layer.__init__(self, prev_layer=prev_layer, name=name) - # the input of this layer is the output of previous layer (fixed) self.inputs = prev_layer.outputs - # print out info (customized) logging.info("GlobalMaxPool2d %s" % name) - # operation (customized) self.outputs = tf.reduce_max(prev_layer.outputs, axis=[1, 2], name=name) - - # get stuff from previous layer (fixed) - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) - # update layer (customized) self.all_layers.append(self.outputs) - # self.all_params.extend( [W, b] ) class GlobalMeanPool2d(Layer): @@ -444,7 +391,7 @@ class GlobalMeanPool2d(Layer): Parameters ------------ layer : :class:`Layer` - The previous layer with a output rank as 4. + The previous layer with a output rank as 4 [batch, height, width, channel]. name : str A unique layer name. @@ -463,24 +410,84 @@ def __init__( ): # check layer name (fixed) Layer.__init__(self, prev_layer=prev_layer, name=name) - # the input of this layer is the output of previous layer (fixed) self.inputs = prev_layer.outputs - # print out info (customized) logging.info("GlobalMeanPool2d %s" % name) - # operation (customized) self.outputs = tf.reduce_mean(prev_layer.outputs, axis=[1, 2], name=name) + # update layer (customized) + self.all_layers.append(self.outputs) + + +class GlobalMaxPool3d(Layer): + """The :class:`GlobalMaxPool3d` class is a 3D Global Max Pooling layer. + + Parameters + ------------ + layer : :class:`Layer` + The previous layer with a output rank as 5 [batch, height, width, depth, channel]. + name : str + A unique layer name. + + Examples + --------- + >>> x = tf.placeholder("float32", [None, 100, 100, 100, 30]) + >>> n = InputLayer(x, name='in') + >>> n = GlobalMaxPool3d(n) + ... [None, 30] + """ + + def __init__( + self, + prev_layer=None, + name='globalmaxpool3d', + ): + # check layer name (fixed) + Layer.__init__(self, prev_layer=prev_layer, name=name) + # the input of this layer is the output of previous layer (fixed) + self.inputs = prev_layer.outputs + # print out info (customized) + logging.info("GlobalMaxPool3d %s" % name) + # operation (customized) + self.outputs = tf.reduce_max(prev_layer.outputs, axis=[1, 2, 3], name=name) + # update layer (customized) + self.all_layers.append(self.outputs) + + +class GlobalMeanPool3d(Layer): + """The :class:`GlobalMeanPool3d` class is a 3D Global Mean Pooling layer. + + Parameters + ------------ + layer : :class:`Layer` + The previous layer with a output rank as 5 [batch, height, width, depth, channel]. + name : str + A unique layer name. - # get stuff from previous layer (fixed) - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) + Examples + --------- + >>> x = tf.placeholder("float32", [None, 100, 100, 100, 30]) + >>> n = InputLayer(x, name='in') + >>> n = GlobalMeanPool2d(n) + ... [None, 30] + """ + def __init__( + self, + prev_layer=None, + name='globalmeanpool3d', + ): + # check layer name (fixed) + Layer.__init__(self, prev_layer=prev_layer, name=name) + # the input of this layer is the output of previous layer (fixed) + self.inputs = prev_layer.outputs + # print out info (customized) + logging.info("GlobalMeanPool3d %s" % name) + # operation (customized) + self.outputs = tf.reduce_mean(prev_layer.outputs, axis=[1, 2, 3], name=name) # update layer (customized) self.all_layers.append(self.outputs) - # self.all_params.extend( [W, b] ) # Alias diff --git a/tests/test_layers_pooling.py b/tests/test_layers_pooling.py index ce5bf7b10..6086c568d 100644 --- a/tests/test_layers_pooling.py +++ b/tests/test_layers_pooling.py @@ -68,3 +68,25 @@ shape = n.outputs.get_shape().as_list() if shape[-1] != 32: raise Exception("shape dont match") + +## 3D ======================================================================== +x = tf.placeholder(tf.float32, (None, 100, 100, 100, 3)) +nin = tl.layers.InputLayer(x, name='in') + +n = tl.layers.MeanPool3d(nin, (3, 3, 3), (2, 2, 2), 'SAME', name='meanpool3d') +print(n) +shape = n.outputs.get_shape().as_list() +if shape != [None, 50, 50, 50, 3]: + raise Exception("shape dont match") + +n = tl.layers.GlobalMaxPool3d(nin) +print(n) +shape = n.outputs.get_shape().as_list() +if shape != [None, 3]: + raise Exception("shape dont match") + +n = tl.layers.GlobalMeanPool3d(nin) +print(n) +shape = n.outputs.get_shape().as_list() +if shape != [None, 3]: + raise Exception("shape dont match")