diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index 95c767090..829086bf3 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -263,7 +263,6 @@ Layer list PadLayer UpSampling2dLayer DownSampling2dLayer - DeformableConv2dLayer AtrousConv1dLayer AtrousConv2dLayer @@ -280,6 +279,7 @@ Layer list MeanPool3d DepthwiseConv2d + DeformableConv2d SubpixelConv1d SubpixelConv2d @@ -456,10 +456,6 @@ Convolutional layer (Pro) ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: DownSampling2dLayer -2D Deformable Conv -^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: DeformableConv2dLayer - 1D Atrous convolution ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: AtrousConv1dLayer @@ -496,6 +492,11 @@ APIs may better for you. ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: DepthwiseConv2d +2D Deformable Conv +^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: DeformableConv2d + + Super-Resolution layer ------------------------ diff --git a/tensorlayer/layers/convolution.py b/tensorlayer/layers/convolution.py index 37b1b4bef..e313b221b 100644 --- a/tensorlayer/layers/convolution.py +++ b/tensorlayer/layers/convolution.py @@ -55,9 +55,10 @@ def __init__( ): Layer.__init__(self, name=name) self.inputs = layer.outputs - logging.info("Conv1dLayer %s: shape:%s stride:%s pad:%s act:%s" % (self.name, str(shape), str(stride), padding, act.__name__)) if act is None: act = tf.identity + logging.info("Conv1dLayer %s: shape:%s stride:%s pad:%s act:%s" % (self.name, str(shape), str(stride), padding, act.__name__)) + with tf.variable_scope(name) as vs: W = tf.get_variable(name='W_conv1d', shape=shape, initializer=W_init, dtype=D_TYPE, **W_init_args) self.outputs = tf.nn.convolution( @@ -165,6 +166,8 @@ def __init__( ): Layer.__init__(self, name=name) self.inputs = layer.outputs + if act is None: + act = tf.identity logging.info("Conv2dLayer %s: shape:%s strides:%s pad:%s act:%s" % (self.name, str(shape), str(strides), padding, act.__name__)) with tf.variable_scope(name) as vs: @@ -279,6 +282,8 @@ def __init__( ): Layer.__init__(self, name=name) self.inputs = layer.outputs + if act is None: + act = tf.identity logging.info("DeConv2dLayer %s: shape:%s out_shape:%s strides:%s pad:%s act:%s" % (self.name, str(shape), str(output_shape), str(strides), padding, act.__name__)) # logging.info(" DeConv2dLayer: Untested") @@ -345,6 +350,8 @@ def __init__( ): Layer.__init__(self, name=name) self.inputs = layer.outputs + if act is None: + act = tf.identity logging.info("Conv3dLayer %s: shape:%s strides:%s pad:%s act:%s" % (self.name, str(shape), str(strides), padding, act.__name__)) with tf.variable_scope(name) as vs: @@ -410,6 +417,8 @@ def __init__( ): Layer.__init__(self, name=name) self.inputs = layer.outputs + if act is None: + act = tf.identity logging.info("DeConv3dLayer %s: shape:%s out_shape:%s strides:%s pad:%s act:%s" % (self.name, str(shape), str(output_shape), str(strides), padding, act.__name__)) @@ -546,132 +555,26 @@ def __init__( self.all_layers.extend([self.outputs]) -def _to_bc_h_w(x, x_shape): - """(b, h, w, c) -> (b*c, h, w)""" - x = tf.transpose(x, [0, 3, 1, 2]) - x = tf.reshape(x, (-1, x_shape[1], x_shape[2])) - return x - - -def _to_b_h_w_n_c(x, x_shape): - """(b*c, h, w, n) -> (b, h, w, n, c)""" - x = tf.reshape(x, (-1, x_shape[4], x_shape[1], x_shape[2], x_shape[3])) - x = tf.transpose(x, [0, 2, 3, 4, 1]) - return x - - -def _tf_repeat(a, repeats): - """Tensorflow version of np.repeat for 1D""" - # https://github.com/tensorflow/tensorflow/issues/8521 - assert len(a.get_shape()) == 1 - - a = tf.expand_dims(a, -1) - a = tf.tile(a, [1, repeats]) - a = tf_flatten(a) - return a - - -def tf_batch_map_coordinates(inputs, coords): - """Batch version of tf_map_coordinates - - Only supports 2D feature maps - - Parameters - ---------- - inputs : ``tf.Tensor`` - shape = (b*c, h, w) - coords : ``tf.Tensor`` - shape = (b*c, h, w, n, 2) - - Returns - ------- - ``tf.Tensor`` - A Tensor with the shape as (b*c, h, w, n) - - """ - input_shape = inputs.get_shape() - coords_shape = coords.get_shape() - batch_channel = tf.shape(inputs)[0] - input_h = int(input_shape[1]) - input_w = int(input_shape[2]) - kernel_n = int(coords_shape[3]) - n_coords = input_h * input_w * kernel_n - - coords_lt = tf.cast(tf.floor(coords), 'int32') - coords_rb = tf.cast(tf.ceil(coords), 'int32') - coords_lb = tf.stack([coords_lt[:, :, :, :, 0], coords_rb[:, :, :, :, 1]], axis=-1) - coords_rt = tf.stack([coords_rb[:, :, :, :, 0], coords_lt[:, :, :, :, 1]], axis=-1) - - idx = _tf_repeat(tf.range(batch_channel), n_coords) - - vals_lt = _get_vals_by_coords(inputs, coords_lt, idx, (batch_channel, input_h, input_w, kernel_n)) - vals_rb = _get_vals_by_coords(inputs, coords_rb, idx, (batch_channel, input_h, input_w, kernel_n)) - vals_lb = _get_vals_by_coords(inputs, coords_lb, idx, (batch_channel, input_h, input_w, kernel_n)) - vals_rt = _get_vals_by_coords(inputs, coords_rt, idx, (batch_channel, input_h, input_w, kernel_n)) - - coords_offset_lt = coords - tf.cast(coords_lt, 'float32') - - vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, :, :, :, 0] - vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, :, :, :, 0] - mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, :, :, :, 1] - - return mapped_vals - - -def tf_batch_map_offsets(inputs, offsets, grid_offset): - """Batch map offsets into input - - Parameters - ------------ - inputs : ``tf.Tensor`` - shape = (b, h, w, c) - offsets: ``tf.Tensor`` - shape = (b, h, w, 2*n) - grid_offset: `tf.Tensor`` - Offset grids shape = (h, w, n, 2) - - Returns - ------- - ``tf.Tensor`` - A Tensor with the shape as (b, h, w, c) - +class DeformableConv2dLayer(Layer): + """The :class:`DeformableConv2dLayer` class is a 2D + `Deformable Convolutional Networks `__. """ - input_shape = inputs.get_shape() - batch_size = tf.shape(inputs)[0] - kernel_n = int(int(offsets.get_shape()[3]) / 2) - input_h = input_shape[1] - input_w = input_shape[2] - channel = input_shape[3] - - # inputs (b, h, w, c) --> (b*c, h, w) - inputs = _to_bc_h_w(inputs, input_shape) - - # offsets (b, h, w, 2*n) --> (b, h, w, n, 2) - offsets = tf.reshape(offsets, (batch_size, input_h, input_w, kernel_n, 2)) - # offsets (b, h, w, n, 2) --> (b*c, h, w, n, 2) - # offsets = tf.tile(offsets, [channel, 1, 1, 1, 1]) - - coords = tf.expand_dims(grid_offset, 0) # grid_offset --> (1, h, w, n, 2) - coords = tf.tile(coords, [batch_size, 1, 1, 1, 1]) + offsets # grid_offset --> (b, h, w, n, 2) - - # clip out of bound - coords = tf.stack( - [ - tf.clip_by_value(coords[:, :, :, :, 0], 0.0, tf.cast(input_h - 1, 'float32')), - tf.clip_by_value(coords[:, :, :, :, 1], 0.0, tf.cast(input_w - 1, 'float32')) - ], - axis=-1) - coords = tf.tile(coords, [channel, 1, 1, 1, 1]) - - mapped_vals = tf_batch_map_coordinates(inputs, coords) - # (b*c, h, w, n) --> (b, h, w, n, c) - mapped_vals = _to_b_h_w_n_c(mapped_vals, [batch_size, input_h, input_w, kernel_n, channel]) - return mapped_vals + def __init__(self, + layer, + act=tf.identity, + offset_layer=None, + shape=(3, 3, 1, 100), + name='deformable_conv_2d_layer', + W_init=tf.truncated_normal_initializer(stddev=0.02), + b_init=tf.constant_initializer(value=0.0), + W_init_args={}, + b_init_args={}): + raise Exception("deprecated, use DeformableConv2d instead") -class DeformableConv2dLayer(Layer): - """The :class:`DeformableConv2dLayer` class is a 2D +class DeformableConv2d(Layer): + """The :class:`DeformableConv2d` class is a 2D `Deformable Convolutional Networks `__. Parameters @@ -682,10 +585,12 @@ class DeformableConv2dLayer(Layer): To predict the offset of convolution operations. The output shape is (batchsize, input height, input width, 2*(number of element in the convolution kernel)) e.g. if apply a 3*3 kernel, the number of the last dimension should be 18 (2*3*3) + n_filter : int + The number of filters. + filter_size : tuple of int + The filter size (height, width). act : activation function The activation function of this layer. - shape : tuple of int - The shape of the filters: [filter_height, filter_width, in_channels, out_channels]. W_init : initializer The initializer for the weight matrix. b_init : initializer or None @@ -700,14 +605,14 @@ class DeformableConv2dLayer(Layer): Examples -------- >>> net = tl.layers.InputLayer(x, name='input_layer') - >>> offset_1 = tl.layers.Conv2dLayer(layer=net, act=act, shape=(3, 3, 3, 18), strides=(1, 1, 1, 1),padding='SAME', name='offset_layer1') - >>> net = tl.layers.DeformableConv2dLayer(layer=net, act=act, offset_layer=offset_1, shape=(3, 3, 3, 32), name='deformable_conv_2d_layer1') - >>> offset_2 = tl.layers.Conv2dLayer(layer=net, act=act, shape=(3, 3, 32, 18), strides=(1, 1, 1, 1), padding='SAME', name='offset_layer2') - >>> net = tl.layers.DeformableConv2dLayer(layer=net, act=act, offset_layer=offset_2, shape=(3, 3, 32, 64), name='deformable_conv_2d_layer2') + >>> offset1 = tl.layers.Conv2d(net, 18, (3, 3), (1, 1), act=act, padding='SAME', name='offset1') + >>> net = tl.layers.DeformableConv2d(net, offset1, 32, (3, 3), act=act, name='deformable1') + >>> offset2 = tl.layers.Conv2d(net, 18, (3, 3), (1, 1), act=act, padding='SAME', name='offset2') + >>> net = tl.layers.DeformableConv2d(net, offset2, 64, (3, 3), act=act, name='deformable2') References ---------- - - The deformation operation was adapted from the implementation in ``__ + - The deformation operation was adapted from the implementation in `here `__ Notes ----- @@ -716,24 +621,169 @@ class DeformableConv2dLayer(Layer): """ - def __init__(self, - layer, - act=tf.identity, - offset_layer=None, - shape=(3, 3, 1, 100), - name='deformable_conv_2d_layer', - W_init=tf.truncated_normal_initializer(stddev=0.02), - b_init=tf.constant_initializer(value=0.0), - W_init_args={}, - b_init_args={}): + # >>> net = tl.layers.InputLayer(x, name='input_layer') + # >>> offset_1 = tl.layers.Conv2dLayer(layer=net, act=act, shape=(3, 3, 3, 18), strides=(1, 1, 1, 1),padding='SAME', name='offset_layer1') + # >>> net = tl.layers.DeformableConv2dLayer(layer=net, act=act, offset_layer=offset_1, shape=(3, 3, 3, 32), name='deformable_conv_2d_layer1') + # >>> offset_2 = tl.layers.Conv2dLayer(layer=net, act=act, shape=(3, 3, 32, 18), strides=(1, 1, 1, 1), padding='SAME', name='offset_layer2') + # >>> net = tl.layers.DeformableConv2dLayer(layer=net, act=act, offset_layer=offset_2, shape=(3, 3, 32, 64), name='deformable_conv_2d_layer2') + def __init__( + self, + layer, + offset_layer=None, + # shape=(3, 3, 1, 100), + n_filter=32, + filter_size=(3, 3), + act=tf.identity, + name='deformable_conv_2d', + W_init=tf.truncated_normal_initializer(stddev=0.02), + b_init=tf.constant_initializer(value=0.0), + W_init_args={}, + b_init_args={}): if tf.__version__ < "1.4": - raise Exception("Deformable CNN layer requires tensrflow 1.4 or higher version") + raise Exception("Deformable CNN layer requires tensrflow 1.4 or higher version | current version %s" % tf.__version__) + + def _to_bc_h_w(x, x_shape): + """(b, h, w, c) -> (b*c, h, w)""" + x = tf.transpose(x, [0, 3, 1, 2]) + x = tf.reshape(x, (-1, x_shape[1], x_shape[2])) + return x + + def _to_b_h_w_n_c(x, x_shape): + """(b*c, h, w, n) -> (b, h, w, n, c)""" + x = tf.reshape(x, (-1, x_shape[4], x_shape[1], x_shape[2], x_shape[3])) + x = tf.transpose(x, [0, 2, 3, 4, 1]) + return x + + def tf_flatten(a): + """Flatten tensor""" + return tf.reshape(a, [-1]) + + def _get_vals_by_coords(inputs, coords, idx, out_shape): + indices = tf.stack([idx, tf_flatten(coords[:, :, :, :, 0]), tf_flatten(coords[:, :, :, :, 1])], axis=-1) + vals = tf.gather_nd(inputs, indices) + vals = tf.reshape(vals, out_shape) + return vals + + def _tf_repeat(a, repeats): + """Tensorflow version of np.repeat for 1D""" + # https://github.com/tensorflow/tensorflow/issues/8521 + assert len(a.get_shape()) == 1 + + a = tf.expand_dims(a, -1) + a = tf.tile(a, [1, repeats]) + a = tf_flatten(a) + return a + + def _tf_batch_map_coordinates(inputs, coords): + """Batch version of tf_map_coordinates + + Only supports 2D feature maps + + Parameters + ---------- + inputs : ``tf.Tensor`` + shape = (b*c, h, w) + coords : ``tf.Tensor`` + shape = (b*c, h, w, n, 2) + + Returns + ------- + ``tf.Tensor`` + A Tensor with the shape as (b*c, h, w, n) + + """ + input_shape = inputs.get_shape() + coords_shape = coords.get_shape() + batch_channel = tf.shape(inputs)[0] + input_h = int(input_shape[1]) + input_w = int(input_shape[2]) + kernel_n = int(coords_shape[3]) + n_coords = input_h * input_w * kernel_n + + coords_lt = tf.cast(tf.floor(coords), 'int32') + coords_rb = tf.cast(tf.ceil(coords), 'int32') + coords_lb = tf.stack([coords_lt[:, :, :, :, 0], coords_rb[:, :, :, :, 1]], axis=-1) + coords_rt = tf.stack([coords_rb[:, :, :, :, 0], coords_lt[:, :, :, :, 1]], axis=-1) + + idx = _tf_repeat(tf.range(batch_channel), n_coords) + + vals_lt = _get_vals_by_coords(inputs, coords_lt, idx, (batch_channel, input_h, input_w, kernel_n)) + vals_rb = _get_vals_by_coords(inputs, coords_rb, idx, (batch_channel, input_h, input_w, kernel_n)) + vals_lb = _get_vals_by_coords(inputs, coords_lb, idx, (batch_channel, input_h, input_w, kernel_n)) + vals_rt = _get_vals_by_coords(inputs, coords_rt, idx, (batch_channel, input_h, input_w, kernel_n)) + + coords_offset_lt = coords - tf.cast(coords_lt, 'float32') + + vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, :, :, :, 0] + vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, :, :, :, 0] + mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, :, :, :, 1] + + return mapped_vals + + def _tf_batch_map_offsets(inputs, offsets, grid_offset): + """Batch map offsets into input + + Parameters + ------------ + inputs : ``tf.Tensor`` + shape = (b, h, w, c) + offsets: ``tf.Tensor`` + shape = (b, h, w, 2*n) + grid_offset: `tf.Tensor`` + Offset grids shape = (h, w, n, 2) + + Returns + ------- + ``tf.Tensor`` + A Tensor with the shape as (b, h, w, c) + + """ + input_shape = inputs.get_shape() + batch_size = tf.shape(inputs)[0] + kernel_n = int(int(offsets.get_shape()[3]) / 2) + input_h = input_shape[1] + input_w = input_shape[2] + channel = input_shape[3] + + # inputs (b, h, w, c) --> (b*c, h, w) + inputs = _to_bc_h_w(inputs, input_shape) + + # offsets (b, h, w, 2*n) --> (b, h, w, n, 2) + offsets = tf.reshape(offsets, (batch_size, input_h, input_w, kernel_n, 2)) + # offsets (b, h, w, n, 2) --> (b*c, h, w, n, 2) + # offsets = tf.tile(offsets, [channel, 1, 1, 1, 1]) + + coords = tf.expand_dims(grid_offset, 0) # grid_offset --> (1, h, w, n, 2) + coords = tf.tile(coords, [batch_size, 1, 1, 1, 1]) + offsets # grid_offset --> (b, h, w, n, 2) + + # clip out of bound + coords = tf.stack( + [ + tf.clip_by_value(coords[:, :, :, :, 0], 0.0, tf.cast(input_h - 1, 'float32')), + tf.clip_by_value(coords[:, :, :, :, 1], 0.0, tf.cast(input_w - 1, 'float32')) + ], + axis=-1) + coords = tf.tile(coords, [channel, 1, 1, 1, 1]) + + mapped_vals = _tf_batch_map_coordinates(inputs, coords) + # (b*c, h, w, n) --> (b, h, w, n, c) + mapped_vals = _to_b_h_w_n_c(mapped_vals, [batch_size, input_h, input_w, kernel_n, channel]) + + return mapped_vals Layer.__init__(self, name=name) self.inputs = layer.outputs self.offset_layer = offset_layer + if act is None: + act = tf.identity + logging.info("DeformableConv2d %s: n_filter: %d, filter_size: %s act:%s" % (self.name, n_filter, str(filter_size), act.__name__)) - logging.info("DeformableConv2dLayer %s: shape:%s, act:%s" % (self.name, str(shape), act.__name__)) + try: + pre_channel = int(layer.outputs.get_shape()[-1]) + except: # if pre_channel is ?, it happens when using Spatial Transformer Net + pre_channel = 1 + logging.info("[warnings] unknow input channels, set to 1") + shape = (filter_size[0], filter_size[1], pre_channel, n_filter) with tf.variable_scope(name) as vs: offset = self.offset_layer.outputs @@ -760,14 +810,20 @@ def __init__(self, grid = tf.tile(grid, [1, 1, kernel_n, 1]) # grid --> (h, w, n, 2) grid_offset = grid + initial_offsets # grid_offset --> (h, w, n, 2) - input_deform = tf_batch_map_offsets(self.inputs, offset, grid_offset) + input_deform = _tf_batch_map_offsets(self.inputs, offset, grid_offset) - W = tf.get_variable(name='W_conv2d', shape=[1, 1, shape[0] * shape[1], shape[-2], shape[-1]], initializer=W_init, dtype=D_TYPE, **W_init_args) - b = tf.get_variable(name='b_conv2d', shape=(shape[-1]), initializer=b_init, dtype=D_TYPE, **b_init_args) + W = tf.get_variable( + name='W_deformableconv2d', shape=[1, 1, shape[0] * shape[1], shape[-2], shape[-1]], initializer=W_init, dtype=D_TYPE, **W_init_args) - self.outputs = tf.reshape( - act(tf.nn.conv3d(input_deform, W, strides=[1, 1, 1, 1, 1], padding='VALID', name=None) + b), - (tf.shape(self.inputs)[0], input_h, input_w, shape[-1])) + if b_init: + b = tf.get_variable(name='b_deformableconv2d', shape=(shape[-1]), initializer=b_init, dtype=D_TYPE, **b_init_args) + self.outputs = tf.reshape( + act(tf.nn.conv3d(input_deform, W, strides=[1, 1, 1, 1, 1], padding='VALID', name=None) + b), + (tf.shape(self.inputs)[0], input_h, input_w, shape[-1])) + else: + self.outputs = tf.reshape( + act(tf.nn.conv3d(input_deform, W, strides=[1, 1, 1, 1, 1], padding='VALID', name=None)), + (tf.shape(self.inputs)[0], input_h, input_w, shape[-1])) # fixed self.all_layers = list(layer.all_layers) @@ -784,62 +840,10 @@ def __init__(self, # this layer self.all_layers.extend([self.outputs]) - self.all_params.extend([W, b]) - - -class _DeformableConv2d(DeformableConv2dLayer): # TODO - """Simplified version of :class:`DeformableConv2dLayer`, see - `Deformable Convolutional Networks `__. - - Parameters - ---------- - layer : :class:`Layer` - Previous layer. - offset_layer : :class:`Layer` - To predict the offset of convolution operations. - The output shape is (batchsize, input height, input width, 2*(number of element in the convolution kernel)) - e.g. if apply a 3*3 kernel, the number of the last dimension should be 18 (2*3*3) - act : activation function - The activation function of this layer. - n_filter : int - The number of filters. - filter_size : tuple of int - The filter size (height, width). - W_init : initializer - The initializer for the weight matrix. - b_init : initializer or None - The initializer for the bias vector. If None, skip biases. - W_init_args : dictionary - The arguments for the weight matrix initializer. - b_init_args : dictionary - The arguments for the bias vector initializer. - name : str - A unique layer name. - """ - - def __init__( - self, - layer, - act=tf.identity, - offset_layer=None, - # shape=(3, 3, 1, 100), - n_filter=32, - filter_size=(3, 3), - name='deformable_conv_2d_layer', - W_init=tf.truncated_normal_initializer(stddev=0.02), - b_init=tf.constant_initializer(value=0.0), - W_init_args={}, - b_init_args={}): - - try: - pre_channel = int(layer.outputs.get_shape()[-1]) - except: # if pre_channel is ?, it happens when using Spatial Transformer Net - pre_channel = 1 - logging.info("[warnings] unknow input channels, set to 1") - shape = (filter_size[0], filter_size[1], pre_channel, n_filter) - - DeformableConv2dLayer.__init__( - self, act=act, offset_layer=offset_layer, shape=shape, name=name, W_init=W_init, b_init=b_init, W_init_args=W_init_args, b_init_args=b_init_args) + if b_init: + self.all_params.extend([W, b]) + else: + self.all_params.extend([W]) def atrous_conv1d( @@ -957,6 +961,8 @@ def __init__(self, name='atrou2d'): Layer.__init__(self, name=name) self.inputs = layer.outputs + if act is None: + act = tf.identity logging.info("AtrousConv2dLayer %s: n_filter:%d filter_size:%s rate:%d pad:%s act:%s" % (self.name, n_filter, filter_size, rate, padding, act.__name__)) with tf.variable_scope(name) as vs: shape = [filter_size[0], filter_size[1], int(self.inputs.get_shape()[-1]), n_filter]