From 71c1893835eda5e973be5141c17e4c612d78f6dd Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 14:32:26 +0800 Subject: [PATCH 01/17] Add octconv layers and obj coord preprocessing ops. Add octconv layers and obj coord preprocessing ops. --- docs/modules/cost.rst | 2 +- docs/modules/layers.rst | 34 ++ docs/modules/prepro.rst | 10 + tensorlayer/layers/convolution/__init__.py | 8 + tensorlayer/layers/convolution/oct_conv.py | 555 +++++++++++++++++++++ tensorlayer/prepro.py | 135 +++++ tests/layers/test_layers_convolution.py | 60 +++ 7 files changed, 803 insertions(+), 1 deletion(-) create mode 100644 tensorlayer/layers/convolution/oct_conv.py diff --git a/docs/modules/cost.rst b/docs/modules/cost.rst index eba52f4ca..bf9fe939c 100644 --- a/docs/modules/cost.rst +++ b/docs/modules/cost.rst @@ -96,5 +96,5 @@ Special .. autofunction:: maxnorm_i_regularizer Huber Loss -^^^^^^^^^^ +-------------------------- .. autofunction:: huber_loss \ No newline at end of file diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index 7a70b54dc..d410ac5d6 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -39,6 +39,12 @@ Layer list SeparableConv2d DeformableConv2d GroupConv2d + OctConv2dIn + OctConv2d + OctConv2dOut + OctConv2dHighOut + OctConv2dLowOut + OctConv2dConcat PadLayer PoolLayer @@ -241,6 +247,34 @@ GroupConv2d """"""""""""""""""""" .. autoclass:: GroupConv2d +OctConv2d +-------------------------- + +For OctConv2d, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + +OctConv2dIn +""""""""""""""""""""" +.. autoclass:: OctConv2dIn + +OctConv2d +""""""""""""""""""""" +.. autoclass:: OctConv2d + +OctConv2dOut +""""""""""""""""""""" +.. autoclass:: OctConv2dOut + +OctConv2dHighOut +""""""""""""""""""""" +.. autoclass:: OctConv2dHighOut + +OctConv2dConcat +""""""""""""""""""""" +.. autoclass:: OctConv2dConcat + +OctConv2dLowOut +""""""""""""""""""""" +.. autoclass:: OctConv2dLowOut Separable Convolutions ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/modules/prepro.rst b/docs/modules/prepro.rst index ec65da066..da1bd3189 100644 --- a/docs/modules/prepro.rst +++ b/docs/modules/prepro.rst @@ -79,6 +79,8 @@ API - Data Pre-Processing obj_box_coord_upleft_butright_to_centroid obj_box_coord_centroid_to_upleft obj_box_coord_upleft_to_centroid + obj_box_coord_affine + rotated_obj_box_coord_affine parse_darknet_ann_str_to_list parse_darknet_ann_list_to_cls_box @@ -577,6 +579,14 @@ Image Aug - Zoom ^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: obj_box_zoom +Image Aug - Affine +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autofunction:: obj_box_coord_affine + +Image Aug - Rotated-Affine +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autofunction:: rotated_obj_box_coord_affine + Keypoints ------------ diff --git a/tensorlayer/layers/convolution/__init__.py b/tensorlayer/layers/convolution/__init__.py index ba68797f2..0f9ebe237 100644 --- a/tensorlayer/layers/convolution/__init__.py +++ b/tensorlayer/layers/convolution/__init__.py @@ -80,4 +80,12 @@ #quan_conv 'QuanConv2d', 'QuanConv2dWithBN', + + # octave_conv + 'OctConv2dIn', + 'OctConv2d', + 'OctConv2dOut', + 'OctConv2dHighOut', + 'OctConv2dLowOut', + 'OctConv2dConcat', ] diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py new file mode 100644 index 000000000..3ea81befb --- /dev/null +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -0,0 +1,555 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import tensorflow as tf + +import tensorlayer as tl +from tensorlayer import logging +from tensorlayer.decorators import deprecated_alias +from tensorlayer.layers.core import Layer + +# from tensorlayer.layers.core import LayersConfig + + + +__all__ = [ + 'OctConv2dIn', + 'OctConv2d', + 'OctConv2dOut', + 'OctConv2dHighOut', + 'OctConv2dLowOut', + 'OctConv2dConcat', +] + +class OctConv2dIn(Layer): + """ + The :class:`OctConv2dIn` class is a preprocessing layer for 2D image [batch, height, width, channel], see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + + Parameters + ---------- + name : None or str + A unique layer name. + + Notes + ----- + - The height and width of input must be a multiple of the 2. + - Use this layer before any other octconv layers. + - The output will be a list which contains 2 tensor. + + Examples + -------- + With TensorLayer + + >>> net = tl.layers.Input([8, 28, 28, 16], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] + + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info( + "OctConv2dIn %s: " % ( + self.name, + ) + ) + + def __repr__(self): + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs): + pass + + def forward(self, inputs): + high_out=tf.identity(inputs,name=(self.name+'_high_out')) + low_out = tf.nn.avg_pool2d(inputs, (2,2), strides=(2,2),padding='SAME',name=self.name+'_low_out') + outputs=[high_out,low_out] + return outputs + + +class OctConv2d(Layer): + """ + The :class:`OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. Use this layer to process tensor list. + + Parameters + ---------- + filter : int + The sum of the number of filters. + alpha : :float + The percentage of high_res output. + filter_size : tuple of int + The filter size (height, width). + strides : tuple of int + The sliding window strides of corresponding input dimensions. + It must be in the same order as the ``shape`` parameter. + W_init : initializer + The initializer for the weight matrix. + b_init : initializer or None + The initializer for the bias vector. If None, skip biases. + act : activation function + The activation function of this layer. + name : None or str + A unique layer name. + + Notes + ----- + - The input should be a list with shape [high_res_tensor , low_res_tensor], the height and width of high_res should be twice of the low_res_tensor. + - If you do not which tensor is larger, use OctConv2dConcat layer. + - The output will be a list which contains 2 tensor. + - You should not use the output directly. + + Examples + -------- + With TensorLayer + + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2d(32,0.5,act=tf.nn.relu, name='octconv2d_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] + + """ + + def __init__( + self, + filter=32, + alpha=0.5, + filter_size=(2, 2), + strides=(1,1), + W_init=tl.initializers.truncated_normal(stddev=0.02), + b_init=tl.initializers.constant(value=0.0), + act=None, + in_channels=None, + name=None # 'cnn2d_layer', + ): + super().__init__(name) + self.filter = filter + self.alpha = alpha + if (self.alpha >= 1) or (self.alpha <= 0): + raise ValueError( + "The alpha must be in (0,1)") + self.high_out = int(self.alpha * self.filter) + self.low_out = self.filter - self.high_out + if (self.high_out == 0) or (self.low_out == 0): + raise ValueError( + "The output channel must be greater than 0.") + self.filter_size = filter_size + self.strides = strides + self.W_init = W_init + self.b_init = b_init + self.act = act + self.in_channels = in_channels + if self.in_channels: + self.build(None) + self._built = True + + + logging.info( + "OctConv2d %s: filter_size: %s strides: %s high_out: %s low_out: %s act: %s" % ( + self.name, str(filter_size), str(strides), str(self.high_out), str(self.low_out), + self.act.__name__ if self.act is not None else 'No Activation' + ) + ) + + def __repr__(self): + actstr = self.act.__name__ if self.act is not None else 'No Activation' + s = ('{classname}(in_channels={in_channels}, out_channels={filter} kernel_size={filter_size}' + ', strides={strides}') + if self.b_init is None: + s += ', bias=False' + s += (', ' + actstr) + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape): + if not self.in_channels: + high_ch=inputs_shape[0][-1] + low_ch=inputs_shape[1][-1] + else: + high_ch=self.in_channels[0] + low_ch=self.in_channels[1] + self.high_high_filter_shape = ( + self.filter_size[0], self.filter_size[1], high_ch, self.high_out + ) + self.high_low_filter_shape = ( + self.filter_size[0], self.filter_size[1], high_ch, self.low_out + ) + self.low_low_filter_shape = ( + self.filter_size[0], self.filter_size[1], low_ch, self.low_out + ) + self.low_high_filter_shape = ( + self.filter_size[0], self.filter_size[1], low_ch, self.high_out + ) + self.high_high__W = self._get_weights( + "high_high_filters", shape=self.high_high_filter_shape, init=self.W_init + ) + self.high_low__W = self._get_weights( + "high_low_filters", shape=self.high_low_filter_shape, init=self.W_init + ) + self.low_low_W = self._get_weights( + "low_low_filters", shape=self.low_low_filter_shape, init=self.W_init + ) + self.low_high_W = self._get_weights( + "low_high_filters", shape=self.low_high_filter_shape, init=self.W_init + ) + if self.b_init: + self.high_b = self._get_weights( + "high_biases", shape=(self.high_out), init=self.b_init + ) + self.low_b = self._get_weights( + "low_biases", shape=(self.low_out), init=self.b_init + ) + + def forward(self, inputs): + high_input = inputs[0] + low_input=inputs[1] + high_to_high = tf.nn.conv2d(high_input, self.high_high__W, + strides=self.strides, padding="SAME") + high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') + high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") + low_to_low = tf.nn.conv2d(low_input, self.low_low_W, + strides=self.strides, padding="SAME") + low_to_high = tf.nn.conv2d(low_input, self.low_high_W, + strides=self.strides, padding="SAME") + low_to_high=tf.keras.layers.UpSampling2D(size=(2,2), interpolation='nearest')(low_to_high) + high_out=high_to_high+low_to_high + low_out=low_to_low+high_to_low + if self.b_init: + high_out = tf.nn.bias_add(high_out, self.high_b, data_format="NHWC") + low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") + if self.act: + high_out = self.act(high_out,name=self.name+'_high_out') + low_out= self.act(low_out,name=self.name+'_low_out') + else: + high_out=tf.identity(high_out,name=self.name+'_high_out') + low_out=tf.identity(low_out,name=self.name+'_low_out') + outputs=[high_out,low_out] + return outputs + + + +class OctConv2dOut(Layer): + """ + The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer output to get only a tensor, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. Use this layer after other Octconv layers and get a normal tensor output. + + Parameters + ---------- + filter : int + The number of filters. + filter_size : tuple of int + The filter size (height, width). + strides : tuple of int + The sliding window strides of corresponding input dimensions. + It must be in the same order as the ``shape`` parameter. + W_init : initializer + The initializer for the weight matrix. + b_init : initializer or None + The initializer for the bias vector. If None, skip biases. + act : activation function + The activation function of this layer. + name : None or str + A unique layer name. + + Notes + ----- + - Use this layer to get only a tensor for other normal layer. + + Examples + -------- + With TensorLayer + + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2d(32,0.5,act=tf.nn.relu, name='octconv2d_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] + >>> octconv2d = tl.layers.OctConv2dOut(32,act=tf.nn.relu, name='octconv2dout_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : (8, 14, 14, 32) + + """ + + def __init__( + self, + n_filter=32, + filter_size=(2, 2), + strides=(1,1), + W_init=tl.initializers.truncated_normal(stddev=0.02), + b_init=tl.initializers.constant(value=0.0), + act=None, + in_channels=None, + name=None # 'cnn2d_layer', + ): + super().__init__(name) + + self.high_out = n_filter + self.low_out = n_filter + self.filter_size = filter_size + self.strides = strides + self.W_init = W_init + self.b_init = b_init + self.act = act + self.in_channels = in_channels + if self.in_channels: + self.build(None) + self._built = True + + logging.info( + "OctConv2dOut %s: filter_size: %s strides: %s out_channels: %s act: %s" % ( + self.name, str(filter_size), str(strides), str(self.low_out), + self.act.__name__ if self.act is not None else 'No Activation' + ) + ) + + def __repr__(self): + actstr = self.act.__name__ if self.act is not None else 'No Activation' + s = ('{classname}(in_channels={in_channels}, out_channels={low_out}, kernel_size={filter_size}' + ', strides={strides}') + if self.b_init is None: + s += ', bias=False' + s += (', ' + actstr) + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape): + if not self.in_channels: + high_ch=inputs_shape[0][-1] + low_ch=inputs_shape[1][-1] + else: + high_ch=self.in_channels[0] + low_ch=self.in_channels[1] + self.high_low_filter_shape = ( + self.filter_size[0], self.filter_size[1], high_ch, self.high_out + ) + self.low_low_filter_shape = ( + self.filter_size[0], self.filter_size[1], low_ch, self.low_out + ) + self.high_low__W = self._get_weights( + "high_low_filters", shape=self.high_low_filter_shape, init=self.W_init + ) + self.low_low_W = self._get_weights( + "low_low_filters", shape=self.low_low_filter_shape, init=self.W_init + ) + if self.b_init: + self.low_b = self._get_weights( + "low_biases", shape=(self.low_out), init=self.b_init + ) + + def forward(self, inputs): + high_input = inputs[0] + low_input=inputs[1] + high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') + high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") + low_to_low = tf.nn.conv2d(low_input, self.low_low_W, + strides=self.strides, padding="SAME") + low_out=low_to_low+high_to_low + if self.b_init: + low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") + if self.act: + outputs= self.act(low_out,name=self.name+'_low_out') + else: + outputs=tf.identity(low_out,name=self.name+'_low_out') + return outputs + + + + +class OctConv2dHighOut(Layer): + """ + The :class:`OctConv2dHighOut` class is a slice layer for Octconv tensor list, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + + Parameters + ---------- + name : None or str + A unique layer name. + + Notes + ----- + - Use this layer to get high resolution tensor. + - If you want to do some customized normalization ops, use this layer with OctConv2dLowOut and OctConv2dConcat layers to implement your idea. + + Examples + -------- + With TensorLayer + + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2dHighOut(name='octconv2dho_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : (8, 28, 28, 32) + + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info( + "OctConv2dHighOut %s: " % ( + self.name, + ) + ) + + def __repr__(self): + + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs): + pass + + def forward(self, inputs): + outputs=tf.identity(inputs[0],self.name) + return outputs + + +class OctConv2dLowOut (Layer): + """ + The :class:`OctConv2dLowOut` class is a slice layer for Octconv tensor list, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + + Parameters + ---------- + name : None or str + A unique layer name. + + Notes + ----- + - Use this layer to get low resolution tensor. + - If you want to do some customized normalization ops, use this layer with OctConv2dHighOut and OctConv2dConcat layers to implement your idea. + + Examples + -------- + With TensorLayer + + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2dLowOut(name='octconv2dlo_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : (8, 14, 14, 32) + + + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info( + "OctConv2dHighOut %s: " % ( + self.name, + ) + ) + + def __repr__(self): + + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs): + pass + + def forward(self, inputs): + outputs=tf.identity(inputs[1],self.name) + return outputs + +class OctConv2dConcat(Layer): + """ + The :class:`OctConv2dConcat` class is a concat layer for two 2D image batches, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + + Parameters + ---------- + name : None or str + A unique layer name. + + Notes + ----- + - Use this layer to concat two tensor. + - The height and width of one tensor should be twice of the other tensor. + + Examples + -------- + With TensorLayer + + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1') + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2dl = tl.layers.OctConv2dLowOut(name='octconv2dlo_1')(octconv2d) + >>> octconv2dh = tl.layers.OctConv2dHighOut(name='octconv2dho_1')(octconv2d) + >>> octconv2 = tl.layers.OctConv2dConcat(name='octconv2dcat_1')([octconv2dh,octconv2dl]) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info( + "OctConv2dConcat %s: " % ( + self.name, + ) + ) + + def __repr__(self): + + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs): + pass + + def forward(self, inputs): + if inputs[0].shape[1]>inputs[1].shape[1]: + outputs=[inputs[0],inputs[1]] + else: + outputs = [inputs[1], inputs[0]] + return outputs diff --git a/tensorlayer/prepro.py b/tensorlayer/prepro.py index 931694308..e10755c1f 100644 --- a/tensorlayer/prepro.py +++ b/tensorlayer/prepro.py @@ -111,6 +111,8 @@ 'keypoint_random_flip', 'keypoint_random_resize', 'keypoint_random_resize_shortestedge', + 'obj_box_coord_affine', + 'rotated_obj_box_coord_affine', ] @@ -4145,3 +4147,136 @@ def pose_resize_shortestedge(image, annos, mask, target_size): return dst, adjust_joint_list, None return pose_resize_shortestedge(image, annos, mask, target_size) + + +def obj_box_coord_affine( + classes=None,coords=None, affine_matrix=None,affine_matrix_inv=None,min_ratio=0.0,min_width=0.0,min_height=0.0 +): + """Apply affine transform the box coordinates, and gets the new box coordinates. + + Parameters + ----------- + classes : list of int or None + Class IDs. + coords : list of list of 4 int/float or None + Coordinates [[x, y, w, h], [x, y, w, h], ...]. Here x,y are the center coordinates of the bbox. + affine_matrix : np.ndarray + The affine matrix of the image. A ndarray with shape [2,3]. + affine_matrix_inv : np.ndarray + The inverse of the affine matrix. A ndarray with shape [2,3]. + min_ratio : float + Threshold, remove the box if its ratio of new size to old size less than the threshold. + min_width : float + Threshold, remove the box if its ratio of width to image size less than the threshold. + min_height : float + Threshold, remove the box if its ratio of height to image size less than the threshold. + + Returns + ------- + list of int + A list of classes + list of list of 4 numbers + A list of new bounding boxes. + """ + Me = np.array([[0, 0, 1.]]) + new_classes=[] + new_cords=[] + if affine_matrix_inv==None: + affine_matrix_inv=np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) + for bbox_idx,bbox in enumerate(coords): + old_pt = np.array([[bbox[0] * 2. - 1.0], [bbox[1] * 2. - 1.0], [1.0]]) + new_wh_a = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[affine_matrix_inv[2]], [bbox[3]]])) + new_wh_b = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[affine_matrix_inv[2]], [-bbox[3]]])) + new_w = max(abs(new_wh_a[0]), abs(new_wh_b[0])) + new_h = max(abs(new_wh_a[1]), abs(new_wh_b[1])) + new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.0) / 2.0 + bbox_left = new_pt[0] - new_w / 2.0 + if bbox_left <= 0: + bbox_left = 0. + if bbox_left >= 1: + bbox_left = 1. + bbox_right = new_pt[0] + new_w / 2.0 + if bbox_right <= 0: + bbox_right = 0. + if bbox_right >= 1: + bbox_right = 1. + bbox_top = new_pt[1] + new_h / 2.0 + if bbox_top <= 0: + bbox_top = 0. + if bbox_top >= 1: + bbox_top = 1. + bbox_bottom = new_pt[1] - new_h / 2.0 + if bbox_bottom <= 0: + bbox_bottom = 0. + if bbox_bottom >= 1: + bbox_bottom = 1. + ratio = abs(bbox_right - bbox_left) * abs(bbox_top - bbox_bottom) / (new_w + 0.00001) / (new_h + 0.00001) + bbox_x = abs(bbox_right + bbox_left) / 2. + bbox_y = abs(bbox_top + bbox_bottom) / 2. + bbox_h = abs(bbox_top - bbox_bottom) + bbox_w = abs(bbox_right - bbox_left) + if (ratio > min_ratio) & (bbox_h >= min_height) & (bbox_w >= min_width): + new_classes.append(classes[bbox_idx]) + new_cords.append([bbox_x,bbox_y,bbox_w,bbox_h]) + return new_classes,new_cords + + +def rotated_obj_box_coord_affine( + classes=None,coords=None, affine_matrix=None,affine_matrix_inv=None + ): + """Apply affine transform the box coordinates with rotation, and gets the new box coordinates with rotation. Experimental! + + Parameters + ----------- + classes : list of int or None + Class IDs. + coords : list of list of 5 int/float or None + Coordinates [[x, y, w, h, r], [x, y, w, h, r], ...]. Here x,y are the center coordinates of the bbox, r is the radius. + affine_matrix : np.ndarray + The affine matrix of the image. A ndarray with shape [2,3]. + affine_matrix_inv : np.ndarray + The inverse of the affine matrix. A ndarray with shape [2,3]. + + Returns + ------- + list of int + A list of classes + list of list of 5 numbers + A list of new bounding boxes. + """ + Me = np.array([[0, 0, 1.]]) + new_classes=[] + new_cords=[] + if affine_matrix_inv==None: + affine_matrix_inv=np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) + for bbox_idx,bbox in enumerate(coords): + centerx = bbox[0] + centery = bbox[1] + old_pt = np.array([[centerx * 2. - 1.0], [centery * 2. - 1.0], [1.0]]) + top_center = np.array([[0.], [-bbox[3]], [1.0]]) + right_center = np.array([[bbox[2]], [0.], [1.0]]) + rot = bbox[-1] + rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), np.cos(rot), 0], [0, 0, 1]]) + top_center = np.matmul(rot_mat, top_center) + old_pt + top_center[2][0] = 1. + right_center = np.matmul(rot_mat, right_center) + old_pt + right_center[2][0] = 1. + new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.0) / 2.0 + new_topcenter = (np.matmul(affine_matrix_inv, top_center) + 1.0) / 2.0 + new_rightcenter = (np.matmul(affine_matrix_inv, right_center) + 1.0) / 2.0 + new_h = np.sqrt((new_topcenter[0][0] - new_pt[0][0]) ** 2 + (new_topcenter[1][0] - new_pt[1][0]) ** 2) * 2. + new_w = np.sqrt((new_rightcenter[0][0] - new_pt[0][0]) ** 2 + (new_rightcenter[1][0] - new_pt[1][0]) ** 2) * 2. + deltax = -new_topcenter[0][0] + new_pt[0][0] + deltay = (-new_topcenter[1][0] + new_pt[1][0])* -1 + if deltay==0: + if deltax>=0: + new_rot = -math.pi/2 + else: + new_rot = math.pi / 2 + else: + tanx = deltax / deltay + new_rot = np.arctan(tanx) + if (new_pt[0][0]>0) and (new_pt[0][0]<1)and (new_pt[1][0]>0)and (new_pt[1][0]<1): + new_cords.append([1, new_pt[0][0], new_pt[1][0], new_w, new_h, new_rot]) + new_classes.append(classes[bbox_idx]) + return new_classes,new_cords \ No newline at end of file diff --git a/tests/layers/test_layers_convolution.py b/tests/layers/test_layers_convolution.py index 0f5979d5b..546592139 100644 --- a/tests/layers/test_layers_convolution.py +++ b/tests/layers/test_layers_convolution.py @@ -462,6 +462,66 @@ def test_layer_n4(self): # self.assertEqual(self.net2.count_params(), 19392) # self.assertEqual(self.net2.outputs.get_shape().as_list()[1:], [299, 299, 64]) +class Layer_OctConv_2D_Test(CustomTestCase): + + @classmethod + def setUpClass(cls): + print("\n#################################") + + cls.batch_size = 5 + cls.inputs_shape = [cls.batch_size, 32, 32, 16] + cls.input_layer = Input(cls.inputs_shape, name='input_layer') + + cls.n1 = tl.layers.OctConv2dIn(name='octconv2din')(cls.input_layer) + + cls.n2 = tl.layers.OctConv2d(32,0.5,act=tf.nn.relu, name='octconv2d')(cls.n1) + + cls.n3 = tl.layers.OctConv2dHighOut(name='octconv2dho')(cls.n2) + + cls.n4 = tl.layers.OctConv2dLowOut(name='octconv2dlo')(cls.n2) + + cls.n5 = tl.layers.OctConv2dConcat(name='octconv2dconcat')([cls.n3,cls.n4]) + + cls.n6 = tl.layers.OctConv2dOut(n_filter=32,name='octconv2dout')(cls.n5) + + cls.model = Model(cls.input_layer, cls.n6) + print("Testing OctConv2d model: \n", cls.model) + + @classmethod + def tearDownClass(cls): + pass + # tf.reset_default_graph() + + def test_layer_n1(self): + + self.assertEqual(self.n1[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n1[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(len(self.n1), 2) + + def test_layer_n2(self): + + self.assertEqual(self.n2[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n2[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(len(self.n2), 2) + + def test_layer_n3(self): + + self.assertEqual(self.n3.get_shape().as_list()[1:], [32, 32, 16]) + + def test_layer_n4(self): + + self.assertEqual(self.n4.get_shape().as_list()[1:], [16, 16, 16]) + + def test_layer_n5(self): + + self.assertEqual(self.n5[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n5[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(len(self.n2), 2) + + def test_layer_n6(self): + + self.assertEqual(self.n6.get_shape().as_list()[1:], [16, 16, 32]) + if __name__ == '__main__': tl.logging.set_verbosity(tl.logging.DEBUG) From e18a62b42a699d0b1c8926ca29e6d6c30b7dd518 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 14:57:43 +0800 Subject: [PATCH 02/17] Add import octconv. Add import octconv. --- tensorlayer/layers/convolution/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorlayer/layers/convolution/__init__.py b/tensorlayer/layers/convolution/__init__.py index 0f9ebe237..c96d4c5d0 100644 --- a/tensorlayer/layers/convolution/__init__.py +++ b/tensorlayer/layers/convolution/__init__.py @@ -24,6 +24,7 @@ from .ternary_conv import * from .quan_conv import * from .quan_conv_bn import * +from .oct_conv import * __all__ = [ From 73083fda5a67ffbaaa4b4f63a8011ecf4f44f06f Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 16:02:40 +0800 Subject: [PATCH 03/17] Code formatting Code formatting --- tensorlayer/layers/convolution/oct_conv.py | 145 +++++++++++---------- tensorlayer/prepro.py | 116 ++++++++++------- tests/layers/test_layers_convolution.py | 44 ++++--- 3 files changed, 176 insertions(+), 129 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 3ea81befb..78ecd372b 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -2,16 +2,10 @@ # -*- coding: utf-8 -*- import tensorflow as tf - import tensorlayer as tl from tensorlayer import logging -from tensorlayer.decorators import deprecated_alias from tensorlayer.layers.core import Layer -# from tensorlayer.layers.core import LayersConfig - - - __all__ = [ 'OctConv2dIn', 'OctConv2d', @@ -21,9 +15,12 @@ 'OctConv2dConcat', ] + class OctConv2dIn(Layer): """ - The :class:`OctConv2dIn` class is a preprocessing layer for 2D image [batch, height, width, channel], see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + The :class:`OctConv2dIn` class is a preprocessing layer for 2D image + [batch, height, width, channel], see `Drop an Octave: Reducing Spatial Redundancy in + Convolutional Neural Networks with Octave Convolution `__. Parameters ---------- @@ -68,19 +65,22 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self): pass def forward(self, inputs): - high_out=tf.identity(inputs,name=(self.name+'_high_out')) - low_out = tf.nn.avg_pool2d(inputs, (2,2), strides=(2,2),padding='SAME',name=self.name+'_low_out') - outputs=[high_out,low_out] + high_out = tf.identity(inputs, name=(self.name + '_high_out')) + low_out = tf.nn.avg_pool2d(inputs, (2, 2), strides=(2, 2), + padding='SAME', name=self.name + '_low_out') + outputs = [high_out, low_out] return outputs class OctConv2d(Layer): """ - The :class:`OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. Use this layer to process tensor list. + The :class:`OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see `Drop an Octave: + Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution + `__. Use this layer to process tensor list. Parameters ---------- @@ -104,7 +104,8 @@ class OctConv2d(Layer): Notes ----- - - The input should be a list with shape [high_res_tensor , low_res_tensor], the height and width of high_res should be twice of the low_res_tensor. + - The input should be a list with shape [high_res_tensor , low_res_tensor], + the height and width of high_res should be twice of the low_res_tensor. - If you do not which tensor is larger, use OctConv2dConcat layer. - The output will be a list which contains 2 tensor. - You should not use the output directly. @@ -128,12 +129,12 @@ def __init__( filter=32, alpha=0.5, filter_size=(2, 2), - strides=(1,1), + strides=(1, 1), W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, in_channels=None, - name=None # 'cnn2d_layer', + name=None ): super().__init__(name) self.filter = filter @@ -156,7 +157,6 @@ def __init__( self.build(None) self._built = True - logging.info( "OctConv2d %s: filter_size: %s strides: %s high_out: %s low_out: %s act: %s" % ( self.name, str(filter_size), str(strides), str(self.high_out), str(self.low_out), @@ -166,8 +166,8 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={filter} kernel_size={filter_size}' - ', strides={strides}') + s = ('{classname}(in_channels={in_channels}, out_channels={filter} ,' + 'kernel_size={filter_size}, strides={strides}') if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -179,11 +179,11 @@ def __repr__(self): def build(self, inputs_shape): if not self.in_channels: - high_ch=inputs_shape[0][-1] - low_ch=inputs_shape[1][-1] + high_ch = inputs_shape[0][-1] + low_ch = inputs_shape[1][-1] else: - high_ch=self.in_channels[0] - low_ch=self.in_channels[1] + high_ch = self.in_channels[0] + low_ch = self.in_channels[1] self.high_high_filter_shape = ( self.filter_size[0], self.filter_size[1], high_ch, self.high_out ) @@ -218,36 +218,38 @@ def build(self, inputs_shape): def forward(self, inputs): high_input = inputs[0] - low_input=inputs[1] + low_input = inputs[1] high_to_high = tf.nn.conv2d(high_input, self.high_high__W, strides=self.strides, padding="SAME") - high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') - high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") + high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") + strides=self.strides, padding="SAME") low_to_high = tf.nn.conv2d(low_input, self.low_high_W, - strides=self.strides, padding="SAME") - low_to_high=tf.keras.layers.UpSampling2D(size=(2,2), interpolation='nearest')(low_to_high) - high_out=high_to_high+low_to_high - low_out=low_to_low+high_to_low + strides=self.strides, padding="SAME") + low_to_high = tf.keras.layers.UpSampling2D(size=(2, 2), + interpolation='nearest')(low_to_high) + high_out = high_to_high + low_to_high + low_out = low_to_low + high_to_low if self.b_init: high_out = tf.nn.bias_add(high_out, self.high_b, data_format="NHWC") low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") if self.act: - high_out = self.act(high_out,name=self.name+'_high_out') - low_out= self.act(low_out,name=self.name+'_low_out') + high_out = self.act(high_out, name=self.name + '_high_out') + low_out = self.act(low_out, name=self.name + '_low_out') else: - high_out=tf.identity(high_out,name=self.name+'_high_out') - low_out=tf.identity(low_out,name=self.name+'_low_out') - outputs=[high_out,low_out] + high_out = tf.identity(high_out, name=self.name + '_high_out') + low_out = tf.identity(low_out, name=self.name + '_low_out') + outputs = [high_out, low_out] return outputs - class OctConv2dOut(Layer): """ - The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer output to get only a tensor, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. Use this layer after other Octconv layers and get a normal tensor output. + The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer output to get + only a tensor, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural + Networks with Octave Convolution `__. Parameters ---------- @@ -292,7 +294,7 @@ def __init__( self, n_filter=32, filter_size=(2, 2), - strides=(1,1), + strides=(1, 1), W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, @@ -322,8 +324,8 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={low_out}, kernel_size={filter_size}' - ', strides={strides}') + s = ('{classname}(in_channels={in_channels}, out_channels={low_out},' + ' kernel_size={filter_size}, strides={strides}') if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -334,11 +336,11 @@ def __repr__(self): def build(self, inputs_shape): if not self.in_channels: - high_ch=inputs_shape[0][-1] - low_ch=inputs_shape[1][-1] + high_ch = inputs_shape[0][-1] + low_ch = inputs_shape[1][-1] else: - high_ch=self.in_channels[0] - low_ch=self.in_channels[1] + high_ch = self.in_channels[0] + low_ch = self.in_channels[1] self.high_low_filter_shape = ( self.filter_size[0], self.filter_size[1], high_ch, self.high_out ) @@ -358,27 +360,27 @@ def build(self, inputs_shape): def forward(self, inputs): high_input = inputs[0] - low_input=inputs[1] - high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') - high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") + low_input = inputs[1] + high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") - low_out=low_to_low+high_to_low + strides=self.strides, padding="SAME") + low_out = low_to_low + high_to_low if self.b_init: low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") if self.act: - outputs= self.act(low_out,name=self.name+'_low_out') + outputs = self.act(low_out, name=self.name + '_low_out') else: - outputs=tf.identity(low_out,name=self.name+'_low_out') + outputs = tf.identity(low_out, name=self.name + '_low_out') return outputs - - class OctConv2dHighOut(Layer): """ - The :class:`OctConv2dHighOut` class is a slice layer for Octconv tensor list, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + The :class:`OctConv2dHighOut` class is a slice layer for Octconv tensor list, + see `Drop an Octave: Reducing Spatial Redundancy in Convolutional + Neural Networks with Octave Convolution `__. Parameters ---------- @@ -388,7 +390,8 @@ class OctConv2dHighOut(Layer): Notes ----- - Use this layer to get high resolution tensor. - - If you want to do some customized normalization ops, use this layer with OctConv2dLowOut and OctConv2dConcat layers to implement your idea. + - If you want to do some customized normalization ops, use this layer with + OctConv2dLowOut and OctConv2dConcat layers to implement your idea. Examples -------- @@ -426,17 +429,19 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self): pass def forward(self, inputs): - outputs=tf.identity(inputs[0],self.name) + outputs = tf.identity(inputs[0], self.name) return outputs -class OctConv2dLowOut (Layer): +class OctConv2dLowOut(Layer): """ - The :class:`OctConv2dLowOut` class is a slice layer for Octconv tensor list, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + The :class:`OctConv2dLowOut` class is a slice layer for Octconv tensor list, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks + with Octave Convolution `__. Parameters ---------- @@ -446,7 +451,8 @@ class OctConv2dLowOut (Layer): Notes ----- - Use this layer to get low resolution tensor. - - If you want to do some customized normalization ops, use this layer with OctConv2dHighOut and OctConv2dConcat layers to implement your idea. + - If you want to do some customized normalization ops, use this layer with + OctConv2dHighOut and OctConv2dConcat layers to implement your idea. Examples -------- @@ -485,16 +491,19 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self): pass def forward(self, inputs): - outputs=tf.identity(inputs[1],self.name) + outputs = tf.identity(inputs[1], self.name) return outputs + class OctConv2dConcat(Layer): """ - The :class:`OctConv2dConcat` class is a concat layer for two 2D image batches, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + The :class:`OctConv2dConcat` class is a concat layer for two 2D image batches, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks + with Octave Convolution `__. Parameters ---------- @@ -544,12 +553,12 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self): pass def forward(self, inputs): - if inputs[0].shape[1]>inputs[1].shape[1]: - outputs=[inputs[0],inputs[1]] + if inputs[0].shape[1] > inputs[1].shape[1]: + outputs = [inputs[0], inputs[1]] else: outputs = [inputs[1], inputs[0]] return outputs diff --git a/tensorlayer/prepro.py b/tensorlayer/prepro.py index e10755c1f..fc835b9fa 100644 --- a/tensorlayer/prepro.py +++ b/tensorlayer/prepro.py @@ -4150,7 +4150,7 @@ def pose_resize_shortestedge(image, annos, mask, target_size): def obj_box_coord_affine( - classes=None,coords=None, affine_matrix=None,affine_matrix_inv=None,min_ratio=0.0,min_width=0.0,min_height=0.0 + classes=None,coords=None, affine_matrix=None, affine_matrix_inv=None, min_ratio=0.0, min_width=0.0, min_height=0.0 ): """Apply affine transform the box coordinates, and gets the new box coordinates. @@ -4159,17 +4159,21 @@ def obj_box_coord_affine( classes : list of int or None Class IDs. coords : list of list of 4 int/float or None - Coordinates [[x, y, w, h], [x, y, w, h], ...]. Here x,y are the center coordinates of the bbox. + Coordinates [[x, y, w, h], [x, y, w, h], ...]. Here x,y are + the center coordinates of the bbox. affine_matrix : np.ndarray The affine matrix of the image. A ndarray with shape [2,3]. affine_matrix_inv : np.ndarray The inverse of the affine matrix. A ndarray with shape [2,3]. min_ratio : float - Threshold, remove the box if its ratio of new size to old size less than the threshold. + Threshold, remove the box if its ratio of new size to old size less + than the threshold. min_width : float - Threshold, remove the box if its ratio of width to image size less than the threshold. + Threshold, remove the box if its ratio of width to image size less + than the threshold. min_height : float - Threshold, remove the box if its ratio of height to image size less than the threshold. + Threshold, remove the box if its ratio of height to image size less + than the threshold. Returns ------- @@ -4179,59 +4183,66 @@ def obj_box_coord_affine( A list of new bounding boxes. """ Me = np.array([[0, 0, 1.]]) - new_classes=[] - new_cords=[] - if affine_matrix_inv==None: - affine_matrix_inv=np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) - for bbox_idx,bbox in enumerate(coords): - old_pt = np.array([[bbox[0] * 2. - 1.0], [bbox[1] * 2. - 1.0], [1.0]]) - new_wh_a = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[affine_matrix_inv[2]], [bbox[3]]])) - new_wh_b = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[affine_matrix_inv[2]], [-bbox[3]]])) + new_classes = [] + new_cords = [] + if affine_matrix_inv is None: + affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, + Me), axis=0)) + for (bbox_idx, bbox) in enumerate(coords): + old_pt = np.array([[bbox[0] * 2. - 1.], [bbox[1] * 2. - 1.], [1.]]) + new_wh_a = np.matmul(affine_matrix_inv[0:2, 0:2], + np.array([[bbox[2]], [bbox[3]]])) + new_wh_b = np.matmul(affine_matrix_inv[0:2, 0:2], + np.array([[bbox[2]], [-bbox[3]]])) new_w = max(abs(new_wh_a[0]), abs(new_wh_b[0])) new_h = max(abs(new_wh_a[1]), abs(new_wh_b[1])) - new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.0) / 2.0 - bbox_left = new_pt[0] - new_w / 2.0 + new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.) / 2. + bbox_left = new_pt[0] - new_w / 2. if bbox_left <= 0: bbox_left = 0. if bbox_left >= 1: bbox_left = 1. - bbox_right = new_pt[0] + new_w / 2.0 + bbox_right = new_pt[0] + new_w / 2. if bbox_right <= 0: bbox_right = 0. if bbox_right >= 1: bbox_right = 1. - bbox_top = new_pt[1] + new_h / 2.0 + bbox_top = new_pt[1] + new_h / 2. if bbox_top <= 0: bbox_top = 0. if bbox_top >= 1: bbox_top = 1. - bbox_bottom = new_pt[1] - new_h / 2.0 + bbox_bottom = new_pt[1] - new_h / 2. if bbox_bottom <= 0: bbox_bottom = 0. if bbox_bottom >= 1: bbox_bottom = 1. - ratio = abs(bbox_right - bbox_left) * abs(bbox_top - bbox_bottom) / (new_w + 0.00001) / (new_h + 0.00001) + ratio = abs(bbox_right - bbox_left) * abs(bbox_top - bbox_bottom) \ + / (new_w + 0.00001) / (new_h + 0.00001) bbox_x = abs(bbox_right + bbox_left) / 2. bbox_y = abs(bbox_top + bbox_bottom) / 2. bbox_h = abs(bbox_top - bbox_bottom) bbox_w = abs(bbox_right - bbox_left) - if (ratio > min_ratio) & (bbox_h >= min_height) & (bbox_w >= min_width): + if (ratio > min_ratio) & (bbox_h >= min_height) & (bbox_w + >= min_width): new_classes.append(classes[bbox_idx]) - new_cords.append([bbox_x,bbox_y,bbox_w,bbox_h]) - return new_classes,new_cords + new_cords.append([bbox_x, bbox_y, bbox_w, bbox_h]) + return (new_classes, new_cords) def rotated_obj_box_coord_affine( - classes=None,coords=None, affine_matrix=None,affine_matrix_inv=None + classes=None, coords=None, affine_matrix=None, affine_matrix_inv=None ): - """Apply affine transform the box coordinates with rotation, and gets the new box coordinates with rotation. Experimental! + """Apply affine transform the box coordinates with rotation, and gets the + new box coordinates with rotation. Experimental! Parameters ----------- classes : list of int or None Class IDs. coords : list of list of 5 int/float or None - Coordinates [[x, y, w, h, r], [x, y, w, h, r], ...]. Here x,y are the center coordinates of the bbox, r is the radius. + Coordinates [[x, y, w, h, r], [x, y, w, h, r], ...]. Here x,y are + the center coordinates of the bbox, r is the radius. affine_matrix : np.ndarray The affine matrix of the image. A ndarray with shape [2,3]. affine_matrix_inv : np.ndarray @@ -4245,38 +4256,51 @@ def rotated_obj_box_coord_affine( A list of new bounding boxes. """ Me = np.array([[0, 0, 1.]]) - new_classes=[] - new_cords=[] - if affine_matrix_inv==None: - affine_matrix_inv=np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) - for bbox_idx,bbox in enumerate(coords): + new_classes = [] + new_cords = [] + if affine_matrix_inv is None: + affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, + Me), axis=0)) + for (bbox_idx, bbox) in enumerate(coords): centerx = bbox[0] centery = bbox[1] - old_pt = np.array([[centerx * 2. - 1.0], [centery * 2. - 1.0], [1.0]]) - top_center = np.array([[0.], [-bbox[3]], [1.0]]) - right_center = np.array([[bbox[2]], [0.], [1.0]]) + old_pt = np.array([[centerx * 2. - 1.], [centery * 2. - 1.], [1.]]) + top_center = np.array([[0.], [-bbox[3]], [1.]]) + right_center = np.array([[bbox[2]], [0.], [1.]]) rot = bbox[-1] - rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), np.cos(rot), 0], [0, 0, 1]]) + rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), + np.cos(rot), 0], [0, 0, 1]]) top_center = np.matmul(rot_mat, top_center) + old_pt top_center[2][0] = 1. right_center = np.matmul(rot_mat, right_center) + old_pt right_center[2][0] = 1. - new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.0) / 2.0 - new_topcenter = (np.matmul(affine_matrix_inv, top_center) + 1.0) / 2.0 - new_rightcenter = (np.matmul(affine_matrix_inv, right_center) + 1.0) / 2.0 - new_h = np.sqrt((new_topcenter[0][0] - new_pt[0][0]) ** 2 + (new_topcenter[1][0] - new_pt[1][0]) ** 2) * 2. - new_w = np.sqrt((new_rightcenter[0][0] - new_pt[0][0]) ** 2 + (new_rightcenter[1][0] - new_pt[1][0]) ** 2) * 2. + new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.) / 2. + new_topcenter = (np.matmul(affine_matrix_inv, top_center) + 1.) / 2. + new_rightcenter = (np.matmul(affine_matrix_inv, right_center) + 1.) \ + / 2. + new_h = np.sqrt((new_topcenter[0][0] - new_pt[0][0]) ** 2 + + (new_topcenter[1][0] - new_pt[1][0]) ** 2) * 2. + new_w = np.sqrt((new_rightcenter[0][0] - new_pt[0][0]) ** 2 + + (new_rightcenter[1][0] - new_pt[1][0]) ** 2) * 2. deltax = -new_topcenter[0][0] + new_pt[0][0] - deltay = (-new_topcenter[1][0] + new_pt[1][0])* -1 - if deltay==0: - if deltax>=0: - new_rot = -math.pi/2 + deltay = (-new_topcenter[1][0] + new_pt[1][0]) * -1 + if deltay == 0: + if deltax >= 0: + new_rot = -math.pi / 2 else: new_rot = math.pi / 2 else: tanx = deltax / deltay new_rot = np.arctan(tanx) - if (new_pt[0][0]>0) and (new_pt[0][0]<1)and (new_pt[1][0]>0)and (new_pt[1][0]<1): - new_cords.append([1, new_pt[0][0], new_pt[1][0], new_w, new_h, new_rot]) + if new_pt[0][0] > 0 and new_pt[0][0] < 1 and new_pt[1][0] > 0 \ + and new_pt[1][0] < 1: + new_cords.append([ + 1, + new_pt[0][0], + new_pt[1][0], + new_w, + new_h, + new_rot, + ]) new_classes.append(classes[bbox_idx]) - return new_classes,new_cords \ No newline at end of file + return (new_classes, new_cords) diff --git a/tests/layers/test_layers_convolution.py b/tests/layers/test_layers_convolution.py index 546592139..3621d5e8c 100644 --- a/tests/layers/test_layers_convolution.py +++ b/tests/layers/test_layers_convolution.py @@ -466,61 +466,75 @@ class Layer_OctConv_2D_Test(CustomTestCase): @classmethod def setUpClass(cls): - print("\n#################################") + print ('\n#################################') cls.batch_size = 5 cls.inputs_shape = [cls.batch_size, 32, 32, 16] cls.input_layer = Input(cls.inputs_shape, name='input_layer') - cls.n1 = tl.layers.OctConv2dIn(name='octconv2din')(cls.input_layer) + cls.n1 = tl.layers.OctConv2dIn(name='octconv2din' + )(cls.input_layer) - cls.n2 = tl.layers.OctConv2d(32,0.5,act=tf.nn.relu, name='octconv2d')(cls.n1) + cls.n2 = tl.layers.OctConv2d(32, 0.5, act=tf.nn.relu, + name='octconv2d')(cls.n1) cls.n3 = tl.layers.OctConv2dHighOut(name='octconv2dho')(cls.n2) cls.n4 = tl.layers.OctConv2dLowOut(name='octconv2dlo')(cls.n2) - cls.n5 = tl.layers.OctConv2dConcat(name='octconv2dconcat')([cls.n3,cls.n4]) + cls.n5 = tl.layers.OctConv2dConcat(name='octconv2dconcat' + )([cls.n3, cls.n4]) - cls.n6 = tl.layers.OctConv2dOut(n_filter=32,name='octconv2dout')(cls.n5) + cls.n6 = tl.layers.OctConv2dOut(n_filter=32, name='octconv2dout' + )(cls.n5) cls.model = Model(cls.input_layer, cls.n6) - print("Testing OctConv2d model: \n", cls.model) + print ('Testing OctConv2d model: \n', cls.model) @classmethod def tearDownClass(cls): pass + # tf.reset_default_graph() def test_layer_n1(self): - self.assertEqual(self.n1[0].get_shape().as_list()[1:], [32, 32, 16]) - self.assertEqual(self.n1[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(self.n1[0].get_shape().as_list()[1:], [32, 32, + 16]) + self.assertEqual(self.n1[1].get_shape().as_list()[1:], [16, 16, + 16]) self.assertEqual(len(self.n1), 2) def test_layer_n2(self): - self.assertEqual(self.n2[0].get_shape().as_list()[1:], [32, 32, 16]) - self.assertEqual(self.n2[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(self.n2[0].get_shape().as_list()[1:], [32, 32, + 16]) + self.assertEqual(self.n2[1].get_shape().as_list()[1:], [16, 16, + 16]) self.assertEqual(len(self.n2), 2) def test_layer_n3(self): - self.assertEqual(self.n3.get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n3.get_shape().as_list()[1:], [32, 32, + 16]) def test_layer_n4(self): - self.assertEqual(self.n4.get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(self.n4.get_shape().as_list()[1:], [16, 16, + 16]) def test_layer_n5(self): - self.assertEqual(self.n5[0].get_shape().as_list()[1:], [32, 32, 16]) - self.assertEqual(self.n5[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(self.n5[0].get_shape().as_list()[1:], [32, 32, + 16]) + self.assertEqual(self.n5[1].get_shape().as_list()[1:], [16, 16, + 16]) self.assertEqual(len(self.n2), 2) def test_layer_n6(self): - self.assertEqual(self.n6.get_shape().as_list()[1:], [16, 16, 32]) + self.assertEqual(self.n6.get_shape().as_list()[1:], [16, 16, + 32]) if __name__ == '__main__': From fdc4a068632161470fd7546c267c355dbad1bcd3 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 16:34:33 +0800 Subject: [PATCH 04/17] Formatting Code --- tensorlayer/layers/convolution/oct_conv.py | 177 +++++++++------------ 1 file changed, 76 insertions(+), 101 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 78ecd372b..d5307aa1f 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -2,10 +2,16 @@ # -*- coding: utf-8 -*- import tensorflow as tf + import tensorlayer as tl from tensorlayer import logging +from tensorlayer.decorators import deprecated_alias from tensorlayer.layers.core import Layer +# from tensorlayer.layers.core import LayersConfig + + + __all__ = [ 'OctConv2dIn', 'OctConv2d', @@ -15,33 +21,27 @@ 'OctConv2dConcat', ] - class OctConv2dIn(Layer): """ - The :class:`OctConv2dIn` class is a preprocessing layer for 2D image - [batch, height, width, channel], see `Drop an Octave: Reducing Spatial Redundancy in - Convolutional Neural Networks with Octave Convolution `__. - + The :class:`OctConv2dIn` class is a preprocessing layer for 2D image [batch, height, width, channel], + see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave + Convolution `__. Parameters ---------- name : None or str A unique layer name. - Notes ----- - The height and width of input must be a multiple of the 2. - Use this layer before any other octconv layers. - The output will be a list which contains 2 tensor. - Examples -------- With TensorLayer - >>> net = tl.layers.Input([8, 28, 28, 16], name='input') >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) >>> print(octconv2d) >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] - """ def __init__( @@ -65,23 +65,21 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self): + def build(self, inputs): pass def forward(self, inputs): - high_out = tf.identity(inputs, name=(self.name + '_high_out')) - low_out = tf.nn.avg_pool2d(inputs, (2, 2), strides=(2, 2), - padding='SAME', name=self.name + '_low_out') - outputs = [high_out, low_out] + high_out=tf.identity(inputs,name=(self.name+'_high_out')) + low_out = tf.nn.avg_pool2d(inputs, (2,2), strides=(2,2),padding='SAME',name=self.name+'_low_out') + outputs=[high_out,low_out] return outputs class OctConv2d(Layer): """ - The :class:`OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see `Drop an Octave: - Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution - `__. Use this layer to process tensor list. - + The :class:`OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. Use this layer to process tensor list. Parameters ---------- filter : int @@ -101,7 +99,6 @@ class OctConv2d(Layer): The activation function of this layer. name : None or str A unique layer name. - Notes ----- - The input should be a list with shape [high_res_tensor , low_res_tensor], @@ -109,11 +106,9 @@ class OctConv2d(Layer): - If you do not which tensor is larger, use OctConv2dConcat layer. - The output will be a list which contains 2 tensor. - You should not use the output directly. - Examples -------- With TensorLayer - >>> net = tl.layers.Input([8, 28, 28, 32], name='input') >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) >>> print(octconv2d) @@ -121,7 +116,6 @@ class OctConv2d(Layer): >>> octconv2d = tl.layers.OctConv2d(32,0.5,act=tf.nn.relu, name='octconv2d_1')(octconv2d) >>> print(octconv2d) >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] - """ def __init__( @@ -129,12 +123,12 @@ def __init__( filter=32, alpha=0.5, filter_size=(2, 2), - strides=(1, 1), + strides=(1,1), W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, in_channels=None, - name=None + name=None # 'cnn2d_layer', ): super().__init__(name) self.filter = filter @@ -157,6 +151,7 @@ def __init__( self.build(None) self._built = True + logging.info( "OctConv2d %s: filter_size: %s strides: %s high_out: %s low_out: %s act: %s" % ( self.name, str(filter_size), str(strides), str(self.high_out), str(self.low_out), @@ -166,8 +161,8 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={filter} ,' - 'kernel_size={filter_size}, strides={strides}') + s = ('{classname}(in_channels={in_channels}, out_channels={filter} kernel_size={filter_size}' + ', strides={strides}') if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -179,11 +174,11 @@ def __repr__(self): def build(self, inputs_shape): if not self.in_channels: - high_ch = inputs_shape[0][-1] - low_ch = inputs_shape[1][-1] + high_ch=inputs_shape[0][-1] + low_ch=inputs_shape[1][-1] else: - high_ch = self.in_channels[0] - low_ch = self.in_channels[1] + high_ch=self.in_channels[0] + low_ch=self.in_channels[1] self.high_high_filter_shape = ( self.filter_size[0], self.filter_size[1], high_ch, self.high_out ) @@ -218,39 +213,38 @@ def build(self, inputs_shape): def forward(self, inputs): high_input = inputs[0] - low_input = inputs[1] + low_input=inputs[1] high_to_high = tf.nn.conv2d(high_input, self.high_high__W, strides=self.strides, padding="SAME") - high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') - high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") + high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') + high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") + strides=self.strides, padding="SAME") low_to_high = tf.nn.conv2d(low_input, self.low_high_W, - strides=self.strides, padding="SAME") - low_to_high = tf.keras.layers.UpSampling2D(size=(2, 2), - interpolation='nearest')(low_to_high) - high_out = high_to_high + low_to_high - low_out = low_to_low + high_to_low + strides=self.strides, padding="SAME") + low_to_high=tf.keras.layers.UpSampling2D(size=(2,2), interpolation='nearest')(low_to_high) + high_out=high_to_high+low_to_high + low_out=low_to_low+high_to_low if self.b_init: high_out = tf.nn.bias_add(high_out, self.high_b, data_format="NHWC") low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") if self.act: - high_out = self.act(high_out, name=self.name + '_high_out') - low_out = self.act(low_out, name=self.name + '_low_out') + high_out = self.act(high_out,name=self.name+'_high_out') + low_out= self.act(low_out,name=self.name+'_low_out') else: - high_out = tf.identity(high_out, name=self.name + '_high_out') - low_out = tf.identity(low_out, name=self.name + '_low_out') - outputs = [high_out, low_out] + high_out=tf.identity(high_out,name=self.name+'_high_out') + low_out=tf.identity(low_out,name=self.name+'_low_out') + outputs=[high_out,low_out] return outputs + class OctConv2dOut(Layer): """ - The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer output to get - only a tensor, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural - Networks with Octave Convolution `__. - + The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer output to get only a tensor, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution + `__. Parameters ---------- filter : int @@ -268,15 +262,12 @@ class OctConv2dOut(Layer): The activation function of this layer. name : None or str A unique layer name. - Notes ----- - Use this layer to get only a tensor for other normal layer. - Examples -------- With TensorLayer - >>> net = tl.layers.Input([8, 28, 28, 32], name='input') >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) >>> print(octconv2d) @@ -287,14 +278,13 @@ class OctConv2dOut(Layer): >>> octconv2d = tl.layers.OctConv2dOut(32,act=tf.nn.relu, name='octconv2dout_1')(octconv2d) >>> print(octconv2d) >>> output shape : (8, 14, 14, 32) - """ def __init__( self, n_filter=32, filter_size=(2, 2), - strides=(1, 1), + strides=(1,1), W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, @@ -324,8 +314,8 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={low_out},' - ' kernel_size={filter_size}, strides={strides}') + s = ('{classname}(in_channels={in_channels}, out_channels={low_out}, kernel_size={filter_size}' + ', strides={strides}') if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -336,11 +326,11 @@ def __repr__(self): def build(self, inputs_shape): if not self.in_channels: - high_ch = inputs_shape[0][-1] - low_ch = inputs_shape[1][-1] + high_ch=inputs_shape[0][-1] + low_ch=inputs_shape[1][-1] else: - high_ch = self.in_channels[0] - low_ch = self.in_channels[1] + high_ch=self.in_channels[0] + low_ch=self.in_channels[1] self.high_low_filter_shape = ( self.filter_size[0], self.filter_size[1], high_ch, self.high_out ) @@ -360,43 +350,41 @@ def build(self, inputs_shape): def forward(self, inputs): high_input = inputs[0] - low_input = inputs[1] - high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') - high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") + low_input=inputs[1] + high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') + high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") - low_out = low_to_low + high_to_low + strides=self.strides, padding="SAME") + low_out=low_to_low+high_to_low if self.b_init: low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") if self.act: - outputs = self.act(low_out, name=self.name + '_low_out') + outputs= self.act(low_out,name=self.name+'_low_out') else: - outputs = tf.identity(low_out, name=self.name + '_low_out') + outputs=tf.identity(low_out,name=self.name+'_low_out') return outputs + + class OctConv2dHighOut(Layer): """ - The :class:`OctConv2dHighOut` class is a slice layer for Octconv tensor list, - see `Drop an Octave: Reducing Spatial Redundancy in Convolutional - Neural Networks with Octave Convolution `__. - + The :class:`OctConv2dHighOut` class is a slice layer for Octconv tensor list, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. Parameters ---------- name : None or str A unique layer name. - Notes ----- - Use this layer to get high resolution tensor. - If you want to do some customized normalization ops, use this layer with - OctConv2dLowOut and OctConv2dConcat layers to implement your idea. - + OctConv2dLowOut and OctConv2dConcat layers to implement your idea. Examples -------- With TensorLayer - >>> net = tl.layers.Input([8, 28, 28, 32], name='input') >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) >>> print(octconv2d) @@ -404,7 +392,6 @@ class OctConv2dHighOut(Layer): >>> octconv2d = tl.layers.OctConv2dHighOut(name='octconv2dho_1')(octconv2d) >>> print(octconv2d) >>> output shape : (8, 28, 28, 32) - """ def __init__( @@ -429,35 +416,31 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self): + def build(self, inputs): pass def forward(self, inputs): - outputs = tf.identity(inputs[0], self.name) + outputs=tf.identity(inputs[0],self.name) return outputs class OctConv2dLowOut(Layer): """ The :class:`OctConv2dLowOut` class is a slice layer for Octconv tensor list, see - `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks - with Octave Convolution `__. - + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. Parameters ---------- name : None or str A unique layer name. - Notes ----- - Use this layer to get low resolution tensor. - If you want to do some customized normalization ops, use this layer with OctConv2dHighOut and OctConv2dConcat layers to implement your idea. - Examples -------- With TensorLayer - >>> net = tl.layers.Input([8, 28, 28, 32], name='input') >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) >>> print(octconv2d) @@ -465,8 +448,6 @@ class OctConv2dLowOut(Layer): >>> octconv2d = tl.layers.OctConv2dLowOut(name='octconv2dlo_1')(octconv2d) >>> print(octconv2d) >>> output shape : (8, 14, 14, 32) - - """ def __init__( @@ -491,34 +472,29 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self): + def build(self, inputs): pass def forward(self, inputs): - outputs = tf.identity(inputs[1], self.name) + outputs=tf.identity(inputs[1],self.name) return outputs - class OctConv2dConcat(Layer): """ The :class:`OctConv2dConcat` class is a concat layer for two 2D image batches, see - `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks - with Octave Convolution `__. - + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. Parameters ---------- name : None or str A unique layer name. - Notes ----- - Use this layer to concat two tensor. - The height and width of one tensor should be twice of the other tensor. - Examples -------- With TensorLayer - >>> net = tl.layers.Input([8, 28, 28, 32], name='input') >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1') >>> print(octconv2d) @@ -528,7 +504,6 @@ class OctConv2dConcat(Layer): >>> octconv2 = tl.layers.OctConv2dConcat(name='octconv2dcat_1')([octconv2dh,octconv2dl]) >>> print(octconv2d) >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] - """ def __init__( @@ -553,12 +528,12 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self): + def build(self, inputs): pass def forward(self, inputs): - if inputs[0].shape[1] > inputs[1].shape[1]: - outputs = [inputs[0], inputs[1]] + if inputs[0].shape[1]>inputs[1].shape[1]: + outputs=[inputs[0],inputs[1]] else: outputs = [inputs[1], inputs[0]] - return outputs + return outputs \ No newline at end of file From d9688a1437ccf3c4448dcf6538a480cffd4634df Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 16:59:13 +0800 Subject: [PATCH 05/17] Formatting Code --- tensorlayer/layers/convolution/oct_conv.py | 120 ++++++++++----------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index d5307aa1f..8c0e54911 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -5,13 +5,11 @@ import tensorlayer as tl from tensorlayer import logging -from tensorlayer.decorators import deprecated_alias from tensorlayer.layers.core import Layer # from tensorlayer.layers.core import LayersConfig - __all__ = [ 'OctConv2dIn', 'OctConv2d', @@ -21,10 +19,12 @@ 'OctConv2dConcat', ] + class OctConv2dIn(Layer): """ - The :class:`OctConv2dIn` class is a preprocessing layer for 2D image [batch, height, width, channel], - see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave + The :class:`OctConv2dIn` class is a preprocessing layer for + 2D image [batch, height, width, channel], see `Drop an Octave: Reducing + Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. Parameters ---------- @@ -65,13 +65,14 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self, inputs_shape=None): pass def forward(self, inputs): - high_out=tf.identity(inputs,name=(self.name+'_high_out')) - low_out = tf.nn.avg_pool2d(inputs, (2,2), strides=(2,2),padding='SAME',name=self.name+'_low_out') - outputs=[high_out,low_out] + high_out = tf.identity(inputs, name=(self.name + '_high_out')) + low_out = tf.nn.avg_pool2d(inputs, (2, 2), strides=(2, 2), padding='SAME', + name=self.name + '_low_out') + outputs = [high_out, low_out] return outputs @@ -120,10 +121,10 @@ class OctConv2d(Layer): def __init__( self, - filter=32, + nfilter=32, alpha=0.5, filter_size=(2, 2), - strides=(1,1), + strides=(1, 1), W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, @@ -131,7 +132,7 @@ def __init__( name=None # 'cnn2d_layer', ): super().__init__(name) - self.filter = filter + self.filter = nfilter self.alpha = alpha if (self.alpha >= 1) or (self.alpha <= 0): raise ValueError( @@ -151,7 +152,6 @@ def __init__( self.build(None) self._built = True - logging.info( "OctConv2d %s: filter_size: %s strides: %s high_out: %s low_out: %s act: %s" % ( self.name, str(filter_size), str(strides), str(self.high_out), str(self.low_out), @@ -161,8 +161,8 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={filter} kernel_size={filter_size}' - ', strides={strides}') + s = ('{classname}(in_channels={in_channels}, out_channels={filter}, ' + 'kernel_size={filter_size}, strides={strides}') if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -174,11 +174,11 @@ def __repr__(self): def build(self, inputs_shape): if not self.in_channels: - high_ch=inputs_shape[0][-1] - low_ch=inputs_shape[1][-1] + high_ch = inputs_shape[0][-1] + low_ch = inputs_shape[1][-1] else: - high_ch=self.in_channels[0] - low_ch=self.in_channels[1] + high_ch = self.in_channels[0] + low_ch = self.in_channels[1] self.high_high_filter_shape = ( self.filter_size[0], self.filter_size[1], high_ch, self.high_out ) @@ -213,37 +213,38 @@ def build(self, inputs_shape): def forward(self, inputs): high_input = inputs[0] - low_input=inputs[1] + low_input = inputs[1] high_to_high = tf.nn.conv2d(high_input, self.high_high__W, strides=self.strides, padding="SAME") - high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') - high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") + high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") + strides=self.strides, padding="SAME") low_to_high = tf.nn.conv2d(low_input, self.low_high_W, - strides=self.strides, padding="SAME") - low_to_high=tf.keras.layers.UpSampling2D(size=(2,2), interpolation='nearest')(low_to_high) - high_out=high_to_high+low_to_high - low_out=low_to_low+high_to_low + strides=self.strides, padding="SAME") + low_to_high = tf.keras.layers.UpSampling2D(size=(2, 2), + interpolation='nearest')(low_to_high) + high_out = high_to_high + low_to_high + low_out = low_to_low + high_to_low if self.b_init: high_out = tf.nn.bias_add(high_out, self.high_b, data_format="NHWC") low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") if self.act: - high_out = self.act(high_out,name=self.name+'_high_out') - low_out= self.act(low_out,name=self.name+'_low_out') + high_out = self.act(high_out, name=self.name + '_high_out') + low_out = self.act(low_out, name=self.name + '_low_out') else: - high_out=tf.identity(high_out,name=self.name+'_high_out') - low_out=tf.identity(low_out,name=self.name+'_low_out') - outputs=[high_out,low_out] + high_out = tf.identity(high_out, name=self.name + '_high_out') + low_out = tf.identity(low_out, name=self.name + '_low_out') + outputs = [high_out, low_out] return outputs - class OctConv2dOut(Layer): """ - The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer output to get only a tensor, see - `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution + The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer + output to get only a tensor, see` Drop an Octave: Reducing Spatial Redundancy + in Convolutional Neural Networks with Octave Convolution `__. Parameters ---------- @@ -284,7 +285,7 @@ def __init__( self, n_filter=32, filter_size=(2, 2), - strides=(1,1), + strides=(1, 1), W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, @@ -314,8 +315,8 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={low_out}, kernel_size={filter_size}' - ', strides={strides}') + s = ('{classname}(in_channels={in_channels}, out_channels={low_out}, ' + 'kernel_size={filter_size}, strides={strides}') if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -326,11 +327,11 @@ def __repr__(self): def build(self, inputs_shape): if not self.in_channels: - high_ch=inputs_shape[0][-1] - low_ch=inputs_shape[1][-1] + high_ch = inputs_shape[0][-1] + low_ch = inputs_shape[1][-1] else: - high_ch=self.in_channels[0] - low_ch=self.in_channels[1] + high_ch = self.in_channels[0] + low_ch = self.in_channels[1] self.high_low_filter_shape = ( self.filter_size[0], self.filter_size[1], high_ch, self.high_out ) @@ -350,24 +351,22 @@ def build(self, inputs_shape): def forward(self, inputs): high_input = inputs[0] - low_input=inputs[1] - high_to_low =tf.nn.avg_pool2d(high_input, (2,2), strides=(2,2),padding='SAME') - high_to_low=tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") + low_input = inputs[1] + high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, + strides=self.strides, padding="SAME") low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") - low_out=low_to_low+high_to_low + strides=self.strides, padding="SAME") + low_out = low_to_low + high_to_low if self.b_init: low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") if self.act: - outputs= self.act(low_out,name=self.name+'_low_out') + outputs = self.act(low_out, name=self.name + '_low_out') else: - outputs=tf.identity(low_out,name=self.name+'_low_out') + outputs = tf.identity(low_out, name=self.name + '_low_out') return outputs - - class OctConv2dHighOut(Layer): """ The :class:`OctConv2dHighOut` class is a slice layer for Octconv tensor list, see @@ -416,11 +415,11 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self, inputs_shape=None): pass def forward(self, inputs): - outputs=tf.identity(inputs[0],self.name) + outputs = tf.identity(inputs[0], self.name) return outputs @@ -472,13 +471,14 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self, inputs_shape=None): pass def forward(self, inputs): - outputs=tf.identity(inputs[1],self.name) + outputs = tf.identity(inputs[1], self.name) return outputs + class OctConv2dConcat(Layer): """ The :class:`OctConv2dConcat` class is a concat layer for two 2D image batches, see @@ -528,12 +528,12 @@ def __repr__(self): s += ')' return s.format(classname=self.__class__.__name__, **self.__dict__) - def build(self, inputs): + def build(self, inputs_shape=None): pass def forward(self, inputs): - if inputs[0].shape[1]>inputs[1].shape[1]: - outputs=[inputs[0],inputs[1]] + if inputs[0].shape[1] > inputs[1].shape[1]: + outputs = [inputs[0], inputs[1]] else: outputs = [inputs[1], inputs[0]] - return outputs \ No newline at end of file + return outputs From 596025b238ba1b7aeaa94b4539ae4d1cadc200e0 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 17:25:47 +0800 Subject: [PATCH 06/17] Fix doc --- docs/modules/layers.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index d410ac5d6..81ee90a2d 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -247,8 +247,8 @@ GroupConv2d """"""""""""""""""""" .. autoclass:: GroupConv2d -OctConv2d --------------------------- +OctConv +^^^^^^^^^^^^^^^^^^^^^^^^^^ For OctConv2d, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. From fe89e149741060ace41e28261a3feec60933f6f4 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 17:43:05 +0800 Subject: [PATCH 07/17] Yapf fix. --- tensorlayer/layers/convolution/oct_conv.py | 127 ++++++--------------- tensorlayer/prepro.py | 31 ++--- tests/layers/test_layers_convolution.py | 45 +++----- 3 files changed, 65 insertions(+), 138 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 8c0e54911..d59658b5e 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -9,7 +9,6 @@ # from tensorlayer.layers.core import LayersConfig - __all__ = [ 'OctConv2dIn', 'OctConv2d', @@ -52,11 +51,7 @@ def __init__( self.build(None) self._built = True - logging.info( - "OctConv2dIn %s: " % ( - self.name, - ) - ) + logging.info("OctConv2dIn %s: " % (self.name, )) def __repr__(self): s = ('{classname}(') @@ -70,8 +65,7 @@ def build(self, inputs_shape=None): def forward(self, inputs): high_out = tf.identity(inputs, name=(self.name + '_high_out')) - low_out = tf.nn.avg_pool2d(inputs, (2, 2), strides=(2, 2), padding='SAME', - name=self.name + '_low_out') + low_out = tf.nn.avg_pool2d(inputs, (2, 2), strides=(2, 2), padding='SAME', name=self.name + '_low_out') outputs = [high_out, low_out] return outputs @@ -135,13 +129,11 @@ def __init__( self.filter = nfilter self.alpha = alpha if (self.alpha >= 1) or (self.alpha <= 0): - raise ValueError( - "The alpha must be in (0,1)") + raise ValueError("The alpha must be in (0,1)") self.high_out = int(self.alpha * self.filter) self.low_out = self.filter - self.high_out if (self.high_out == 0) or (self.low_out == 0): - raise ValueError( - "The output channel must be greater than 0.") + raise ValueError("The output channel must be greater than 0.") self.filter_size = filter_size self.strides = strides self.W_init = W_init @@ -161,8 +153,10 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={filter}, ' - 'kernel_size={filter_size}, strides={strides}') + s = ( + '{classname}(in_channels={in_channels}, out_channels={filter}, ' + 'kernel_size={filter_size}, strides={strides}' + ) if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -179,52 +173,27 @@ def build(self, inputs_shape): else: high_ch = self.in_channels[0] low_ch = self.in_channels[1] - self.high_high_filter_shape = ( - self.filter_size[0], self.filter_size[1], high_ch, self.high_out - ) - self.high_low_filter_shape = ( - self.filter_size[0], self.filter_size[1], high_ch, self.low_out - ) - self.low_low_filter_shape = ( - self.filter_size[0], self.filter_size[1], low_ch, self.low_out - ) - self.low_high_filter_shape = ( - self.filter_size[0], self.filter_size[1], low_ch, self.high_out - ) - self.high_high__W = self._get_weights( - "high_high_filters", shape=self.high_high_filter_shape, init=self.W_init - ) - self.high_low__W = self._get_weights( - "high_low_filters", shape=self.high_low_filter_shape, init=self.W_init - ) - self.low_low_W = self._get_weights( - "low_low_filters", shape=self.low_low_filter_shape, init=self.W_init - ) - self.low_high_W = self._get_weights( - "low_high_filters", shape=self.low_high_filter_shape, init=self.W_init - ) + self.high_high_filter_shape = (self.filter_size[0], self.filter_size[1], high_ch, self.high_out) + self.high_low_filter_shape = (self.filter_size[0], self.filter_size[1], high_ch, self.low_out) + self.low_low_filter_shape = (self.filter_size[0], self.filter_size[1], low_ch, self.low_out) + self.low_high_filter_shape = (self.filter_size[0], self.filter_size[1], low_ch, self.high_out) + self.high_high__W = self._get_weights("high_high_filters", shape=self.high_high_filter_shape, init=self.W_init) + self.high_low__W = self._get_weights("high_low_filters", shape=self.high_low_filter_shape, init=self.W_init) + self.low_low_W = self._get_weights("low_low_filters", shape=self.low_low_filter_shape, init=self.W_init) + self.low_high_W = self._get_weights("low_high_filters", shape=self.low_high_filter_shape, init=self.W_init) if self.b_init: - self.high_b = self._get_weights( - "high_biases", shape=(self.high_out), init=self.b_init - ) - self.low_b = self._get_weights( - "low_biases", shape=(self.low_out), init=self.b_init - ) + self.high_b = self._get_weights("high_biases", shape=(self.high_out), init=self.b_init) + self.low_b = self._get_weights("low_biases", shape=(self.low_out), init=self.b_init) def forward(self, inputs): high_input = inputs[0] low_input = inputs[1] - high_to_high = tf.nn.conv2d(high_input, self.high_high__W, - strides=self.strides, padding="SAME") + high_to_high = tf.nn.conv2d(high_input, self.high_high__W, strides=self.strides, padding="SAME") high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') - high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") - low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") - low_to_high = tf.nn.conv2d(low_input, self.low_high_W, - strides=self.strides, padding="SAME") - low_to_high = tf.keras.layers.UpSampling2D(size=(2, 2), - interpolation='nearest')(low_to_high) + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, strides=self.strides, padding="SAME") + low_to_low = tf.nn.conv2d(low_input, self.low_low_W, strides=self.strides, padding="SAME") + low_to_high = tf.nn.conv2d(low_input, self.low_high_W, strides=self.strides, padding="SAME") + low_to_high = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')(low_to_high) high_out = high_to_high + low_to_high low_out = low_to_low + high_to_low if self.b_init: @@ -315,8 +284,10 @@ def __init__( def __repr__(self): actstr = self.act.__name__ if self.act is not None else 'No Activation' - s = ('{classname}(in_channels={in_channels}, out_channels={low_out}, ' - 'kernel_size={filter_size}, strides={strides}') + s = ( + '{classname}(in_channels={in_channels}, out_channels={low_out}, ' + 'kernel_size={filter_size}, strides={strides}' + ) if self.b_init is None: s += ', bias=False' s += (', ' + actstr) @@ -332,31 +303,19 @@ def build(self, inputs_shape): else: high_ch = self.in_channels[0] low_ch = self.in_channels[1] - self.high_low_filter_shape = ( - self.filter_size[0], self.filter_size[1], high_ch, self.high_out - ) - self.low_low_filter_shape = ( - self.filter_size[0], self.filter_size[1], low_ch, self.low_out - ) - self.high_low__W = self._get_weights( - "high_low_filters", shape=self.high_low_filter_shape, init=self.W_init - ) - self.low_low_W = self._get_weights( - "low_low_filters", shape=self.low_low_filter_shape, init=self.W_init - ) + self.high_low_filter_shape = (self.filter_size[0], self.filter_size[1], high_ch, self.high_out) + self.low_low_filter_shape = (self.filter_size[0], self.filter_size[1], low_ch, self.low_out) + self.high_low__W = self._get_weights("high_low_filters", shape=self.high_low_filter_shape, init=self.W_init) + self.low_low_W = self._get_weights("low_low_filters", shape=self.low_low_filter_shape, init=self.W_init) if self.b_init: - self.low_b = self._get_weights( - "low_biases", shape=(self.low_out), init=self.b_init - ) + self.low_b = self._get_weights("low_biases", shape=(self.low_out), init=self.b_init) def forward(self, inputs): high_input = inputs[0] low_input = inputs[1] high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') - high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, - strides=self.strides, padding="SAME") - low_to_low = tf.nn.conv2d(low_input, self.low_low_W, - strides=self.strides, padding="SAME") + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, strides=self.strides, padding="SAME") + low_to_low = tf.nn.conv2d(low_input, self.low_low_W, strides=self.strides, padding="SAME") low_out = low_to_low + high_to_low if self.b_init: low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") @@ -401,11 +360,7 @@ def __init__( self.build(None) self._built = True - logging.info( - "OctConv2dHighOut %s: " % ( - self.name, - ) - ) + logging.info("OctConv2dHighOut %s: " % (self.name, )) def __repr__(self): @@ -457,11 +412,7 @@ def __init__( self.build(None) self._built = True - logging.info( - "OctConv2dHighOut %s: " % ( - self.name, - ) - ) + logging.info("OctConv2dHighOut %s: " % (self.name, )) def __repr__(self): @@ -514,11 +465,7 @@ def __init__( self.build(None) self._built = True - logging.info( - "OctConv2dConcat %s: " % ( - self.name, - ) - ) + logging.info("OctConv2dConcat %s: " % (self.name, )) def __repr__(self): diff --git a/tensorlayer/prepro.py b/tensorlayer/prepro.py index fc835b9fa..ebbde9de4 100644 --- a/tensorlayer/prepro.py +++ b/tensorlayer/prepro.py @@ -4150,7 +4150,8 @@ def pose_resize_shortestedge(image, annos, mask, target_size): def obj_box_coord_affine( - classes=None,coords=None, affine_matrix=None, affine_matrix_inv=None, min_ratio=0.0, min_width=0.0, min_height=0.0 + classes=None, coords=None, affine_matrix=None, affine_matrix_inv=None, min_ratio=0.0, min_width=0.0, + min_height=0.0 ): """Apply affine transform the box coordinates, and gets the new box coordinates. @@ -4186,14 +4187,11 @@ def obj_box_coord_affine( new_classes = [] new_cords = [] if affine_matrix_inv is None: - affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, - Me), axis=0)) + affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) for (bbox_idx, bbox) in enumerate(coords): old_pt = np.array([[bbox[0] * 2. - 1.], [bbox[1] * 2. - 1.], [1.]]) - new_wh_a = np.matmul(affine_matrix_inv[0:2, 0:2], - np.array([[bbox[2]], [bbox[3]]])) - new_wh_b = np.matmul(affine_matrix_inv[0:2, 0:2], - np.array([[bbox[2]], [-bbox[3]]])) + new_wh_a = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[bbox[2]], [bbox[3]]])) + new_wh_b = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[bbox[2]], [-bbox[3]]])) new_w = max(abs(new_wh_a[0]), abs(new_wh_b[0])) new_h = max(abs(new_wh_a[1]), abs(new_wh_b[1])) new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.) / 2. @@ -4223,16 +4221,13 @@ def obj_box_coord_affine( bbox_y = abs(bbox_top + bbox_bottom) / 2. bbox_h = abs(bbox_top - bbox_bottom) bbox_w = abs(bbox_right - bbox_left) - if (ratio > min_ratio) & (bbox_h >= min_height) & (bbox_w - >= min_width): + if (ratio > min_ratio) & (bbox_h >= min_height) & (bbox_w >= min_width): new_classes.append(classes[bbox_idx]) new_cords.append([bbox_x, bbox_y, bbox_w, bbox_h]) return (new_classes, new_cords) -def rotated_obj_box_coord_affine( - classes=None, coords=None, affine_matrix=None, affine_matrix_inv=None - ): +def rotated_obj_box_coord_affine(classes=None, coords=None, affine_matrix=None, affine_matrix_inv=None): """Apply affine transform the box coordinates with rotation, and gets the new box coordinates with rotation. Experimental! @@ -4259,8 +4254,7 @@ def rotated_obj_box_coord_affine( new_classes = [] new_cords = [] if affine_matrix_inv is None: - affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, - Me), axis=0)) + affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) for (bbox_idx, bbox) in enumerate(coords): centerx = bbox[0] centery = bbox[1] @@ -4268,8 +4262,7 @@ def rotated_obj_box_coord_affine( top_center = np.array([[0.], [-bbox[3]], [1.]]) right_center = np.array([[bbox[2]], [0.], [1.]]) rot = bbox[-1] - rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), - np.cos(rot), 0], [0, 0, 1]]) + rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), np.cos(rot), 0], [0, 0, 1]]) top_center = np.matmul(rot_mat, top_center) + old_pt top_center[2][0] = 1. right_center = np.matmul(rot_mat, right_center) + old_pt @@ -4278,10 +4271,8 @@ def rotated_obj_box_coord_affine( new_topcenter = (np.matmul(affine_matrix_inv, top_center) + 1.) / 2. new_rightcenter = (np.matmul(affine_matrix_inv, right_center) + 1.) \ / 2. - new_h = np.sqrt((new_topcenter[0][0] - new_pt[0][0]) ** 2 - + (new_topcenter[1][0] - new_pt[1][0]) ** 2) * 2. - new_w = np.sqrt((new_rightcenter[0][0] - new_pt[0][0]) ** 2 - + (new_rightcenter[1][0] - new_pt[1][0]) ** 2) * 2. + new_h = np.sqrt((new_topcenter[0][0] - new_pt[0][0])**2 + (new_topcenter[1][0] - new_pt[1][0])**2) * 2. + new_w = np.sqrt((new_rightcenter[0][0] - new_pt[0][0])**2 + (new_rightcenter[1][0] - new_pt[1][0])**2) * 2. deltax = -new_topcenter[0][0] + new_pt[0][0] deltay = (-new_topcenter[1][0] + new_pt[1][0]) * -1 if deltay == 0: diff --git a/tests/layers/test_layers_convolution.py b/tests/layers/test_layers_convolution.py index 3621d5e8c..890f2a330 100644 --- a/tests/layers/test_layers_convolution.py +++ b/tests/layers/test_layers_convolution.py @@ -462,34 +462,31 @@ def test_layer_n4(self): # self.assertEqual(self.net2.count_params(), 19392) # self.assertEqual(self.net2.outputs.get_shape().as_list()[1:], [299, 299, 64]) + class Layer_OctConv_2D_Test(CustomTestCase): @classmethod def setUpClass(cls): - print ('\n#################################') + print('\n#################################') cls.batch_size = 5 cls.inputs_shape = [cls.batch_size, 32, 32, 16] cls.input_layer = Input(cls.inputs_shape, name='input_layer') - cls.n1 = tl.layers.OctConv2dIn(name='octconv2din' - )(cls.input_layer) + cls.n1 = tl.layers.OctConv2dIn(name='octconv2din')(cls.input_layer) - cls.n2 = tl.layers.OctConv2d(32, 0.5, act=tf.nn.relu, - name='octconv2d')(cls.n1) + cls.n2 = tl.layers.OctConv2d(32, 0.5, act=tf.nn.relu, name='octconv2d')(cls.n1) cls.n3 = tl.layers.OctConv2dHighOut(name='octconv2dho')(cls.n2) cls.n4 = tl.layers.OctConv2dLowOut(name='octconv2dlo')(cls.n2) - cls.n5 = tl.layers.OctConv2dConcat(name='octconv2dconcat' - )([cls.n3, cls.n4]) + cls.n5 = tl.layers.OctConv2dConcat(name='octconv2dconcat')([cls.n3, cls.n4]) - cls.n6 = tl.layers.OctConv2dOut(n_filter=32, name='octconv2dout' - )(cls.n5) + cls.n6 = tl.layers.OctConv2dOut(n_filter=32, name='octconv2dout')(cls.n5) cls.model = Model(cls.input_layer, cls.n6) - print ('Testing OctConv2d model: \n', cls.model) + print('Testing OctConv2d model: \n', cls.model) @classmethod def tearDownClass(cls): @@ -499,42 +496,34 @@ def tearDownClass(cls): def test_layer_n1(self): - self.assertEqual(self.n1[0].get_shape().as_list()[1:], [32, 32, - 16]) - self.assertEqual(self.n1[1].get_shape().as_list()[1:], [16, 16, - 16]) + self.assertEqual(self.n1[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n1[1].get_shape().as_list()[1:], [16, 16, 16]) self.assertEqual(len(self.n1), 2) def test_layer_n2(self): - self.assertEqual(self.n2[0].get_shape().as_list()[1:], [32, 32, - 16]) - self.assertEqual(self.n2[1].get_shape().as_list()[1:], [16, 16, - 16]) + self.assertEqual(self.n2[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n2[1].get_shape().as_list()[1:], [16, 16, 16]) self.assertEqual(len(self.n2), 2) def test_layer_n3(self): - self.assertEqual(self.n3.get_shape().as_list()[1:], [32, 32, - 16]) + self.assertEqual(self.n3.get_shape().as_list()[1:], [32, 32, 16]) def test_layer_n4(self): - self.assertEqual(self.n4.get_shape().as_list()[1:], [16, 16, - 16]) + self.assertEqual(self.n4.get_shape().as_list()[1:], [16, 16, 16]) def test_layer_n5(self): - self.assertEqual(self.n5[0].get_shape().as_list()[1:], [32, 32, - 16]) - self.assertEqual(self.n5[1].get_shape().as_list()[1:], [16, 16, - 16]) + self.assertEqual(self.n5[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n5[1].get_shape().as_list()[1:], [16, 16, 16]) self.assertEqual(len(self.n2), 2) def test_layer_n6(self): - self.assertEqual(self.n6.get_shape().as_list()[1:], [16, 16, - 32]) + self.assertEqual(self.n6.get_shape().as_list()[1:], [16, 16, 32]) + if __name__ == '__main__': From b6fd6e6a420e07d6528d1677068a9a249ffd2f5d Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 17:50:19 +0800 Subject: [PATCH 08/17] Fix indentation --- tensorlayer/layers/convolution/oct_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index d59658b5e..a736c2e08 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -24,7 +24,7 @@ class OctConv2dIn(Layer): The :class:`OctConv2dIn` class is a preprocessing layer for 2D image [batch, height, width, channel], see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave - Convolution `__. + Convolution `__. Parameters ---------- name : None or str From 9c6997f3281942a36b58f78352a07b16e6b9be66 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 17:59:24 +0800 Subject: [PATCH 09/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index a736c2e08..ff0e22e4e 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -96,8 +96,7 @@ class OctConv2d(Layer): A unique layer name. Notes ----- - - The input should be a list with shape [high_res_tensor , low_res_tensor], - the height and width of high_res should be twice of the low_res_tensor. + - The input should be a list with shape [high_res_tensor , low_res_tensor], the height and width of high_res should be twice of the low_res_tensor. - If you do not which tensor is larger, use OctConv2dConcat layer. - The output will be a list which contains 2 tensor. - You should not use the output directly. @@ -338,8 +337,7 @@ class OctConv2dHighOut(Layer): Notes ----- - Use this layer to get high resolution tensor. - - If you want to do some customized normalization ops, use this layer with - OctConv2dLowOut and OctConv2dConcat layers to implement your idea. + - If you want to do some customized normalization ops, use this layer with OctConv2dLowOut and OctConv2dConcat layers to implement your idea. Examples -------- With TensorLayer @@ -390,8 +388,7 @@ class OctConv2dLowOut(Layer): Notes ----- - Use this layer to get low resolution tensor. - - If you want to do some customized normalization ops, use this layer with - OctConv2dHighOut and OctConv2dConcat layers to implement your idea. + - If you want to do some customized normalization ops, use this layer with OctConv2dHighOut and OctConv2dConcat layers to implement your idea. Examples -------- With TensorLayer From 3c4fdf62752215c3e40ec48666e823f1e6fc8077 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 18:07:14 +0800 Subject: [PATCH 10/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index ff0e22e4e..8a74fa32e 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -77,7 +77,7 @@ class OctConv2d(Layer): Octave Convolution `__. Use this layer to process tensor list. Parameters ---------- - filter : int + nfilter : int The sum of the number of filters. alpha : :float The percentage of high_res output. From 055f6dde8a3c37899dcb728bab6a646feda1483e Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 18:17:06 +0800 Subject: [PATCH 11/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 8a74fa32e..079265332 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -79,7 +79,7 @@ class OctConv2d(Layer): ---------- nfilter : int The sum of the number of filters. - alpha : :float + alpha : float The percentage of high_res output. filter_size : tuple of int The filter size (height, width). @@ -122,7 +122,7 @@ def __init__( b_init=tl.initializers.constant(value=0.0), act=None, in_channels=None, - name=None # 'cnn2d_layer', + name=None ): super().__init__(name) self.filter = nfilter From 7fcb9289fadb4bbaedae7cb9eb36dd53a2c5c1fc Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 18:37:39 +0800 Subject: [PATCH 12/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 079265332..2b03c50fd 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -80,7 +80,7 @@ class OctConv2d(Layer): nfilter : int The sum of the number of filters. alpha : float - The percentage of high_res output. + The percentage of highres output. filter_size : tuple of int The filter size (height, width). strides : tuple of int @@ -96,7 +96,7 @@ class OctConv2d(Layer): A unique layer name. Notes ----- - - The input should be a list with shape [high_res_tensor , low_res_tensor], the height and width of high_res should be twice of the low_res_tensor. + - The input should be a list with shape (highrestensor , lowrestensor), the height and width of high_res should be twice of the low_res_tensor. - If you do not which tensor is larger, use OctConv2dConcat layer. - The output will be a list which contains 2 tensor. - You should not use the output directly. From e7f40f2b1a92321d596486c41c3a90e09cb6e719 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 18:44:01 +0800 Subject: [PATCH 13/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 2b03c50fd..85055d60d 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -113,16 +113,9 @@ class OctConv2d(Layer): """ def __init__( - self, - nfilter=32, - alpha=0.5, - filter_size=(2, 2), - strides=(1, 1), - W_init=tl.initializers.truncated_normal(stddev=0.02), - b_init=tl.initializers.constant(value=0.0), - act=None, - in_channels=None, - name=None + self, nfilter=32, alpha=0.5, filter_size=(2, 2), strides=(1, 1), + W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, + in_channels=None, name=None ): super().__init__(name) self.filter = nfilter From 8cc823ca2fea80968bb58edbdc3e5a23b4318996 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 18:57:49 +0800 Subject: [PATCH 14/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 85055d60d..534e5bad1 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -71,8 +71,7 @@ def forward(self, inputs): class OctConv2d(Layer): - """ - The :class:`OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see + """The :class: `OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. Use this layer to process tensor list. Parameters @@ -96,7 +95,7 @@ class OctConv2d(Layer): A unique layer name. Notes ----- - - The input should be a list with shape (highrestensor , lowrestensor), the height and width of high_res should be twice of the low_res_tensor. + - The input should be a list with shape (highrestensor, lowrestensor), the height and width of high_res should be twice of the low_res_tensor. - If you do not which tensor is larger, use OctConv2dConcat layer. - The output will be a list which contains 2 tensor. - You should not use the output directly. From b09cfe537153903cd83dc8c2fc51841dd813fc8c Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 19:04:58 +0800 Subject: [PATCH 15/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 534e5bad1..894a1bf7f 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -84,7 +84,6 @@ class OctConv2d(Layer): The filter size (height, width). strides : tuple of int The sliding window strides of corresponding input dimensions. - It must be in the same order as the ``shape`` parameter. W_init : initializer The initializer for the weight matrix. b_init : initializer or None From 59df2c56bc461929ec62356ab163b2ec75d73132 Mon Sep 17 00:00:00 2001 From: Windaway Date: Fri, 17 May 2019 19:16:41 +0800 Subject: [PATCH 16/17] Fix doc --- tensorlayer/layers/convolution/oct_conv.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py index 894a1bf7f..bfd07f348 100644 --- a/tensorlayer/layers/convolution/oct_conv.py +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -213,7 +213,6 @@ class OctConv2dOut(Layer): The filter size (height, width). strides : tuple of int The sliding window strides of corresponding input dimensions. - It must be in the same order as the ``shape`` parameter. W_init : initializer The initializer for the weight matrix. b_init : initializer or None From 7afd8f0a39a4f1864a82e508f7a326fc998dc033 Mon Sep 17 00:00:00 2001 From: Windaway Date: Sun, 2 Jun 2019 14:25:09 +0800 Subject: [PATCH 17/17] Add OctConv Layers and experiential coord transform func. Add Basic OctConv Layers . Add Coord Transform Func. Fix Huber Loss Doc. --- tensorlayer/prepro.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorlayer/prepro.py b/tensorlayer/prepro.py index ebbde9de4..8b98a47f9 100644 --- a/tensorlayer/prepro.py +++ b/tensorlayer/prepro.py @@ -4229,7 +4229,8 @@ def obj_box_coord_affine( def rotated_obj_box_coord_affine(classes=None, coords=None, affine_matrix=None, affine_matrix_inv=None): """Apply affine transform the box coordinates with rotation, and gets the - new box coordinates with rotation. Experimental! + new box coordinates with rotation. Experimental! Box angle are normalized + to [-pi/2, pi/2]. Parameters -----------