Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/modules/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Layer list
DeConv2d
DeConv3d
DepthwiseConv2d
SeparableConv1d
SeparableConv2d
DeformableConv2d
GroupConv2d
Expand Down Expand Up @@ -502,6 +503,10 @@ APIs may better for you.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: DepthwiseConv2d

1D Depthwise Separable Conv
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: SeparableConv1d

2D Depthwise Separable Conv
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: SeparableConv2d
Expand Down
242 changes: 121 additions & 121 deletions tensorlayer/layers/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
'DeConv2d',
'DeConv3d',
'DepthwiseConv2d',
'SeparableConv1d',
'SeparableConv2d',
'GroupConv2d',
]
Expand Down Expand Up @@ -1152,115 +1153,6 @@ def __init__(
self.all_params.append(filters)


class _SeparableConv2dLayer(Layer): # TODO
"""The :class:`SeparableConv2dLayer` class is 2D convolution with separable filters, see `tf.layers.separable_conv2d <https://www.tensorflow.org/api_docs/python/tf/layers/separable_conv2d>`__.

This layer has not been fully tested yet.

Parameters
----------
prev_layer : :class:`Layer`
Previous layer with a 4D output tensor in the shape of [batch, height, width, channels].
n_filter : int
The number of filters.
filter_size : tuple of int
The filter size (height, width).
strides : tuple of int
The strides (height, width).
This can be a single integer if you want to specify the same value for all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.
padding : str
The type of padding algorithm: "SAME" or "VALID"
data_format : str
One of channels_last (Default) or channels_first.
The order must match the input dimensions.
channels_last corresponds to inputs with shapedata_format = 'NWHC' (batch, width, height, channels) while
channels_first corresponds to inputs with shape [batch, channels, width, height].
dilation_rate : int or tuple of ints
The dilation rate of the convolution.
It can be a single integer if you want to specify the same value for all spatial dimensions.
Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.
depth_multiplier : int
The number of depthwise convolution output channels for each input channel.
The total number of depthwise convolution output channels will be equal to num_filters_in * depth_multiplier.
act : activation function
The activation function of this layer.
use_bias : boolean
Whether the layer uses a bias
depthwise_initializer : initializer
The initializer for the depthwise convolution kernel.
pointwise_initializer : initializer
The initializer for the pointwise convolution kernel.
bias_initializer : initializer
The initializer for the bias vector. If None, skip bias.
depthwise_regularizer : regularizer
Optional regularizer for the depthwise convolution kernel.
pointwise_regularizer : regularizer
Optional regularizer for the pointwise convolution kernel.
bias_regularizer : regularizer
Optional regularizer for the bias vector.
activity_regularizer : regularizer
Regularizer function for the output.
name : str
A unique layer name.

"""

@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(
self, prev_layer, n_filter, filter_size=5, strides=(1, 1), padding='valid', data_format='channels_last',
dilation_rate=(1, 1), depth_multiplier=1, act=tf.identity, use_bias=True, depthwise_initializer=None,
pointwise_initializer=None, bias_initializer=tf.zeros_initializer, depthwise_regularizer=None,
pointwise_regularizer=None, bias_regularizer=None, activity_regularizer=None, name='atrou2d'
):

super(_SeparableConv2dLayer, self).__init__(prev_layer=prev_layer, name=name)
logging.info(
"SeparableConv2dLayer %s: n_filter:%d filter_size:%s strides:%s padding:%s dilation_rate:%s depth_multiplier:%s act:%s"
% (
name, n_filter, filter_size, str(strides), padding, str(dilation_rate), str(depth_multiplier),
act.__name__
)
)

self.inputs = prev_layer.outputs

if tf.__version__ > "0.12.1":
raise Exception("This layer only supports for TF 1.0+")

bias_initializer = bias_initializer()

with tf.variable_scope(name) as vs:
self.outputs = tf.layers.separable_conv2d(
self.inputs,
filters=n_filter,
kernel_size=filter_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
depth_multiplier=depth_multiplier,
activation=act,
use_bias=use_bias,
depthwise_initializer=depthwise_initializer,
pointwise_initializer=pointwise_initializer,
bias_initializer=bias_initializer,
depthwise_regularizer=depthwise_regularizer,
pointwise_regularizer=pointwise_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
)
# trainable=True, name=None, reuse=None)

variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)

