diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index 3f9b8f870..aade90e0a 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -260,6 +260,7 @@ Layer list DeConv2d DeConv3d DepthwiseConv2d + SeparableConv1d SeparableConv2d DeformableConv2d GroupConv2d @@ -502,6 +503,10 @@ APIs may better for you. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: DepthwiseConv2d +1D Depthwise Separable Conv +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: SeparableConv1d + 2D Depthwise Separable Conv ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: SeparableConv2d diff --git a/tensorlayer/layers/convolution.py b/tensorlayer/layers/convolution.py index fdbfefe85..b0b01451e 100644 --- a/tensorlayer/layers/convolution.py +++ b/tensorlayer/layers/convolution.py @@ -24,6 +24,7 @@ 'DeConv2d', 'DeConv3d', 'DepthwiseConv2d', + 'SeparableConv1d', 'SeparableConv2d', 'GroupConv2d', ] @@ -1152,115 +1153,6 @@ def __init__( self.all_params.append(filters) -class _SeparableConv2dLayer(Layer): # TODO - """The :class:`SeparableConv2dLayer` class is 2D convolution with separable filters, see `tf.layers.separable_conv2d `__. - - This layer has not been fully tested yet. - - Parameters - ---------- - prev_layer : :class:`Layer` - Previous layer with a 4D output tensor in the shape of [batch, height, width, channels]. - n_filter : int - The number of filters. - filter_size : tuple of int - The filter size (height, width). - strides : tuple of int - The strides (height, width). - This can be a single integer if you want to specify the same value for all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1. - padding : str - The type of padding algorithm: "SAME" or "VALID" - data_format : str - One of channels_last (Default) or channels_first. - The order must match the input dimensions. - channels_last corresponds to inputs with shapedata_format = 'NWHC' (batch, width, height, channels) while - channels_first corresponds to inputs with shape [batch, channels, width, height]. - dilation_rate : int or tuple of ints - The dilation rate of the convolution. - It can be a single integer if you want to specify the same value for all spatial dimensions. - Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1. - depth_multiplier : int - The number of depthwise convolution output channels for each input channel. - The total number of depthwise convolution output channels will be equal to num_filters_in * depth_multiplier. - act : activation function - The activation function of this layer. - use_bias : boolean - Whether the layer uses a bias - depthwise_initializer : initializer - The initializer for the depthwise convolution kernel. - pointwise_initializer : initializer - The initializer for the pointwise convolution kernel. - bias_initializer : initializer - The initializer for the bias vector. If None, skip bias. - depthwise_regularizer : regularizer - Optional regularizer for the depthwise convolution kernel. - pointwise_regularizer : regularizer - Optional regularizer for the pointwise convolution kernel. - bias_regularizer : regularizer - Optional regularizer for the bias vector. - activity_regularizer : regularizer - Regularizer function for the output. - name : str - A unique layer name. - - """ - - @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release - def __init__( - self, prev_layer, n_filter, filter_size=5, strides=(1, 1), padding='valid', data_format='channels_last', - dilation_rate=(1, 1), depth_multiplier=1, act=tf.identity, use_bias=True, depthwise_initializer=None, - pointwise_initializer=None, bias_initializer=tf.zeros_initializer, depthwise_regularizer=None, - pointwise_regularizer=None, bias_regularizer=None, activity_regularizer=None, name='atrou2d' - ): - - super(_SeparableConv2dLayer, self).__init__(prev_layer=prev_layer, name=name) - logging.info( - "SeparableConv2dLayer %s: n_filter:%d filter_size:%s strides:%s padding:%s dilation_rate:%s depth_multiplier:%s act:%s" - % ( - name, n_filter, filter_size, str(strides), padding, str(dilation_rate), str(depth_multiplier), - act.__name__ - ) - ) - - self.inputs = prev_layer.outputs - - if tf.__version__ > "0.12.1": - raise Exception("This layer only supports for TF 1.0+") - - bias_initializer = bias_initializer() - - with tf.variable_scope(name) as vs: - self.outputs = tf.layers.separable_conv2d( - self.inputs, - filters=n_filter, - kernel_size=filter_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=act, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - pointwise_initializer=pointwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - pointwise_regularizer=pointwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - ) - # trainable=True, name=None, reuse=None) - - variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) - - # self.all_layers = list(layer.all_layers) - # self.all_params = list(layer.all_params) - # self.all_drop = dict(layer.all_drop) - self.all_layers.append(self.outputs) - self.all_params.extend(variables) - - def deconv2d_bilinear_upsampling_initializer(shape): """Returns the initializer that can be passed to DeConv2dLayer for initializ ingthe weights in correspondence to channel-wise bilinear up-sampling. @@ -1762,18 +1654,18 @@ def __init__( self.inputs = prev_layer.outputs with tf.variable_scope(name) as vs: - self.outputs = tf.contrib.layers.conv3d_transpose( - inputs=self.inputs, - num_outputs=n_filter, + nn = tf.layers.Conv3DTranspose( + filters=n_filter, kernel_size=filter_size, - stride=strides, + strides=strides, padding=padding, - activation_fn=act, - weights_initializer=W_init, - biases_initializer=b_init, - scope=name, + activation=act, + kernel_initializer=W_init, + bias_initializer=b_init, + name=None, ) - new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) + self.outputs = nn(self.inputs) + new_variables = nn.weights # tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) self.all_layers.append(self.outputs) self.all_params.extend(new_variables) @@ -1908,6 +1800,113 @@ def __init__( self.all_params.append(W) +class SeparableConv1d(Layer): + """The :class:`SeparableConv1d` class is a 1D depthwise separable convolutional layer, see `tf.layers.separable_conv1d `__. + + This layer performs a depthwise convolution that acts separately on channels, followed by a pointwise convolution that mixes channels. + + Parameters + ------------ + prev_layer : :class:`Layer` + Previous layer. + n_filter : int + The dimensionality of the output space (i.e. the number of filters in the convolution). + filter_size : int + Specifying the spatial dimensions of the filters. Can be a single integer to specify the same value for all spatial dimensions. + strides : int + Specifying the stride of the convolution. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1. + padding : str + One of "valid" or "same" (case-insensitive). + data_format : str + One of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width). + dilation_rate : int + Specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1. + depth_multiplier : int + The number of depthwise convolution output channels for each input channel. The total number of depthwise convolution output channels will be equal to num_filters_in * depth_multiplier. + depthwise_init : initializer + for the depthwise convolution kernel. + pointwise_init : initializer + For the pointwise convolution kernel. + b_init : initializer + For the bias vector. If None, ignore bias in the pointwise part only. + name : a str + A unique layer name. + + """ + + @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release + def __init__( + self, + prev_layer, + n_filter=100, + filter_size=3, + strides=1, + act=tf.identity, + padding='valid', + data_format='channels_last', + dilation_rate=1, + depth_multiplier=1, + # activation=None, + # use_bias=True, + depthwise_init=None, + pointwise_init=None, + b_init=tf.zeros_initializer(), + # depthwise_regularizer=None, + # pointwise_regularizer=None, + # bias_regularizer=None, + # activity_regularizer=None, + # depthwise_constraint=None, + # pointwise_constraint=None, + # W_init=tf.truncated_normal_initializer(stddev=0.1), + # b_init=tf.constant_initializer(value=0.0), + # W_init_args=None, + # b_init_args=None, + name='seperable1d', + ): + # if W_init_args is None: + # W_init_args = {} + # if b_init_args is None: + # b_init_args = {} + + super(SeparableConv1d, self).__init__(prev_layer=prev_layer, name=name) + logging.info( + "SeparableConv1d %s: n_filter:%d filter_size:%s filter_size:%s depth_multiplier:%d act:%s" % + (self.name, n_filter, str(filter_size), str(strides), depth_multiplier, act.__name__) + ) + + self.inputs = prev_layer.outputs + + with tf.variable_scope(name) as vs: + nn = tf.layers.SeparableConv1D( + filters=n_filter, + kernel_size=filter_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=act, + use_bias=(True if b_init is not None else False), + depthwise_initializer=depthwise_init, + pointwise_initializer=pointwise_init, + bias_initializer=b_init, + # depthwise_regularizer=None, + # pointwise_regularizer=None, + # bias_regularizer=None, + # activity_regularizer=None, + # depthwise_constraint=None, + # pointwise_constraint=None, + # bias_constraint=None, + trainable=True, + name=None + ) + self.outputs = nn(self.inputs) + new_variables = nn.weights + + self.all_layers.append(self.outputs) + self.all_params.extend(new_variables) + + class SeparableConv2d(Layer): """The :class:`SeparableConv2d` class is a 2D depthwise separable convolutional layer, see `tf.layers.separable_conv2d `__. @@ -1986,8 +1985,7 @@ def __init__( self.inputs = prev_layer.outputs with tf.variable_scope(name) as vs: - self.outputs = tf.layers.separable_conv2d( - inputs=self.inputs, + nn = tf.layers.SeparableConv2D( filters=n_filter, kernel_size=filter_size, strides=strides, @@ -2010,7 +2008,9 @@ def __init__( trainable=True, name=None ) - new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) + self.outputs = nn(self.inputs) + new_variables = nn.weights + # new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) self.all_layers.append(self.outputs) self.all_params.extend(new_variables) diff --git a/tests/test_layers_convolution.py b/tests/test_layers_convolution.py index e49c2b046..2121794f1 100644 --- a/tests/test_layers_convolution.py +++ b/tests/test_layers_convolution.py @@ -23,6 +23,14 @@ def setUpClass(cls): n2 = tl.layers.Conv1d(nin1, n_filter=32, filter_size=5, stride=2) cls.shape_n2 = n2.outputs.get_shape().as_list() + n2_1 = tl.layers.SeparableConv1d( + nin1, n_filter=32, filter_size=3, strides=1, padding='VALID', act=tf.nn.relu, name='seperable1d1' + ) + cls.shape_n2_1 = n2_1.outputs.get_shape().as_list() + cls.n2_1_all_layers = n2_1.all_layers + cls.n2_1_params = n2_1.all_params + cls.n2_1_count_params = n2_1.count_params() + ############ # 2D # ############ @@ -65,7 +73,7 @@ def setUpClass(cls): cls.shape_n9 = n9.outputs.get_shape().as_list() n10 = tl.layers.SeparableConv2d( - nin2, n_filter=32, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, name='seperable1' + nin2, n_filter=32, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, name='seperable2d1' ) cls.shape_n10 = n10.outputs.get_shape().as_list() cls.n10_all_layers = n10.all_layers @@ -101,6 +109,10 @@ def test_shape_n2(self): self.assertEqual(self.shape_n2[1], 50) self.assertEqual(self.shape_n2[2], 32) + def test_shape_n2_1(self): + self.assertEqual(self.shape_n2_1[1], 98) + self.assertEqual(self.shape_n2_1[2], 32) + def test_shape_n3(self): self.assertEqual(self.shape_n3[1], 50) self.assertEqual(self.shape_n3[2], 50) @@ -151,6 +163,9 @@ def test_shape_n12(self): self.assertEqual(self.shape_n12[3], 200) self.assertEqual(self.shape_n12[4], 32) + def test_params_n2_1(self): + self.assertEqual(len(self.n2_1_params), 3) + def test_params_n4(self): self.assertEqual(len(self.n4_params), 2) @@ -161,6 +176,9 @@ def test_params_n10(self): self.assertEqual(len(self.n10_params), 3) self.assertEqual(self.n10_count_params, 155) + def test_layers_n2_1(self): + self.assertEqual(len(self.n2_1_all_layers), 1) + def test_layers_n10(self): self.assertEqual(len(self.n10_all_layers), 1)