From 1505b3e325795334f7718222f45131376b163439 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 17:27:07 +0000 Subject: [PATCH 01/13] add BinaryDenseLayer SignLayer etc --- tensorlayer/layers/__init__.py | 1 + tensorlayer/layers/binary.py | 144 +++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 tensorlayer/layers/binary.py diff --git a/tensorlayer/layers/__init__.py b/tensorlayer/layers/__init__.py index cad53aab8..66b0d9736 100644 --- a/tensorlayer/layers/__init__.py +++ b/tensorlayer/layers/__init__.py @@ -9,6 +9,7 @@ from .core import * from .convolution import * +from .binary import * from .super_resolution import * from .normalization import * from .spatial_transformer import * diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py new file mode 100644 index 000000000..a1dff89f0 --- /dev/null +++ b/tensorlayer/layers/binary.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- + +from .core import * +from .. import _logging as logging +import tensorflow as tf +import tensorlayer as tl + +__all__ = [ + 'BinaryDenseLayer', + 'SignLayer', + 'MultiplyScaleLayer', +] + + +@tf.RegisterGradient("TL_Sign_QuantizeGrad") +def quantize_grad(op, grad): + return tf.clip_by_value(tf.identity(grad), -1, 1) + +def quantize(x): + with tf.get_default_graph().gradient_override_map({"Sign": "TL_Sign_QuantizeGrad"}): + return tf.sign(x) + +class BinaryDenseLayer(Layer): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 + """The :class:`BinaryDenseLayer` class is a binary fully connected layer, which weights are either -1 or 1 while inferencing. + + Parameters + ---------- + layer : :class:`Layer` + Previous layer. + n_units : int + The number of units of this layer. + act : activation function + The activation function of this layer, usually set to ``tf.act.sign`` or apply :class:`SignLayer` after :class:`BatchNormLayer`. + use_gemm : boolean + If True, use gemm instead of ``tf.matmul`` for inferencing. (TODO). + W_init : initializer + The initializer for the weight matrix. + b_init : initializer or None + The initializer for the bias vector. If None, skip biases. + W_init_args : dictionary + The arguments for the weight matrix initializer. + b_init_args : dictionary + The arguments for the bias vector initializer. + name : a str + A unique layer name. + + """ + + def __init__( + self, + prev_layer, + n_units=100, + act=tf.identity, + use_gemm=False, + W_init=tf.truncated_normal_initializer(stddev=0.1), + W_init_args=None, + name='binary_dense', + ): + if W_init_args is None: + W_init_args = {} + + Layer.__init__(self, prev_layer=prev_layer, name=name) + self.inputs = prev_layer.outputs + if self.inputs.get_shape().ndims != 2: + raise Exception("The input dimension must be rank 2, please reshape or flatten it") + + if use_gemm: + raise Exception("TODO. The current version use tf.matmul for inferencing.") + + n_in = int(self.inputs.get_shape()[-1]) + self.n_units = n_units + logging.info("BinaryDenseLayer %s: %d %s" % (self.name, self.n_units, act.__name__)) + with tf.variable_scope(name): + W = tf.get_variable(name='W', shape=(n_in, n_units), initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args) + # W = tl.act.sign(W) + W = quantize(W) + # W = tf.Variable(W) + print(W) + self.outputs = act(tf.matmul(self.inputs, W)) + # self.outputs = act(xnor_gemm(self.inputs, W)) # TODO + + self.all_layers.append(self.outputs) + self.all_params.append(W) + +class SignLayer(Layer): + """The :class:`SignLayer` class is for quantizing the layer outputs to -1 or 1 while inferencing. + + Parameters + ---------- + layer : :class:`Layer` + Previous layer. + name : a str + A unique layer name. + + """ + + def __init__( + self, + prev_layer, + name='sign', + ): + + Layer.__init__(self, prev_layer=prev_layer, name=name) + self.inputs = prev_layer.outputs + + logging.info("SignLayer %s" % (self.name)) + with tf.variable_scope(name): + # self.outputs = tl.act.sign(self.inputs) + self.outputs = quantize(self.inputs) + + self.all_layers.append(self.outputs) + +class MultiplyScaleLayer(Layer): + """The :class:`AddScaleLayer` class is for multipling a trainble scale value to the layer outputs. Usually be used on the output of binary net. + + Parameters + ---------- + layer : :class:`Layer` + Previous layer. + init_scale : float + The initial value for the scale factor. + name : a str + A unique layer name. + + """ + + def __init__( + self, + prev_layer, + init_scale=0.05, + name='scale', + ): + + Layer.__init__(self, prev_layer=prev_layer, name=name) + self.inputs = prev_layer.outputs + + logging.info("MultiplyScaleLayer %s: init_scale: %f" % (self.name, init_scale)) + with tf.variable_scope(name): + # scale = tf.get_variable(name='scale_factor', init, trainable=True, ) + scale = tf.get_variable("scale", shape=[1], initializer=tf.constant_initializer(value=init_scale)) + self.outputs = self.inputs * scale + + self.all_layers.append(self.outputs) + self.all_params.append(scale) From 3737b2368c3ae398dbd334dbff4da37eadbfb138 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 17:46:36 +0000 Subject: [PATCH 02/13] add example of binarynet cnn | add BinaryConv2d --- example/tutorial_binarynet_mnist_cnn.py | 105 +++++++++++++++++++ tensorlayer/layers/binary.py | 132 +++++++++++++++++++++++- 2 files changed, 233 insertions(+), 4 deletions(-) create mode 100644 example/tutorial_binarynet_mnist_cnn.py diff --git a/example/tutorial_binarynet_mnist_cnn.py b/example/tutorial_binarynet_mnist_cnn.py new file mode 100644 index 000000000..0d27758f5 --- /dev/null +++ b/example/tutorial_binarynet_mnist_cnn.py @@ -0,0 +1,105 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import time +import tensorflow as tf +import tensorlayer as tl + +X_train, y_train, X_val, y_val, X_test, y_test = \ + tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) + +sess = tf.InteractiveSession() + +batch_size = 128 + +x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) # [batch_size, height, width, channels] +y_ = tf.placeholder(tf.int64, shape=[batch_size]) + + +def mlp(x, is_train=True, reuse=False): + with tf.variable_scope("binarynet", reuse=reuse): + net = tl.layers.InputLayer(x, name='input') + net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1') + # drop + net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn1') + net = tl.layers.SignLayer(net, name='sign2') + net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', name='bcnn2') + # drop + net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn2') + net = tl.layers.SignLayer(net, name='sign2') + net = tl.layers.FlattenLayer(net, name='flatten') + net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop1') + # net = tl.layers.DenseLayer(net, 256, act=tf.nn.relu, name='dense') + net = tl.layers.BinaryDenseLayer(net, 256, name='dense') + net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop2') + # net = tl.layers.DenseLayer(net, 10, act=tf.identity, name='output') + net = tl.layers.BinaryDenseLayer(net, 10, name='bout') + # net = tl.layers.MultiplyScaleLayer(net, name='scale') + return net + + +# define inferences +net_train = mlp(x, is_train=True, reuse=False) +net_test = mlp(x, is_train=False, reuse=True) + +# cost for training +y = net_train.outputs +cost = tl.cost.cross_entropy(y, y_, name='xentropy') + +# cost and accuracy for evalution +y2 = net_test.outputs +cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2') +correct_prediction = tf.equal(tf.argmax(y2, 1), y_) +acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) + +# define the optimizer +train_params = tl.layers.get_variables_with_name('binarynet', True, True) +train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params) + +# initialize all variables in the session +tl.layers.initialize_global_variables(sess) + +net_train.print_params() +net_train.print_layers() + +n_epoch = 200 +print_freq = 5 + +v = tl.layers.get_quantize_sign_params(sess, net_test.all_params) +print(v) + +for epoch in range(n_epoch): + start_time = time.time() + for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): + sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) + + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) + train_loss, train_acc, n_batch = 0, 0, 0 + for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): + err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) + train_loss += err + train_acc += ac + n_batch += 1 + print(" train loss: %f" % (train_loss / n_batch)) + print(" train acc: %f" % (train_acc / n_batch)) + val_loss, val_acc, n_batch = 0, 0, 0 + for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True): + err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a}) + val_loss += err + val_acc += ac + n_batch += 1 + print(" val loss: %f" % (val_loss / n_batch)) + print(" val acc: %f" % (val_acc / n_batch)) + + net_train.print_params() + +print('Evaluation') +test_loss, test_acc, n_batch = 0, 0, 0 +for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True): + err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) + test_loss += err + test_acc += ac + n_batch += 1 +print(" test loss: %f" % (test_loss / n_batch)) +print(" test acc: %f" % (test_acc / n_batch)) diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index a1dff89f0..761cc1f5c 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -1,26 +1,56 @@ # -*- coding: utf-8 -*- - +import numpy as np from .core import * from .. import _logging as logging import tensorflow as tf import tensorlayer as tl __all__ = [ + 'get_quantize_sign_params', 'BinaryDenseLayer', 'SignLayer', 'MultiplyScaleLayer', + 'BinaryConv2d', ] +def get_quantize_sign_params(sess, params_list): + """Quantize the parameters into -1 and 1 using ``np.sign``. + + Parameters + ---------- + params_list : list of tensor + A list of parameters (tensor). + sess : None or Session + Session may be required in some case. + + Returns + -------- + list of array + The parameters in a list. + + Examples + --------- + >>> tl.layers.initialize_global_variables(sess) + >>> v = tl.layers.get_quantize_sign_params(sess, net_test.all_params) + """ + params_list_var = sess.run(params_list) + for i, p in enumerate(params_list_var): + params_list_var[i] = np.sign(p).astype(int) + return params_list_var + + @tf.RegisterGradient("TL_Sign_QuantizeGrad") def quantize_grad(op, grad): return tf.clip_by_value(tf.identity(grad), -1, 1) + def quantize(x): with tf.get_default_graph().gradient_override_map({"Sign": "TL_Sign_QuantizeGrad"}): return tf.sign(x) -class BinaryDenseLayer(Layer): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 + +class BinaryDenseLayer(Layer): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 """The :class:`BinaryDenseLayer` class is a binary fully connected layer, which weights are either -1 or 1 while inferencing. Parameters @@ -72,16 +102,109 @@ def __init__( logging.info("BinaryDenseLayer %s: %d %s" % (self.name, self.n_units, act.__name__)) with tf.variable_scope(name): W = tf.get_variable(name='W', shape=(n_in, n_units), initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args) - # W = tl.act.sign(W) + # W = tl.act.sign(W) # dont update ... W = quantize(W) # W = tf.Variable(W) - print(W) + # print(W) self.outputs = act(tf.matmul(self.inputs, W)) # self.outputs = act(xnor_gemm(self.inputs, W)) # TODO self.all_layers.append(self.outputs) self.all_params.append(W) + +class BinaryConv2d(Layer): + """ + The :class:`BinaryConv2d` class is a 2D binary CNN layer, which weights are either -1 or 1 while inferencing. + + Parameters + ---------- + layer : :class:`Layer` + Previous layer. + n_filter : int + The number of filters. + filter_size : tuple of int + The filter size (height, width). + strides : tuple of int + The sliding window strides of corresponding input dimensions. + It must be in the same order as the ``shape`` parameter. + act : activation function + The activation function of this layer. + padding : str + The padding algorithm type: "SAME" or "VALID". + use_gemm : boolean + If True, use gemm instead of ``tf.matmul`` for inferencing. (TODO). + W_init : initializer + The initializer for the the weight matrix. + W_init_args : dictionary + The arguments for the weight matrix initializer. + use_cudnn_on_gpu : bool + Default is False. + data_format : str + "NHWC" or "NCHW", default is "NHWC". + name : str + A unique layer name. + + """ + + def __init__( + self, + prev_layer, + n_filter=32, + filter_size=(3, 3), + strides=(1, 1), + act=tf.identity, + padding='SAME', + use_gemm=False, + W_init=tf.truncated_normal_initializer(stddev=0.02), + # b_init=tf.constant_initializer(value=0.0), + W_init_args=None, + # b_init_args=None, + use_cudnn_on_gpu=None, + data_format=None, + # act=tf.identity, + # shape=(5, 5, 1, 100), + # strides=(1, 1, 1, 1), + # padding='SAME', + # W_init=tf.truncated_normal_initializer(stddev=0.02), + # b_init=tf.constant_initializer(value=0.0), + # W_init_args=None, + # b_init_args=None, + # use_cudnn_on_gpu=None, + # data_format=None, + name='binary_cnn2d', + ): + if W_init_args is None: + W_init_args = {} + + if use_gemm: + raise Exception("TODO. The current version use tf.matmul for inferencing.") + + Layer.__init__(self, prev_layer=prev_layer, name=name) + self.inputs = prev_layer.outputs + if act is None: + act = tf.identity + logging.info("BinaryConv2d %s: n_filter:%d filter_size:%s strides:%s pad:%s act:%s" % (self.name, n_filter, str(filter_size), str(strides), padding, + act.__name__)) + + if len(strides) != 2: + raise ValueError("len(strides) should be 2.") + try: + pre_channel = int(prev_layer.outputs.get_shape()[-1]) + except Exception: # if pre_channel is ?, it happens when using Spatial Transformer Net + pre_channel = 1 + logging.info("[warnings] unknow input channels, set to 1") + shape = (filter_size[0], filter_size[1], pre_channel, n_filter) + strides = (1, strides[0], strides[1], 1) + with tf.variable_scope(name): + W = tf.get_variable(name='W_conv2d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args) + W = quantize(W) + self.outputs = act(tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format)) + + self.all_layers.append(self.outputs) + self.all_params.append(W) + + class SignLayer(Layer): """The :class:`SignLayer` class is for quantizing the layer outputs to -1 or 1 while inferencing. @@ -110,6 +233,7 @@ def __init__( self.all_layers.append(self.outputs) + class MultiplyScaleLayer(Layer): """The :class:`AddScaleLayer` class is for multipling a trainble scale value to the layer outputs. Usually be used on the output of binary net. From ab000290cdfaf19fd7e60458a6e94d436bada08c Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 17:49:53 +0000 Subject: [PATCH 03/13] rename scale layer\ --- example/tutorial_binarynet_mnist_cnn.py | 2 +- tensorlayer/layers/binary.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/example/tutorial_binarynet_mnist_cnn.py b/example/tutorial_binarynet_mnist_cnn.py index 0d27758f5..26605067b 100644 --- a/example/tutorial_binarynet_mnist_cnn.py +++ b/example/tutorial_binarynet_mnist_cnn.py @@ -34,7 +34,7 @@ def mlp(x, is_train=True, reuse=False): net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop2') # net = tl.layers.DenseLayer(net, 10, act=tf.identity, name='output') net = tl.layers.BinaryDenseLayer(net, 10, name='bout') - # net = tl.layers.MultiplyScaleLayer(net, name='scale') + # net = tl.layers.ScaleLayer(net, name='scale') return net diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index 761cc1f5c..6a362316d 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -9,7 +9,7 @@ 'get_quantize_sign_params', 'BinaryDenseLayer', 'SignLayer', - 'MultiplyScaleLayer', + 'ScaleLayer', 'BinaryConv2d', ] @@ -234,7 +234,7 @@ def __init__( self.all_layers.append(self.outputs) -class MultiplyScaleLayer(Layer): +class ScaleLayer(Layer): """The :class:`AddScaleLayer` class is for multipling a trainble scale value to the layer outputs. Usually be used on the output of binary net. Parameters @@ -258,7 +258,7 @@ def __init__( Layer.__init__(self, prev_layer=prev_layer, name=name) self.inputs = prev_layer.outputs - logging.info("MultiplyScaleLayer %s: init_scale: %f" % (self.name, init_scale)) + logging.info("ScaleLayer %s: init_scale: %f" % (self.name, init_scale)) with tf.variable_scope(name): # scale = tf.get_variable(name='scale_factor', init, trainable=True, ) scale = tf.get_variable("scale", shape=[1], initializer=tf.constant_initializer(value=init_scale)) From 6c94f45bd646b102762684e2d631046708f175da Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 17:54:02 +0000 Subject: [PATCH 04/13] remove unused code --- tensorlayer/layers/binary.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index 6a362316d..c47f6653a 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -3,7 +3,6 @@ from .core import * from .. import _logging as logging import tensorflow as tf -import tensorlayer as tl __all__ = [ 'get_quantize_sign_params', From e150ebf8afe12335cfbceacdb93fd984353eedff Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 17:57:02 +0000 Subject: [PATCH 05/13] remove print params --- example/tutorial_binarynet_mnist_cnn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/example/tutorial_binarynet_mnist_cnn.py b/example/tutorial_binarynet_mnist_cnn.py index 26605067b..35057f685 100644 --- a/example/tutorial_binarynet_mnist_cnn.py +++ b/example/tutorial_binarynet_mnist_cnn.py @@ -92,8 +92,6 @@ def mlp(x, is_train=True, reuse=False): print(" val loss: %f" % (val_loss / n_batch)) print(" val acc: %f" % (val_acc / n_batch)) - net_train.print_params() - print('Evaluation') test_loss, test_acc, n_batch = 0, 0, 0 for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True): From 779d1968d9b3d63e688e8f7ddd6e5c50dd91bb5d Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 20:39:56 +0000 Subject: [PATCH 06/13] rename function name in binarynet example --- example/tutorial_binarynet_mnist_cnn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/tutorial_binarynet_mnist_cnn.py b/example/tutorial_binarynet_mnist_cnn.py index 35057f685..43f2710de 100644 --- a/example/tutorial_binarynet_mnist_cnn.py +++ b/example/tutorial_binarynet_mnist_cnn.py @@ -16,7 +16,7 @@ y_ = tf.placeholder(tf.int64, shape=[batch_size]) -def mlp(x, is_train=True, reuse=False): +def model(x, is_train=True, reuse=False): with tf.variable_scope("binarynet", reuse=reuse): net = tl.layers.InputLayer(x, name='input') net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1') @@ -39,8 +39,8 @@ def mlp(x, is_train=True, reuse=False): # define inferences -net_train = mlp(x, is_train=True, reuse=False) -net_test = mlp(x, is_train=False, reuse=True) +net_train = model(x, is_train=True, reuse=False) +net_test = model(x, is_train=False, reuse=True) # cost for training y = net_train.outputs From 13d3ecafa25efe9dcf7b6fa48db45bd0b6dfbbfe Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 22:04:01 +0000 Subject: [PATCH 07/13] update all --- example/tutorial_binarynet_mnist_cnn.py | 14 ++--- tensorlayer/layers/binary.py | 76 +++++++++++++------------ 2 files changed, 48 insertions(+), 42 deletions(-) diff --git a/example/tutorial_binarynet_mnist_cnn.py b/example/tutorial_binarynet_mnist_cnn.py index 43f2710de..326c0aa31 100644 --- a/example/tutorial_binarynet_mnist_cnn.py +++ b/example/tutorial_binarynet_mnist_cnn.py @@ -12,7 +12,7 @@ batch_size = 128 -x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) # [batch_size, height, width, channels] +x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) y_ = tf.placeholder(tf.int64, shape=[batch_size]) @@ -20,12 +20,13 @@ def model(x, is_train=True, reuse=False): with tf.variable_scope("binarynet", reuse=reuse): net = tl.layers.InputLayer(x, name='input') net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1') - # drop - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn1') + net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') + + net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn') net = tl.layers.SignLayer(net, name='sign2') net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', name='bcnn2') - # drop - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn2') + net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') + net = tl.layers.SignLayer(net, name='sign2') net = tl.layers.FlattenLayer(net, name='flatten') net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop1') @@ -65,8 +66,7 @@ def model(x, is_train=True, reuse=False): n_epoch = 200 print_freq = 5 -v = tl.layers.get_quantize_sign_params(sess, net_test.all_params) -print(v) +# print(sess.run(net_test.all_params)) # print real value of parameters for epoch in range(n_epoch): start_time = time.time() diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index c47f6653a..c5b05d6ce 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -3,42 +3,15 @@ from .core import * from .. import _logging as logging import tensorflow as tf +import tensorlayer as tl __all__ = [ - 'get_quantize_sign_params', 'BinaryDenseLayer', 'SignLayer', 'ScaleLayer', 'BinaryConv2d', ] - -def get_quantize_sign_params(sess, params_list): - """Quantize the parameters into -1 and 1 using ``np.sign``. - - Parameters - ---------- - params_list : list of tensor - A list of parameters (tensor). - sess : None or Session - Session may be required in some case. - - Returns - -------- - list of array - The parameters in a list. - - Examples - --------- - >>> tl.layers.initialize_global_variables(sess) - >>> v = tl.layers.get_quantize_sign_params(sess, net_test.all_params) - """ - params_list_var = sess.run(params_list) - for i, p in enumerate(params_list_var): - params_list_var[i] = np.sign(p).astype(int) - return params_list_var - - @tf.RegisterGradient("TL_Sign_QuantizeGrad") def quantize_grad(op, grad): return tf.clip_by_value(tf.identity(grad), -1, 1) @@ -52,6 +25,8 @@ def quantize(x): class BinaryDenseLayer(Layer): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 """The :class:`BinaryDenseLayer` class is a binary fully connected layer, which weights are either -1 or 1 while inferencing. + Note that, the bias vector would not be binarized. + Parameters ---------- layer : :class:`Layer` @@ -82,11 +57,15 @@ def __init__( act=tf.identity, use_gemm=False, W_init=tf.truncated_normal_initializer(stddev=0.1), + b_init=tf.constant_initializer(value=0.0), W_init_args=None, + b_init_args=None, name='binary_dense', ): if W_init_args is None: W_init_args = {} + if b_init_args is None: + b_init_args = {} Layer.__init__(self, prev_layer=prev_layer, name=name) self.inputs = prev_layer.outputs @@ -105,17 +84,30 @@ def __init__( W = quantize(W) # W = tf.Variable(W) # print(W) - self.outputs = act(tf.matmul(self.inputs, W)) - # self.outputs = act(xnor_gemm(self.inputs, W)) # TODO + if b_init is not None: + try: + b = tf.get_variable(name='b', shape=(n_units), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args) + except Exception: # If initializer is a constant, do not specify shape. + b = tf.get_variable(name='b', initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args) + self.outputs = act(tf.matmul(self.inputs, W) + b) + # self.outputs = act(xnor_gemm(self.inputs, W) + b) # TODO + else: + self.outputs = act(tf.matmul(self.inputs, W)) + # self.outputs = act(xnor_gemm(self.inputs, W)) # TODO self.all_layers.append(self.outputs) - self.all_params.append(W) + if b_init is not None: + self.all_params.extend([W, b]) + else: + self.all_params.append(W) class BinaryConv2d(Layer): """ The :class:`BinaryConv2d` class is a 2D binary CNN layer, which weights are either -1 or 1 while inferencing. + Note that, the bias vector would not be binarized. + Parameters ---------- layer : :class:`Layer` @@ -135,8 +127,12 @@ class BinaryConv2d(Layer): If True, use gemm instead of ``tf.matmul`` for inferencing. (TODO). W_init : initializer The initializer for the the weight matrix. + b_init : initializer or None + The initializer for the the bias vector. If None, skip biases. W_init_args : dictionary The arguments for the weight matrix initializer. + b_init_args : dictionary + The arguments for the bias vector initializer. use_cudnn_on_gpu : bool Default is False. data_format : str @@ -156,9 +152,9 @@ def __init__( padding='SAME', use_gemm=False, W_init=tf.truncated_normal_initializer(stddev=0.02), - # b_init=tf.constant_initializer(value=0.0), + b_init=tf.constant_initializer(value=0.0), W_init_args=None, - # b_init_args=None, + b_init_args=None, use_cudnn_on_gpu=None, data_format=None, # act=tf.identity, @@ -175,6 +171,8 @@ def __init__( ): if W_init_args is None: W_init_args = {} + if b_init_args is None: + b_init_args = {} if use_gemm: raise Exception("TODO. The current version use tf.matmul for inferencing.") @@ -198,10 +196,18 @@ def __init__( with tf.variable_scope(name): W = tf.get_variable(name='W_conv2d', shape=shape, initializer=W_init, dtype=LayersConfig.tf_dtype, **W_init_args) W = quantize(W) - self.outputs = act(tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format)) + if b_init: + b = tf.get_variable(name='b_conv2d', shape=(shape[-1]), initializer=b_init, dtype=LayersConfig.tf_dtype, **b_init_args) + self.outputs = act( + tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format) + b) + else: + self.outputs = act(tf.nn.conv2d(self.inputs, W, strides=strides, padding=padding, use_cudnn_on_gpu=use_cudnn_on_gpu, data_format=data_format)) self.all_layers.append(self.outputs) - self.all_params.append(W) + if b_init: + self.all_params.extend([W, b]) + else: + self.all_params.append(W) class SignLayer(Layer): From dbf2ea8d9da0443791283a9df93d84d748d194ad Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 22:32:23 +0000 Subject: [PATCH 08/13] rename sign act name --- docs/modules/activation.rst | 6 +++--- tensorlayer/activation.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/modules/activation.rst b/docs/modules/activation.rst index c0c2b2223..fdaa06752 100644 --- a/docs/modules/activation.rst +++ b/docs/modules/activation.rst @@ -30,7 +30,7 @@ For more complex activation, TensorFlow API will be required. ramp leaky_relu swish - sign + hard_tanh pixel_wise_softmax Identity @@ -49,9 +49,9 @@ Swish ------------ .. autofunction:: swish -Differentiable Sign +Hard Tanh --------------------- -.. autofunction:: sign +.. autofunction:: hard_tanh Pixel-wise softmax -------------------- diff --git a/tensorlayer/activation.py b/tensorlayer/activation.py index 70818b875..13b1b1d6e 100644 --- a/tensorlayer/activation.py +++ b/tensorlayer/activation.py @@ -9,7 +9,7 @@ 'ramp', 'leaky_relu', 'swish', - 'sign', + 'hard_tanh', 'pixel_wise_softmax', 'linear', 'lrelu', @@ -122,8 +122,8 @@ def _sign_grad(unused_op, grad): return tf.clip_by_value(tf.identity(grad), -1, 1) -def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L36 - """Differentiable sign function by clipping linear gradient into [-1, 1], usually be used for quantizing value in binary network, see `tf.sign `__. +def hard_tanh(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L36 + """Differentiable hard tanh function by clipping linear gradient into [-1, 1], usually be used for quantizing value in binary network, see `Binarized Neural Networks `__. Parameters ---------- From 28af90f057120847f0b6c32fae1979a0b8017a7b Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 22:34:30 +0000 Subject: [PATCH 09/13] rename function --- tensorlayer/activation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorlayer/activation.py b/tensorlayer/activation.py index 13b1b1d6e..e37a93f96 100644 --- a/tensorlayer/activation.py +++ b/tensorlayer/activation.py @@ -118,7 +118,7 @@ def swish(x, name='swish'): @tf.RegisterGradient("QuantizeGrad") -def _sign_grad(unused_op, grad): +def _hard_tanh_grad(unused_op, grad): return tf.clip_by_value(tf.identity(grad), -1, 1) @@ -140,8 +140,8 @@ def hard_tanh(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/m - `AngusG/tensorflow-xnor-bnn `__ """ - with tf.get_default_graph().gradient_override_map({"sign": "QuantizeGrad"}): - return tf.sign(x, name='tl_sign') + with tf.get_default_graph().gradient_override_map({"hard_tanh": "QuantizeGrad"}): + return tf.sign(x, name='hard_tanh') # if tf.__version__ > "1.7": From 2c1c5fd141a7846e312f012552e7cfa6d31b3413 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 22:40:24 +0000 Subject: [PATCH 10/13] fix codacy; --- tensorlayer/layers/binary.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index c5b05d6ce..b39129b23 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- -import numpy as np from .core import * from .. import _logging as logging import tensorflow as tf -import tensorlayer as tl __all__ = [ 'BinaryDenseLayer', @@ -12,6 +10,7 @@ 'BinaryConv2d', ] + @tf.RegisterGradient("TL_Sign_QuantizeGrad") def quantize_grad(op, grad): return tf.clip_by_value(tf.identity(grad), -1, 1) From 4f49aca61e1153ce07e5f607fc3b388dae071cfc Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 23:15:38 +0000 Subject: [PATCH 11/13] rename sign --- docs/modules/activation.rst | 6 +++--- tensorlayer/activation.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/modules/activation.rst b/docs/modules/activation.rst index fdaa06752..8d7474674 100644 --- a/docs/modules/activation.rst +++ b/docs/modules/activation.rst @@ -30,7 +30,7 @@ For more complex activation, TensorFlow API will be required. ramp leaky_relu swish - hard_tanh + sign pixel_wise_softmax Identity @@ -49,9 +49,9 @@ Swish ------------ .. autofunction:: swish -Hard Tanh +Sign --------------------- -.. autofunction:: hard_tanh +.. autofunction:: sign Pixel-wise softmax -------------------- diff --git a/tensorlayer/activation.py b/tensorlayer/activation.py index e37a93f96..e8128b497 100644 --- a/tensorlayer/activation.py +++ b/tensorlayer/activation.py @@ -9,7 +9,7 @@ 'ramp', 'leaky_relu', 'swish', - 'hard_tanh', + 'sign', 'pixel_wise_softmax', 'linear', 'lrelu', @@ -118,12 +118,12 @@ def swish(x, name='swish'): @tf.RegisterGradient("QuantizeGrad") -def _hard_tanh_grad(unused_op, grad): +def _sign_grad(unused_op, grad): return tf.clip_by_value(tf.identity(grad), -1, 1) -def hard_tanh(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L36 - """Differentiable hard tanh function by clipping linear gradient into [-1, 1], usually be used for quantizing value in binary network, see `Binarized Neural Networks `__. +def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L36 + """Differentiable sign function by clipping linear gradient into [-1, 1], usually be used for quantizing value in binary network, see `Binarized Neural Networks `__. Parameters ---------- @@ -140,8 +140,8 @@ def hard_tanh(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/m - `AngusG/tensorflow-xnor-bnn `__ """ - with tf.get_default_graph().gradient_override_map({"hard_tanh": "QuantizeGrad"}): - return tf.sign(x, name='hard_tanh') + with tf.get_default_graph().gradient_override_map({"sign": "QuantizeGrad"}): + return tf.sign(x, name='sign') # if tf.__version__ > "1.7": From f088fc5b45e1f687b8697811d7057a521d17ea6a Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Thu, 15 Mar 2018 23:36:42 +0000 Subject: [PATCH 12/13] improve docs for sign --- tensorlayer/activation.py | 4 +++- tensorlayer/layers/binary.py | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorlayer/activation.py b/tensorlayer/activation.py index e8128b497..9a0c12896 100644 --- a/tensorlayer/activation.py +++ b/tensorlayer/activation.py @@ -123,7 +123,9 @@ def _sign_grad(unused_op, grad): def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L36 - """Differentiable sign function by clipping linear gradient into [-1, 1], usually be used for quantizing value in binary network, see `Binarized Neural Networks `__. + """Sign function. + + Clip and binarize tensor using the straight through estimator (STE) for the gradient, usually be used for quantizing values in `Binarized Neural Networks `__. Parameters ---------- diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index b39129b23..bddd9792b 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -12,16 +12,16 @@ @tf.RegisterGradient("TL_Sign_QuantizeGrad") -def quantize_grad(op, grad): +def _quantize_grad(op, grad): + """Clip and binarize tensor using the straight through estimator (STE) for the gradient. """ return tf.clip_by_value(tf.identity(grad), -1, 1) - -def quantize(x): +def quantize(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 https://github.com/itayhubara/BinaryNet.tf/blob/master/nnUtils.py with tf.get_default_graph().gradient_override_map({"Sign": "TL_Sign_QuantizeGrad"}): return tf.sign(x) -class BinaryDenseLayer(Layer): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 +class BinaryDenseLayer(Layer): """The :class:`BinaryDenseLayer` class is a binary fully connected layer, which weights are either -1 or 1 while inferencing. Note that, the bias vector would not be binarized. From 623c04ed1759e03bfdf192ba2f8b7e7e720048e2 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Fri, 16 Mar 2018 00:39:43 +0000 Subject: [PATCH 13/13] yapf --- tensorlayer/layers/binary.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorlayer/layers/binary.py b/tensorlayer/layers/binary.py index bddd9792b..9d0578ee3 100644 --- a/tensorlayer/layers/binary.py +++ b/tensorlayer/layers/binary.py @@ -16,7 +16,10 @@ def _quantize_grad(op, grad): """Clip and binarize tensor using the straight through estimator (STE) for the gradient. """ return tf.clip_by_value(tf.identity(grad), -1, 1) -def quantize(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 https://github.com/itayhubara/BinaryNet.tf/blob/master/nnUtils.py + +def quantize(x): + # ref: https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70 + # https://github.com/itayhubara/BinaryNet.tf/blob/master/nnUtils.py with tf.get_default_graph().gradient_override_map({"Sign": "TL_Sign_QuantizeGrad"}): return tf.sign(x)