Skip to content

Commit

Permalink
test quan_conv, doc conv convlayer 1d2d3d fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
warshallrho committed Mar 17, 2019
1 parent c2f2167 commit 5e78bf7
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 227 deletions.
23 changes: 7 additions & 16 deletions tensorlayer/layers/convolution/deformable_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# -*- coding: utf-8 -*-

import tensorflow as tf
import tensorlayer as tl

from tensorlayer.layers.core import Layer
# from tensorlayer.layers.core import LayersConfig

from tensorlayer import logging

Expand All @@ -22,8 +22,6 @@ class DeformableConv2d(Layer):
Parameters
----------
prev_layer : :class:`Layer`
Previous layer.
offset_layer : :class:`Layer`
To predict the offset of convolution operations.
The output shape is (batchsize, input height, input width, 2*(number of element in the convolution kernel))
Expand All @@ -38,10 +36,6 @@ class DeformableConv2d(Layer):
The initializer for the weight matrix.
b_init : initializer or None
The initializer for the bias vector. If None, skip biases.
W_init_args : dictionary
The arguments for the weight matrix initializer.
b_init_args : dictionary
The arguments for the bias vector initializer.
name : str
A unique layer name.
Expand All @@ -64,24 +58,21 @@ class DeformableConv2d(Layer):
"""

@deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
# @deprecated_alias(layer='prev_layer', end_support_version=1.9) # TODO remove this line for the 1.9 release
def __init__(
self,
prev_layer,
offset_layer=None,
# shape=(3, 3, 1, 100),
n_filter=32,
filter_size=(3, 3),
act=None,
W_init=tl.initializers.truncated_normal(stddev=0.02),
b_init=tl.initializers.constant(value=0.0),
name='deformable_conv_2d',
W_init=tf.compat.v1.initializers.truncated_normal(stddev=0.02),
b_init=tf.compat.v1.initializers.constant(value=0.0),
W_init_args=None,
b_init_args=None
):

super(DeformableConv2d, self
).__init__(prev_layer=prev_layer, act=act, W_init_args=W_init_args, b_init_args=b_init_args, name=name)
# super(DeformableConv2d, self
# ).__init__(prev_layer=prev_layer, act=act, W_init_args=W_init_args, b_init_args=b_init_args, name=name)
super().__init__(name)

logging.info(
"DeformableConv2d %s: n_filter: %d, filter_size: %s act: %s" %
Expand Down
95 changes: 55 additions & 40 deletions tensorlayer/layers/convolution/expert_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ class Conv1dLayer(Layer):
name : None or str
A unique layer name
Notes
-----
- shape = [w, the number of output channel of previous layer, the number of output channels]
- the number of output channel of a layer is its last dimension.
Examples
--------
With TensorLayer
>>> net = tl.layers.Input([8, 100, 1], name='input')
>>> conv1d = tl.layers.Conv1dLayer(shape=(5, 1, 32), stride=2, b_init=None, name='conv1d_1')
>>> print(conv1d)
>>> tensor = tl.layers.Conv1dLayer(shape=(5, 1, 32), stride=2, act=tf.nn.relu, name='conv1d_2')(net)
>>> print(tensor)
"""

