From 427b5508b368d4a19faa810ea4aaffdbbed4179f Mon Sep 17 00:00:00 2001 From: lllcho Date: Thu, 15 Mar 2018 22:14:53 +0800 Subject: [PATCH 1/5] update by lllcho on March 15 --- tensorlayer/_logging.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorlayer/_logging.py b/tensorlayer/_logging.py index 8e386ad21..65dc90faf 100644 --- a/tensorlayer/_logging.py +++ b/tensorlayer/_logging.py @@ -1,6 +1,12 @@ -import logging +import logging as _logger -logging.basicConfig(level=logging.INFO, format='[TL] %(message)s') +# logging.basicConfig(level=logging.INFO, format='[TL] %(message)s') +logging = _logger.getLogger('tensorlayer') +logging.setLevel(_logger.INFO) +_hander = _logger.StreamHandler() +formatter = _logger.Formatter('[TL] %(message)s') +_hander.setFormatter(formatter) +logging.addHandler(_hander) def info(fmt, *args): From 0d7ac7c8dfd42ebf5c11a73725a6a0f12b6fffbb Mon Sep 17 00:00:00 2001 From: lllcho Date: Thu, 15 Mar 2018 22:28:51 +0800 Subject: [PATCH 2/5] update logging --- tensorlayer/_logging.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorlayer/_logging.py b/tensorlayer/_logging.py index 65dc90faf..be075bb47 100644 --- a/tensorlayer/_logging.py +++ b/tensorlayer/_logging.py @@ -1,6 +1,5 @@ import logging as _logger -# logging.basicConfig(level=logging.INFO, format='[TL] %(message)s') logging = _logger.getLogger('tensorlayer') logging.setLevel(_logger.INFO) _hander = _logger.StreamHandler() From 7542a4337ed14a2d835a4ffbc486f1db908df012 Mon Sep 17 00:00:00 2001 From: lllcho Date: Fri, 16 Mar 2018 21:37:38 +0800 Subject: [PATCH 3/5] b_init in c3d can None and gamma/beta in BN layer can skip --- tensorlayer/layers/convolution.py | 33 +++++++++++++++-------- tensorlayer/layers/normalization.py | 42 +++++++++++++++++------------ 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/tensorlayer/layers/convolution.py b/tensorlayer/layers/convolution.py index f19cd62e6..02321f4af 100644 --- a/tensorlayer/layers/convolution.py +++ b/tensorlayer/layers/convolution.py @@ -358,8 +358,8 @@ class Conv3dLayer(Layer): The padding algorithm type: "SAME" or "VALID". W_init : initializer The initializer for the weight matrix. - b_init : initializer - The initializer for the bias vector. + b_init : initializer or None + The initializer for the the bias vector. If None, skip biases. W_init_args : dictionary The arguments for the weight matrix initializer. b_init_args : dictionary @@ -403,8 +403,11 @@ def __init__( # W = tf.Variable(W_init(shape=shape, **W_init_args), name='W_conv') # b = tf.Variable(b_init(shape=[shape[-1]], **b_init_args), name='b_conv') W = tf.get_variable(name='W_conv3d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args) - b = tf.get_variable(name='b_conv3d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args) - self.outputs = act(tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None) + b) + if b_init: + b = tf.get_variable(name='b_conv3d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args) + self.outputs = act(tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None) + b) + else: + self.outputs = act(tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None)) # self.outputs = act( tf.nn.conv3d(self.inputs, W, strides=strides, padding=padding, name=None) + b ) @@ -412,7 +415,10 @@ def __init__( # self.all_params = list(layer.all_params) # self.all_drop = dict(layer.all_drop) self.all_layers.append(self.outputs) - self.all_params.extend([W, b]) + if b_init: + self.all_params.extend([W, b]) + else: + self.all_params.extend([W]) class DeConv3dLayer(Layer): @@ -435,8 +441,8 @@ class DeConv3dLayer(Layer): The padding algorithm type: "SAME" or "VALID". W_init : initializer The initializer for the weight matrix. - b_init : initializer - The initializer for the bias vector. + b_init : initializer or None + The initializer for the the bias vector. If None, skip biases. W_init_args : dictionary The arguments for the weight matrix initializer. b_init_args : dictionary @@ -474,15 +480,20 @@ def __init__( with tf.variable_scope(name): W = tf.get_variable(name='W_deconv3d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args) - b = tf.get_variable(name='b_deconv3d', shape=(shape[-2]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args) - - self.outputs = act(tf.nn.conv3d_transpose(self.inputs, W, output_shape=output_shape, strides=strides, padding=padding) + b) + if b_init: + b = tf.get_variable(name='b_deconv3d', shape=(shape[-2]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args) + self.outputs = act(tf.nn.conv3d_transpose(self.inputs, W, output_shape=output_shape, strides=strides, padding=padding) + b) + else: + self.outputs = act(tf.nn.conv3d_transpose(self.inputs, W, output_shape=output_shape, strides=strides, padding=padding)) # self.all_layers = list(layer.all_layers) # self.all_params = list(layer.all_params) # self.all_drop = dict(layer.all_drop) self.all_layers.append(self.outputs) - self.all_params.extend([W, b]) + if b_init: + self.all_params.extend([W, b]) + else: + self.all_params.extend([W]) class UpSampling2dLayer(Layer): diff --git a/tensorlayer/layers/normalization.py b/tensorlayer/layers/normalization.py index 6d5f02814..aae953e39 100644 --- a/tensorlayer/layers/normalization.py +++ b/tensorlayer/layers/normalization.py @@ -75,10 +75,10 @@ class BatchNormLayer(Layer): The activation function of this layer. is_train : boolean Is being used for training or inference. - beta_init : initializer - The initializer for initializing beta. - gamma_init : initializer - The initializer for initializing gamma. + beta_init : initializer or None + The initializer for initializing beta, if None, skip beta + gamma_init : initializer or None + The initializer for initializing gamma, if None, skip gamma dtype : TensorFlow dtype tf.float32 (default) or tf.float16. name : str @@ -112,19 +112,27 @@ def __init__( with tf.variable_scope(name): axis = list(range(len(x_shape) - 1)) - # 1. beta, gamma - if tf.__version__ > '0.12.1' and beta_init == tf.zeros_initializer: - beta_init = beta_init() - beta = tf.get_variable('beta', shape=params_shape, initializer=beta_init, dtype=LayersConfig.tf_dtype, trainable=is_train) - - gamma = tf.get_variable( - 'gamma', - shape=params_shape, - initializer=gamma_init, - dtype=LayersConfig.tf_dtype, - trainable=is_train, - ) + variables = [] + if beta_init: + if tf.__version__ > '0.12.1' and beta_init == tf.zeros_initializer: + beta_init = beta_init() + beta = tf.get_variable('beta', shape=params_shape, initializer=beta_init, dtype=LayersConfig.tf_dtype, trainable=is_train) + variables.append(beta) + else: + beta = None + + if gamma_init: + gamma = tf.get_variable( + 'gamma', + shape=params_shape, + initializer=gamma_init, + dtype=LayersConfig.tf_dtype, + trainable=is_train, + ) + variables.append(gamma) + else: + gamma = None # 2. if tf.__version__ > '0.12.1': @@ -163,7 +171,7 @@ def mean_var_with_update(): else: self.outputs = act(tf.nn.batch_normalization(self.inputs, moving_mean, moving_variance, beta, gamma, epsilon)) - variables = [beta, gamma, moving_mean, moving_variance] + variables.extend([moving_mean, moving_variance]) # logging.info(len(variables)) # for idx, v in enumerate(variables): From e54ab3dbb9f7f9bf6d7e1acc5ef0056fa5066d55 Mon Sep 17 00:00:00 2001 From: lllcho Date: Fri, 16 Mar 2018 21:51:12 +0800 Subject: [PATCH 4/5] fix some comments --- tensorlayer/layers/convolution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorlayer/layers/convolution.py b/tensorlayer/layers/convolution.py index 02321f4af..1cc92d048 100644 --- a/tensorlayer/layers/convolution.py +++ b/tensorlayer/layers/convolution.py @@ -123,9 +123,9 @@ class Conv2dLayer(Layer): padding : str The padding algorithm type: "SAME" or "VALID". W_init : initializer - The initializer for the the weight matrix. + The initializer for the weight matrix. b_init : initializer or None - The initializer for the the bias vector. If None, skip biases. + The initializer for the bias vector. If None, skip biases. W_init_args : dictionary The arguments for the weight matrix initializer. b_init_args : dictionary @@ -359,7 +359,7 @@ class Conv3dLayer(Layer): W_init : initializer The initializer for the weight matrix. b_init : initializer or None - The initializer for the the bias vector. If None, skip biases. + The initializer for the bias vector. If None, skip biases. W_init_args : dictionary The arguments for the weight matrix initializer. b_init_args : dictionary @@ -442,7 +442,7 @@ class DeConv3dLayer(Layer): W_init : initializer The initializer for the weight matrix. b_init : initializer or None - The initializer for the the bias vector. If None, skip biases. + The initializer for the bias vector. If None, skip biases. W_init_args : dictionary The arguments for the weight matrix initializer. b_init_args : dictionary From 8afebe584d5ee41dffc315f218541a828976c3f1 Mon Sep 17 00:00:00 2001 From: lllcho Date: Sat, 17 Mar 2018 09:44:41 +0800 Subject: [PATCH 5/5] add comments in bn layer --- tensorlayer/layers/normalization.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorlayer/layers/normalization.py b/tensorlayer/layers/normalization.py index aae953e39..42b1dcf86 100644 --- a/tensorlayer/layers/normalization.py +++ b/tensorlayer/layers/normalization.py @@ -76,9 +76,12 @@ class BatchNormLayer(Layer): is_train : boolean Is being used for training or inference. beta_init : initializer or None - The initializer for initializing beta, if None, skip beta + The initializer for initializing beta, if None, skip beta. + Usually you should not skip beta unless you know what happened. gamma_init : initializer or None - The initializer for initializing gamma, if None, skip gamma + The initializer for initializing gamma, if None, skip gamma. + When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be + disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 `__ dtype : TensorFlow dtype tf.float32 (default) or tf.float16. name : str