diff --git a/docs/modules/cost.rst b/docs/modules/cost.rst index eba52f4ca..bf9fe939c 100644 --- a/docs/modules/cost.rst +++ b/docs/modules/cost.rst @@ -96,5 +96,5 @@ Special .. autofunction:: maxnorm_i_regularizer Huber Loss -^^^^^^^^^^ +-------------------------- .. autofunction:: huber_loss \ No newline at end of file diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index 7a70b54dc..81ee90a2d 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -39,6 +39,12 @@ Layer list SeparableConv2d DeformableConv2d GroupConv2d + OctConv2dIn + OctConv2d + OctConv2dOut + OctConv2dHighOut + OctConv2dLowOut + OctConv2dConcat PadLayer PoolLayer @@ -241,6 +247,34 @@ GroupConv2d """"""""""""""""""""" .. autoclass:: GroupConv2d +OctConv +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For OctConv2d, see `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution `__. + +OctConv2dIn +""""""""""""""""""""" +.. autoclass:: OctConv2dIn + +OctConv2d +""""""""""""""""""""" +.. autoclass:: OctConv2d + +OctConv2dOut +""""""""""""""""""""" +.. autoclass:: OctConv2dOut + +OctConv2dHighOut +""""""""""""""""""""" +.. autoclass:: OctConv2dHighOut + +OctConv2dConcat +""""""""""""""""""""" +.. autoclass:: OctConv2dConcat + +OctConv2dLowOut +""""""""""""""""""""" +.. autoclass:: OctConv2dLowOut Separable Convolutions ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/modules/prepro.rst b/docs/modules/prepro.rst index ec65da066..da1bd3189 100644 --- a/docs/modules/prepro.rst +++ b/docs/modules/prepro.rst @@ -79,6 +79,8 @@ API - Data Pre-Processing obj_box_coord_upleft_butright_to_centroid obj_box_coord_centroid_to_upleft obj_box_coord_upleft_to_centroid + obj_box_coord_affine + rotated_obj_box_coord_affine parse_darknet_ann_str_to_list parse_darknet_ann_list_to_cls_box @@ -577,6 +579,14 @@ Image Aug - Zoom ^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: obj_box_zoom +Image Aug - Affine +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autofunction:: obj_box_coord_affine + +Image Aug - Rotated-Affine +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autofunction:: rotated_obj_box_coord_affine + Keypoints ------------ diff --git a/tensorlayer/layers/convolution/__init__.py b/tensorlayer/layers/convolution/__init__.py index ba68797f2..c96d4c5d0 100644 --- a/tensorlayer/layers/convolution/__init__.py +++ b/tensorlayer/layers/convolution/__init__.py @@ -24,6 +24,7 @@ from .ternary_conv import * from .quan_conv import * from .quan_conv_bn import * +from .oct_conv import * __all__ = [ @@ -80,4 +81,12 @@ #quan_conv 'QuanConv2d', 'QuanConv2dWithBN', + + # octave_conv + 'OctConv2dIn', + 'OctConv2d', + 'OctConv2dOut', + 'OctConv2dHighOut', + 'OctConv2dLowOut', + 'OctConv2dConcat', ] diff --git a/tensorlayer/layers/convolution/oct_conv.py b/tensorlayer/layers/convolution/oct_conv.py new file mode 100644 index 000000000..bfd07f348 --- /dev/null +++ b/tensorlayer/layers/convolution/oct_conv.py @@ -0,0 +1,473 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import tensorflow as tf + +import tensorlayer as tl +from tensorlayer import logging +from tensorlayer.layers.core import Layer + +# from tensorlayer.layers.core import LayersConfig + +__all__ = [ + 'OctConv2dIn', + 'OctConv2d', + 'OctConv2dOut', + 'OctConv2dHighOut', + 'OctConv2dLowOut', + 'OctConv2dConcat', +] + + +class OctConv2dIn(Layer): + """ + The :class:`OctConv2dIn` class is a preprocessing layer for + 2D image [batch, height, width, channel], see `Drop an Octave: Reducing + Spatial Redundancy in Convolutional Neural Networks with Octave + Convolution `__. + Parameters + ---------- + name : None or str + A unique layer name. + Notes + ----- + - The height and width of input must be a multiple of the 2. + - Use this layer before any other octconv layers. + - The output will be a list which contains 2 tensor. + Examples + -------- + With TensorLayer + >>> net = tl.layers.Input([8, 28, 28, 16], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info("OctConv2dIn %s: " % (self.name, )) + + def __repr__(self): + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape=None): + pass + + def forward(self, inputs): + high_out = tf.identity(inputs, name=(self.name + '_high_out')) + low_out = tf.nn.avg_pool2d(inputs, (2, 2), strides=(2, 2), padding='SAME', name=self.name + '_low_out') + outputs = [high_out, low_out] + return outputs + + +class OctConv2d(Layer): + """The :class: `OctConv2d` class is a 2D CNN layer for OctConv2d layer output, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. Use this layer to process tensor list. + Parameters + ---------- + nfilter : int + The sum of the number of filters. + alpha : float + The percentage of highres output. + filter_size : tuple of int + The filter size (height, width). + strides : tuple of int + The sliding window strides of corresponding input dimensions. + W_init : initializer + The initializer for the weight matrix. + b_init : initializer or None + The initializer for the bias vector. If None, skip biases. + act : activation function + The activation function of this layer. + name : None or str + A unique layer name. + Notes + ----- + - The input should be a list with shape (highrestensor, lowrestensor), the height and width of high_res should be twice of the low_res_tensor. + - If you do not which tensor is larger, use OctConv2dConcat layer. + - The output will be a list which contains 2 tensor. + - You should not use the output directly. + Examples + -------- + With TensorLayer + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2d(32,0.5,act=tf.nn.relu, name='octconv2d_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] + """ + + def __init__( + self, nfilter=32, alpha=0.5, filter_size=(2, 2), strides=(1, 1), + W_init=tl.initializers.truncated_normal(stddev=0.02), b_init=tl.initializers.constant(value=0.0), act=None, + in_channels=None, name=None + ): + super().__init__(name) + self.filter = nfilter + self.alpha = alpha + if (self.alpha >= 1) or (self.alpha <= 0): + raise ValueError("The alpha must be in (0,1)") + self.high_out = int(self.alpha * self.filter) + self.low_out = self.filter - self.high_out + if (self.high_out == 0) or (self.low_out == 0): + raise ValueError("The output channel must be greater than 0.") + self.filter_size = filter_size + self.strides = strides + self.W_init = W_init + self.b_init = b_init + self.act = act + self.in_channels = in_channels + if self.in_channels: + self.build(None) + self._built = True + + logging.info( + "OctConv2d %s: filter_size: %s strides: %s high_out: %s low_out: %s act: %s" % ( + self.name, str(filter_size), str(strides), str(self.high_out), str(self.low_out), + self.act.__name__ if self.act is not None else 'No Activation' + ) + ) + + def __repr__(self): + actstr = self.act.__name__ if self.act is not None else 'No Activation' + s = ( + '{classname}(in_channels={in_channels}, out_channels={filter}, ' + 'kernel_size={filter_size}, strides={strides}' + ) + if self.b_init is None: + s += ', bias=False' + s += (', ' + actstr) + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape): + if not self.in_channels: + high_ch = inputs_shape[0][-1] + low_ch = inputs_shape[1][-1] + else: + high_ch = self.in_channels[0] + low_ch = self.in_channels[1] + self.high_high_filter_shape = (self.filter_size[0], self.filter_size[1], high_ch, self.high_out) + self.high_low_filter_shape = (self.filter_size[0], self.filter_size[1], high_ch, self.low_out) + self.low_low_filter_shape = (self.filter_size[0], self.filter_size[1], low_ch, self.low_out) + self.low_high_filter_shape = (self.filter_size[0], self.filter_size[1], low_ch, self.high_out) + self.high_high__W = self._get_weights("high_high_filters", shape=self.high_high_filter_shape, init=self.W_init) + self.high_low__W = self._get_weights("high_low_filters", shape=self.high_low_filter_shape, init=self.W_init) + self.low_low_W = self._get_weights("low_low_filters", shape=self.low_low_filter_shape, init=self.W_init) + self.low_high_W = self._get_weights("low_high_filters", shape=self.low_high_filter_shape, init=self.W_init) + if self.b_init: + self.high_b = self._get_weights("high_biases", shape=(self.high_out), init=self.b_init) + self.low_b = self._get_weights("low_biases", shape=(self.low_out), init=self.b_init) + + def forward(self, inputs): + high_input = inputs[0] + low_input = inputs[1] + high_to_high = tf.nn.conv2d(high_input, self.high_high__W, strides=self.strides, padding="SAME") + high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, strides=self.strides, padding="SAME") + low_to_low = tf.nn.conv2d(low_input, self.low_low_W, strides=self.strides, padding="SAME") + low_to_high = tf.nn.conv2d(low_input, self.low_high_W, strides=self.strides, padding="SAME") + low_to_high = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest')(low_to_high) + high_out = high_to_high + low_to_high + low_out = low_to_low + high_to_low + if self.b_init: + high_out = tf.nn.bias_add(high_out, self.high_b, data_format="NHWC") + low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") + if self.act: + high_out = self.act(high_out, name=self.name + '_high_out') + low_out = self.act(low_out, name=self.name + '_low_out') + else: + high_out = tf.identity(high_out, name=self.name + '_high_out') + low_out = tf.identity(low_out, name=self.name + '_low_out') + outputs = [high_out, low_out] + return outputs + + +class OctConv2dOut(Layer): + """ + The :class:`OctConv2dOut` class is a 2D CNN layer for OctConv2d layer + output to get only a tensor, see` Drop an Octave: Reducing Spatial Redundancy + in Convolutional Neural Networks with Octave Convolution + `__. + Parameters + ---------- + filter : int + The number of filters. + filter_size : tuple of int + The filter size (height, width). + strides : tuple of int + The sliding window strides of corresponding input dimensions. + W_init : initializer + The initializer for the weight matrix. + b_init : initializer or None + The initializer for the bias vector. If None, skip biases. + act : activation function + The activation function of this layer. + name : None or str + A unique layer name. + Notes + ----- + - Use this layer to get only a tensor for other normal layer. + Examples + -------- + With TensorLayer + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2d(32,0.5,act=tf.nn.relu, name='octconv2d_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 16),(8, 14, 14, 16)] + >>> octconv2d = tl.layers.OctConv2dOut(32,act=tf.nn.relu, name='octconv2dout_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : (8, 14, 14, 32) + """ + + def __init__( + self, + n_filter=32, + filter_size=(2, 2), + strides=(1, 1), + W_init=tl.initializers.truncated_normal(stddev=0.02), + b_init=tl.initializers.constant(value=0.0), + act=None, + in_channels=None, + name=None # 'cnn2d_layer', + ): + super().__init__(name) + + self.high_out = n_filter + self.low_out = n_filter + self.filter_size = filter_size + self.strides = strides + self.W_init = W_init + self.b_init = b_init + self.act = act + self.in_channels = in_channels + if self.in_channels: + self.build(None) + self._built = True + + logging.info( + "OctConv2dOut %s: filter_size: %s strides: %s out_channels: %s act: %s" % ( + self.name, str(filter_size), str(strides), str(self.low_out), + self.act.__name__ if self.act is not None else 'No Activation' + ) + ) + + def __repr__(self): + actstr = self.act.__name__ if self.act is not None else 'No Activation' + s = ( + '{classname}(in_channels={in_channels}, out_channels={low_out}, ' + 'kernel_size={filter_size}, strides={strides}' + ) + if self.b_init is None: + s += ', bias=False' + s += (', ' + actstr) + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape): + if not self.in_channels: + high_ch = inputs_shape[0][-1] + low_ch = inputs_shape[1][-1] + else: + high_ch = self.in_channels[0] + low_ch = self.in_channels[1] + self.high_low_filter_shape = (self.filter_size[0], self.filter_size[1], high_ch, self.high_out) + self.low_low_filter_shape = (self.filter_size[0], self.filter_size[1], low_ch, self.low_out) + self.high_low__W = self._get_weights("high_low_filters", shape=self.high_low_filter_shape, init=self.W_init) + self.low_low_W = self._get_weights("low_low_filters", shape=self.low_low_filter_shape, init=self.W_init) + if self.b_init: + self.low_b = self._get_weights("low_biases", shape=(self.low_out), init=self.b_init) + + def forward(self, inputs): + high_input = inputs[0] + low_input = inputs[1] + high_to_low = tf.nn.avg_pool2d(high_input, (2, 2), strides=(2, 2), padding='SAME') + high_to_low = tf.nn.conv2d(high_to_low, self.high_low__W, strides=self.strides, padding="SAME") + low_to_low = tf.nn.conv2d(low_input, self.low_low_W, strides=self.strides, padding="SAME") + low_out = low_to_low + high_to_low + if self.b_init: + low_out = tf.nn.bias_add(low_out, self.low_b, data_format="NHWC") + if self.act: + outputs = self.act(low_out, name=self.name + '_low_out') + else: + outputs = tf.identity(low_out, name=self.name + '_low_out') + return outputs + + +class OctConv2dHighOut(Layer): + """ + The :class:`OctConv2dHighOut` class is a slice layer for Octconv tensor list, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. + Parameters + ---------- + name : None or str + A unique layer name. + Notes + ----- + - Use this layer to get high resolution tensor. + - If you want to do some customized normalization ops, use this layer with OctConv2dLowOut and OctConv2dConcat layers to implement your idea. + Examples + -------- + With TensorLayer + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2dHighOut(name='octconv2dho_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : (8, 28, 28, 32) + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info("OctConv2dHighOut %s: " % (self.name, )) + + def __repr__(self): + + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape=None): + pass + + def forward(self, inputs): + outputs = tf.identity(inputs[0], self.name) + return outputs + + +class OctConv2dLowOut(Layer): + """ + The :class:`OctConv2dLowOut` class is a slice layer for Octconv tensor list, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. + Parameters + ---------- + name : None or str + A unique layer name. + Notes + ----- + - Use this layer to get low resolution tensor. + - If you want to do some customized normalization ops, use this layer with OctConv2dHighOut and OctConv2dConcat layers to implement your idea. + Examples + -------- + With TensorLayer + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1')(net) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2d = tl.layers.OctConv2dLowOut(name='octconv2dlo_1')(octconv2d) + >>> print(octconv2d) + >>> output shape : (8, 14, 14, 32) + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info("OctConv2dHighOut %s: " % (self.name, )) + + def __repr__(self): + + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape=None): + pass + + def forward(self, inputs): + outputs = tf.identity(inputs[1], self.name) + return outputs + + +class OctConv2dConcat(Layer): + """ + The :class:`OctConv2dConcat` class is a concat layer for two 2D image batches, see + `Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with + Octave Convolution `__. + Parameters + ---------- + name : None or str + A unique layer name. + Notes + ----- + - Use this layer to concat two tensor. + - The height and width of one tensor should be twice of the other tensor. + Examples + -------- + With TensorLayer + >>> net = tl.layers.Input([8, 28, 28, 32], name='input') + >>> octconv2d = tl.layers.OctConv2dIn(name='octconv2din_1') + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + >>> octconv2dl = tl.layers.OctConv2dLowOut(name='octconv2dlo_1')(octconv2d) + >>> octconv2dh = tl.layers.OctConv2dHighOut(name='octconv2dho_1')(octconv2d) + >>> octconv2 = tl.layers.OctConv2dConcat(name='octconv2dcat_1')([octconv2dh,octconv2dl]) + >>> print(octconv2d) + >>> output shape : [(8, 28, 28, 32),(8, 14, 14, 32)] + """ + + def __init__( + self, + name=None, # 'cnn2d_layer', + ): + super().__init__(name) + self.build(None) + self._built = True + + logging.info("OctConv2dConcat %s: " % (self.name, )) + + def __repr__(self): + + s = ('{classname}(') + if self.name is not None: + s += ', name=\'{name}\'' + s += ')' + return s.format(classname=self.__class__.__name__, **self.__dict__) + + def build(self, inputs_shape=None): + pass + + def forward(self, inputs): + if inputs[0].shape[1] > inputs[1].shape[1]: + outputs = [inputs[0], inputs[1]] + else: + outputs = [inputs[1], inputs[0]] + return outputs diff --git a/tensorlayer/prepro.py b/tensorlayer/prepro.py index 931694308..8b98a47f9 100644 --- a/tensorlayer/prepro.py +++ b/tensorlayer/prepro.py @@ -111,6 +111,8 @@ 'keypoint_random_flip', 'keypoint_random_resize', 'keypoint_random_resize_shortestedge', + 'obj_box_coord_affine', + 'rotated_obj_box_coord_affine', ] @@ -4145,3 +4147,152 @@ def pose_resize_shortestedge(image, annos, mask, target_size): return dst, adjust_joint_list, None return pose_resize_shortestedge(image, annos, mask, target_size) + + +def obj_box_coord_affine( + classes=None, coords=None, affine_matrix=None, affine_matrix_inv=None, min_ratio=0.0, min_width=0.0, + min_height=0.0 +): + """Apply affine transform the box coordinates, and gets the new box coordinates. + + Parameters + ----------- + classes : list of int or None + Class IDs. + coords : list of list of 4 int/float or None + Coordinates [[x, y, w, h], [x, y, w, h], ...]. Here x,y are + the center coordinates of the bbox. + affine_matrix : np.ndarray + The affine matrix of the image. A ndarray with shape [2,3]. + affine_matrix_inv : np.ndarray + The inverse of the affine matrix. A ndarray with shape [2,3]. + min_ratio : float + Threshold, remove the box if its ratio of new size to old size less + than the threshold. + min_width : float + Threshold, remove the box if its ratio of width to image size less + than the threshold. + min_height : float + Threshold, remove the box if its ratio of height to image size less + than the threshold. + + Returns + ------- + list of int + A list of classes + list of list of 4 numbers + A list of new bounding boxes. + """ + Me = np.array([[0, 0, 1.]]) + new_classes = [] + new_cords = [] + if affine_matrix_inv is None: + affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) + for (bbox_idx, bbox) in enumerate(coords): + old_pt = np.array([[bbox[0] * 2. - 1.], [bbox[1] * 2. - 1.], [1.]]) + new_wh_a = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[bbox[2]], [bbox[3]]])) + new_wh_b = np.matmul(affine_matrix_inv[0:2, 0:2], np.array([[bbox[2]], [-bbox[3]]])) + new_w = max(abs(new_wh_a[0]), abs(new_wh_b[0])) + new_h = max(abs(new_wh_a[1]), abs(new_wh_b[1])) + new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.) / 2. + bbox_left = new_pt[0] - new_w / 2. + if bbox_left <= 0: + bbox_left = 0. + if bbox_left >= 1: + bbox_left = 1. + bbox_right = new_pt[0] + new_w / 2. + if bbox_right <= 0: + bbox_right = 0. + if bbox_right >= 1: + bbox_right = 1. + bbox_top = new_pt[1] + new_h / 2. + if bbox_top <= 0: + bbox_top = 0. + if bbox_top >= 1: + bbox_top = 1. + bbox_bottom = new_pt[1] - new_h / 2. + if bbox_bottom <= 0: + bbox_bottom = 0. + if bbox_bottom >= 1: + bbox_bottom = 1. + ratio = abs(bbox_right - bbox_left) * abs(bbox_top - bbox_bottom) \ + / (new_w + 0.00001) / (new_h + 0.00001) + bbox_x = abs(bbox_right + bbox_left) / 2. + bbox_y = abs(bbox_top + bbox_bottom) / 2. + bbox_h = abs(bbox_top - bbox_bottom) + bbox_w = abs(bbox_right - bbox_left) + if (ratio > min_ratio) & (bbox_h >= min_height) & (bbox_w >= min_width): + new_classes.append(classes[bbox_idx]) + new_cords.append([bbox_x, bbox_y, bbox_w, bbox_h]) + return (new_classes, new_cords) + + +def rotated_obj_box_coord_affine(classes=None, coords=None, affine_matrix=None, affine_matrix_inv=None): + """Apply affine transform the box coordinates with rotation, and gets the + new box coordinates with rotation. Experimental! Box angle are normalized + to [-pi/2, pi/2]. + + Parameters + ----------- + classes : list of int or None + Class IDs. + coords : list of list of 5 int/float or None + Coordinates [[x, y, w, h, r], [x, y, w, h, r], ...]. Here x,y are + the center coordinates of the bbox, r is the radius. + affine_matrix : np.ndarray + The affine matrix of the image. A ndarray with shape [2,3]. + affine_matrix_inv : np.ndarray + The inverse of the affine matrix. A ndarray with shape [2,3]. + + Returns + ------- + list of int + A list of classes + list of list of 5 numbers + A list of new bounding boxes. + """ + Me = np.array([[0, 0, 1.]]) + new_classes = [] + new_cords = [] + if affine_matrix_inv is None: + affine_matrix_inv = np.linalg.pinv(np.concatenate((affine_matrix, Me), axis=0)) + for (bbox_idx, bbox) in enumerate(coords): + centerx = bbox[0] + centery = bbox[1] + old_pt = np.array([[centerx * 2. - 1.], [centery * 2. - 1.], [1.]]) + top_center = np.array([[0.], [-bbox[3]], [1.]]) + right_center = np.array([[bbox[2]], [0.], [1.]]) + rot = bbox[-1] + rot_mat = np.array([[np.cos(rot), -np.sin(rot), 0], [np.sin(rot), np.cos(rot), 0], [0, 0, 1]]) + top_center = np.matmul(rot_mat, top_center) + old_pt + top_center[2][0] = 1. + right_center = np.matmul(rot_mat, right_center) + old_pt + right_center[2][0] = 1. + new_pt = (np.matmul(affine_matrix_inv, old_pt) + 1.) / 2. + new_topcenter = (np.matmul(affine_matrix_inv, top_center) + 1.) / 2. + new_rightcenter = (np.matmul(affine_matrix_inv, right_center) + 1.) \ + / 2. + new_h = np.sqrt((new_topcenter[0][0] - new_pt[0][0])**2 + (new_topcenter[1][0] - new_pt[1][0])**2) * 2. + new_w = np.sqrt((new_rightcenter[0][0] - new_pt[0][0])**2 + (new_rightcenter[1][0] - new_pt[1][0])**2) * 2. + deltax = -new_topcenter[0][0] + new_pt[0][0] + deltay = (-new_topcenter[1][0] + new_pt[1][0]) * -1 + if deltay == 0: + if deltax >= 0: + new_rot = -math.pi / 2 + else: + new_rot = math.pi / 2 + else: + tanx = deltax / deltay + new_rot = np.arctan(tanx) + if new_pt[0][0] > 0 and new_pt[0][0] < 1 and new_pt[1][0] > 0 \ + and new_pt[1][0] < 1: + new_cords.append([ + 1, + new_pt[0][0], + new_pt[1][0], + new_w, + new_h, + new_rot, + ]) + new_classes.append(classes[bbox_idx]) + return (new_classes, new_cords) diff --git a/tests/layers/test_layers_convolution.py b/tests/layers/test_layers_convolution.py index 0f5979d5b..890f2a330 100644 --- a/tests/layers/test_layers_convolution.py +++ b/tests/layers/test_layers_convolution.py @@ -462,6 +462,69 @@ def test_layer_n4(self): # self.assertEqual(self.net2.count_params(), 19392) # self.assertEqual(self.net2.outputs.get_shape().as_list()[1:], [299, 299, 64]) + +class Layer_OctConv_2D_Test(CustomTestCase): + + @classmethod + def setUpClass(cls): + print('\n#################################') + + cls.batch_size = 5 + cls.inputs_shape = [cls.batch_size, 32, 32, 16] + cls.input_layer = Input(cls.inputs_shape, name='input_layer') + + cls.n1 = tl.layers.OctConv2dIn(name='octconv2din')(cls.input_layer) + + cls.n2 = tl.layers.OctConv2d(32, 0.5, act=tf.nn.relu, name='octconv2d')(cls.n1) + + cls.n3 = tl.layers.OctConv2dHighOut(name='octconv2dho')(cls.n2) + + cls.n4 = tl.layers.OctConv2dLowOut(name='octconv2dlo')(cls.n2) + + cls.n5 = tl.layers.OctConv2dConcat(name='octconv2dconcat')([cls.n3, cls.n4]) + + cls.n6 = tl.layers.OctConv2dOut(n_filter=32, name='octconv2dout')(cls.n5) + + cls.model = Model(cls.input_layer, cls.n6) + print('Testing OctConv2d model: \n', cls.model) + + @classmethod + def tearDownClass(cls): + pass + + # tf.reset_default_graph() + + def test_layer_n1(self): + + self.assertEqual(self.n1[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n1[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(len(self.n1), 2) + + def test_layer_n2(self): + + self.assertEqual(self.n2[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n2[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(len(self.n2), 2) + + def test_layer_n3(self): + + self.assertEqual(self.n3.get_shape().as_list()[1:], [32, 32, 16]) + + def test_layer_n4(self): + + self.assertEqual(self.n4.get_shape().as_list()[1:], [16, 16, 16]) + + def test_layer_n5(self): + + self.assertEqual(self.n5[0].get_shape().as_list()[1:], [32, 32, 16]) + self.assertEqual(self.n5[1].get_shape().as_list()[1:], [16, 16, 16]) + self.assertEqual(len(self.n2), 2) + + def test_layer_n6(self): + + self.assertEqual(self.n6.get_shape().as_list()[1:], [16, 16, 32]) + + if __name__ == '__main__': tl.logging.set_verbosity(tl.logging.DEBUG)