def __init__(
Expand Down Expand Up @@ -71,6 +86,10 @@ def __init__(
self.in_channels = shape[-2]
self.name = name

if self.in_channels:
self.build(None)
self._built = True

logging.info(
"Conv1dLayer %s: shape: %s stride: %s pad: %s act: %s" % (
self.name, str(shape), str(stride), padding,
Expand Down Expand Up @@ -98,9 +117,7 @@ def build(self, inputs_shape):
)
if self.b_init:
self.b = self._get_weights(
"biases",
shape=(self.n_filter), #self.shape[-1]),
init=self.b_init
"biases", shape=(self.n_filter), init=self.b_init
)

def forward(self, inputs):
Expand All @@ -110,7 +127,7 @@ def forward(self, inputs):
filters=self.W,
stride=self.stride,
padding=self.padding,
dilations=(self.dilation_rate, ),
dilations=[self.dilation_rate, ],
data_format=self.data_format,
name=self.name,
)
Expand Down Expand Up @@ -139,7 +156,7 @@ class Conv2dLayer(Layer):
The padding algorithm type: "SAME" or "VALID".
data_format : str
"NHWC" or "NCHW", default is "NHWC".
dilation_rate : list of int
dilation_rate : tuple of int
Filter up-sampling/input down-sampling rate.
W_init : initializer
The initializer for the weight matrix.
Expand All @@ -157,30 +174,11 @@ class Conv2dLayer(Layer):
--------
With TensorLayer
>>> x = tf.placeholder(tf.float32, shape=(None, 28, 28, 1))
>>> net = tl.layers.Input(x, name='input_layer')
>>> net = tl.layers.Conv2dLayer(net,
... act = tf.nn.relu,
... shape = (5, 5, 1, 32), # 32 features for each 5x5 patch
... strides = (1, 1, 1, 1),
... padding='SAME',
... W_init=tf.truncated_normal_initializer(stddev=5e-2),
... b_init = tf.constant_initializer(value=0.0),
... name ='cnn_layer1') # output: (?, 28, 28, 32)
>>> net = tl.layers.Pool(net,
... ksize=(1, 2, 2, 1),
... strides=(1, 2, 2, 1),
... padding='SAME',
... pool = tf.nn.max_pool,
... name ='pool_layer1',) # output: (?, 14, 14, 32)
Without TensorLayer, you can implement 2D convolution as follow.
>>> W = tf.Variable(W_init(shape=[5, 5, 1, 32], ), name='W_conv')
>>> b = tf.Variable(b_init(shape=[32], ), name='b_conv')
>>> outputs = tf.nn.relu( tf.nn.conv2d(inputs, W,
... strides=[1, 1, 1, 1],
... padding='SAME') + b )
>>> net = tl.layers.Input([8, 28, 28, 1], name='input')
>>> conv2d = tl.layers.Conv2dLayer(shape=(5, 5, 1, 32), strides=(1, 1, 1, 1), b_init=None, name='conv2d_1')
>>> print(conv2d)
>>> tensor = tl.layers.Conv2dLayer(shape=(5, 5, 1, 32), strides=(1, 1, 1, 1), act=tf.nn.relu, name='conv2d_2')(net)
>>> print(tensor)
"""

Expand All @@ -190,8 +188,8 @@ def __init__(
shape=(5, 5, 1, 100),
strides=(1, 1, 1, 1),
padding='SAME',
data_format=None,
dilation_rate=[1, 1, 1, 1],
data_format='NHWC',
dilation_rate=(1, 1, 1, 1),
W_init=tl.initializers.truncated_normal(stddev=0.02),
b_init=tl.initializers.constant(value=0.0),
name='cnn2d_layer',
Expand All @@ -210,6 +208,10 @@ def __init__(
self.in_channels = shape[-2]
self.name = name

if self.in_channels:
self.build(None)
self._built = True

logging.info(
"Conv2dLayer %s: shape: %s strides: %s pad: %s act: %s" % (
self.name, str(shape), str(strides), padding,
Expand Down Expand Up @@ -247,7 +249,7 @@ def forward(self, inputs):
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilations=self.dilation_rate,
dilations=list(self.dilation_rate),
name=self.name,
)

Expand Down Expand Up @@ -275,7 +277,7 @@ class Conv3dLayer(Layer):
The padding algorithm type: "SAME" or "VALID".
data_format : str
"NDHWC" or "NCDHW", default is "NDHWC".
dilation_rate : list of int
dilation_rate : tuple of int
Filter up-sampling/input down-sampling rate.
W_init : initializer
The initializer for the weight matrix.
Expand All @@ -284,12 +286,21 @@ class Conv3dLayer(Layer):
name : None or str
A unique layer name.
Notes
-----
- shape = [d, h, w, the number of output channel of previous layer, the number of output channels]
- the number of output channel of a layer is its last dimension.
Examples
---------
>>> x = tf.placeholder(tf.float32, (None, 100, 100, 100, 3))
>>> n = tl.layers.Input(x, name='in3')
>>> n = tl.layers.Conv3dLayer(n, shape=(2, 2, 2, 3, 32), strides=(1, 2, 2, 2, 1))
[None, 50, 50, 50, 32]
--------
With TensorLayer
>>> net = tl.layers.Input([8, 100, 100, 100, 3], name='input')
>>> conv2d = tl.layers.Conv2dLayer(shape=(2, 2, 2, 3, 32), strides=(1, 2, 2, 2, 1), b_init=None, name='conv3d_1')
>>> print(conv2d)
>>> tensor = tl.layers.Conv2dLayer(shape=(2, 2, 2, 3, 32), strides=(1, 2, 2, 2, 1), act=tf.nn.relu, name='conv3d_2')(net)
>>> print(tensor)
"""

def __init__(
Expand All @@ -299,7 +310,7 @@ def __init__(
strides=(1, 2, 2, 2, 1),
padding='SAME',
data_format='NDHWC',
dilation_rate=[1, 1, 1, 1, 1],
dilation_rate=(1, 1, 1, 1, 1),
W_init=tl.initializers.truncated_normal(stddev=0.02),
b_init=tl.initializers.constant(value=0.0),
name='cnn3d_layer'
Expand All @@ -318,6 +329,10 @@ def __init__(
self.in_channels = shape[-2]
self.name = name

if self.in_channels:
self.build(None)
self._built = True

logging.info(
"Conv3dLayer %s: shape: %s strides: %s pad: %s act: %s" % (
self.name, str(shape), str(strides), padding,
Expand Down Expand Up @@ -356,7 +371,7 @@ def forward(self, inputs):
strides=self.strides,
padding=self.padding,
data_format=self.data_format, #'NDHWC',
dilations=self.dilation_rate, #[1, 1, 1, 1, 1],
dilations=list(self.dilation_rate), #[1, 1, 1, 1, 1],
name=self.name,
)

Expand Down
14 changes: 7 additions & 7 deletions tensorlayer/layers/convolution/expert_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class DeConv1dLayer(Layer):
The padding algorithm type: "SAME" or "VALID".
data_format : str
"NWC" or "NCW", default is "NWC".
dilation_rate : list of int
dilation_rate : int
Filter up-sampling/input down-sampling rate.
W_init : initializer
The initializer for the weight matrix.
Expand Down Expand Up @@ -122,7 +122,7 @@ def forward(self, inputs):
strides=list(self.strides),
padding=self.padding,
data_format=self.data_format,
dilations=self.dilation_rate,
dilations=list(self.dilation_rate),
name=self.name,
)
if self.b_init:
Expand Down Expand Up @@ -152,7 +152,7 @@ class DeConv2dLayer(Layer):
The padding algorithm type: "SAME" or "VALID".
data_format : str
"NHHWC" or "NCW", default is "NHWC".
dilation_rate : list of int
dilation_rate : tuple of int
Filter up-sampling/input down-sampling rate.
W_init : initializer
The initializer for the weight matrix.
Expand Down Expand Up @@ -269,7 +269,7 @@ def forward(self, inputs):
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilations=self.dilation_rate,
dilations=list(self.dilation_rate),
name=self.name,
)
if self.b_init:
Expand Down Expand Up @@ -297,7 +297,7 @@ class DeConv3dLayer(Layer):
The padding algorithm type: "SAME" or "VALID".
data_format : str
"NDHWC" or "NCDHW", default is "NDHWC".
dilation_rate : list of int
dilation_rate : tuple of int
Filter up-sampling/input down-sampling rate.
W_init : initializer
The initializer for the weight matrix.
Expand All @@ -316,7 +316,7 @@ def __init__(
strides=(1, 2, 2, 2, 1),
padding='SAME',
data_format='NDHWC',
dilation_rate=[1, 1, 1, 1, 1],
dilation_rate=(1, 1, 1, 1, 1),
W_init=tl.initializers.truncated_normal(stddev=0.02),
b_init=tl.initializers.constant(value=0.0),
name='decnn3d_layer',
Expand Down Expand Up @@ -370,7 +370,7 @@ def forward(self, inputs):
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilations=self.dilation_rate,
dilations=list(self.dilation_rate),
name=self.name
)
if self.b_init:
Expand Down

0 comments on commit 5e78bf7

Please sign in to comment.