# self.all_layers = list(layer.all_layers)
# self.all_params = list(layer.all_params)
# self.all_drop = dict(layer.all_drop)
self.all_layers.append(self.outputs)
self.all_params.extend(variables)


def deconv2d_bilinear_upsampling_initializer(shape):
"""Returns the initializer that can be passed to DeConv2dLayer for initializ ingthe
weights in correspondence to channel-wise bilinear up-sampling.
Expand Down Expand Up @@ -1762,18 +1654,18 @@ def __init__(
self.inputs = prev_layer.outputs

with tf.variable_scope(name) as vs:
self.outputs = tf.contrib.layers.conv3d_transpose(
inputs=self.inputs,
num_outputs=n_filter,
nn = tf.layers.Conv3DTranspose(
filters=n_filter,
kernel_size=filter_size,
stride=strides,
strides=strides,
padding=padding,
activation_fn=act,
weights_initializer=W_init,
biases_initializer=b_init,
scope=name,
activation=act,
kernel_initializer=W_init,
bias_initializer=b_init,
name=None,
)
new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
self.outputs = nn(self.inputs)
new_variables = nn.weights # tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)

self.all_layers.append(self.outputs)
self.all_params.extend(new_variables)
Expand Down Expand Up @@ -1908,6 +1800,113 @@ def __init__(
self.all_params.append(W)


class SeparableConv1d(Layer):
"""The :class:`SeparableConv1d` class is a 1D depthwise separable convolutional layer, see `tf.layers.separable_conv1d <https://www.tensorflow.org/api_docs/python/tf/layers/separable_conv1d>`__.

This layer performs a depthwise convolution that acts separately on channels, followed by a pointwise convolution that mixes channels.

Parameters
------------
prev_layer : :class:`Layer`
Previous layer.
n_filter : int
The dimensionality of the output space (i.e. the number of filters in the convolution).
filter_size : int
Specifying the spatial dimensions of the filters. Can be a single integer to specify the same value for all spatial dimensions.
strides : int
Specifying the stride of the convolution. Can be a single integer to specify the same value for all spatial dimensions. Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.
padding : str
One of "valid" or "same" (case-insensitive).
data_format : str
One of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).
dilation_rate : int
Specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any dilation_rate value != 1 is incompatible with specifying any stride value != 1.
depth_multiplier : int
The number of depthwise convolution output channels for each input channel. The total number of depthwise convolution output channels will be equal to num_filters_in * depth_multiplier.
depthwise_init : initializer
for the depthwise convolution kernel.
pointwise_init : initializer
For the pointwise convolution kernel.
b_init : initializer
For the bias vector. If None, ignore bias in the pointwise part only.
name : a str
A unique layer name.

"""

@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(
self,
prev_layer,
n_filter=100,
filter_size=3,
strides=1,
act=tf.identity,
padding='valid',
data_format='channels_last',
dilation_rate=1,
depth_multiplier=1,
# activation=None,
# use_bias=True,
depthwise_init=None,
pointwise_init=None,
b_init=tf.zeros_initializer(),
# depthwise_regularizer=None,
# pointwise_regularizer=None,
# bias_regularizer=None,
# activity_regularizer=None,
# depthwise_constraint=None,
# pointwise_constraint=None,
# W_init=tf.truncated_normal_initializer(stddev=0.1),
# b_init=tf.constant_initializer(value=0.0),
# W_init_args=None,
# b_init_args=None,
name='seperable1d',
):
# if W_init_args is None:
# W_init_args = {}
# if b_init_args is None:
# b_init_args = {}

super(SeparableConv1d, self).__init__(prev_layer=prev_layer, name=name)
logging.info(
"SeparableConv1d %s: n_filter:%d filter_size:%s filter_size:%s depth_multiplier:%d act:%s" %
(self.name, n_filter, str(filter_size), str(strides), depth_multiplier, act.__name__)
)

self.inputs = prev_layer.outputs

with tf.variable_scope(name) as vs:
nn = tf.layers.SeparableConv1D(
filters=n_filter,
kernel_size=filter_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
depth_multiplier=depth_multiplier,
activation=act,
use_bias=(True if b_init is not None else False),
depthwise_initializer=depthwise_init,
pointwise_initializer=pointwise_init,
bias_initializer=b_init,
# depthwise_regularizer=None,
# pointwise_regularizer=None,
# bias_regularizer=None,
# activity_regularizer=None,
# depthwise_constraint=None,
# pointwise_constraint=None,
# bias_constraint=None,
trainable=True,
name=None
)
self.outputs = nn(self.inputs)
new_variables = nn.weights

self.all_layers.append(self.outputs)
self.all_params.extend(new_variables)


class SeparableConv2d(Layer):
"""The :class:`SeparableConv2d` class is a 2D depthwise separable convolutional layer, see `tf.layers.separable_conv2d <https://www.tensorflow.org/api_docs/python/tf/layers/separable_conv2d>`__.

Expand Down Expand Up @@ -1986,8 +1985,7 @@ def __init__(
self.inputs = prev_layer.outputs

with tf.variable_scope(name) as vs:
self.outputs = tf.layers.separable_conv2d(
inputs=self.inputs,
nn = tf.layers.SeparableConv2D(
filters=n_filter,
kernel_size=filter_size,
strides=strides,
Expand All @@ -2010,7 +2008,9 @@ def __init__(
trainable=True,
name=None
)
new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)
self.outputs = nn(self.inputs)
new_variables = nn.weights
# new_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name)

self.all_layers.append(self.outputs)
self.all_params.extend(new_variables)
Expand Down
20 changes: 19 additions & 1 deletion tests/test_layers_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def setUpClass(cls):
n2 = tl.layers.Conv1d(nin1, n_filter=32, filter_size=5, stride=2)
cls.shape_n2 = n2.outputs.get_shape().as_list()

n2_1 = tl.layers.SeparableConv1d(
nin1, n_filter=32, filter_size=3, strides=1, padding='VALID', act=tf.nn.relu, name='seperable1d1'
)
cls.shape_n2_1 = n2_1.outputs.get_shape().as_list()
cls.n2_1_all_layers = n2_1.all_layers
cls.n2_1_params = n2_1.all_params
cls.n2_1_count_params = n2_1.count_params()

############
# 2D #
############
Expand Down Expand Up @@ -65,7 +73,7 @@ def setUpClass(cls):
cls.shape_n9 = n9.outputs.get_shape().as_list()

n10 = tl.layers.SeparableConv2d(
nin2, n_filter=32, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, name='seperable1'
nin2, n_filter=32, filter_size=(3, 3), strides=(1, 1), act=tf.nn.relu, name='seperable2d1'
)
cls.shape_n10 = n10.outputs.get_shape().as_list()
cls.n10_all_layers = n10.all_layers
Expand Down Expand Up @@ -101,6 +109,10 @@ def test_shape_n2(self):
self.assertEqual(self.shape_n2[1], 50)
self.assertEqual(self.shape_n2[2], 32)

def test_shape_n2_1(self):
self.assertEqual(self.shape_n2_1[1], 98)
self.assertEqual(self.shape_n2_1[2], 32)

def test_shape_n3(self):
self.assertEqual(self.shape_n3[1], 50)
self.assertEqual(self.shape_n3[2], 50)
Expand Down Expand Up @@ -151,6 +163,9 @@ def test_shape_n12(self):
self.assertEqual(self.shape_n12[3], 200)
self.assertEqual(self.shape_n12[4], 32)

def test_params_n2_1(self):
self.assertEqual(len(self.n2_1_params), 3)

def test_params_n4(self):
self.assertEqual(len(self.n4_params), 2)

Expand All @@ -161,6 +176,9 @@ def test_params_n10(self):
self.assertEqual(len(self.n10_params), 3)
self.assertEqual(self.n10_count_params, 155)

def test_layers_n2_1(self):
self.assertEqual(len(self.n2_1_all_layers), 1)

def test_layers_n10(self):
self.assertEqual(len(self.n10_all_layers), 1)

Expand Down