Skip to content
2 changes: 1 addition & 1 deletion docs/modules/activation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Swish
------------
.. autofunction:: swish

Differentiable Sign
Sign
---------------------
.. autofunction:: sign

Expand Down
103 changes: 103 additions & 0 deletions example/tutorial_binarynet_mnist_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#! /usr/bin/python
# -*- coding: utf-8 -*-

import time
import tensorflow as tf
import tensorlayer as tl

X_train, y_train, X_val, y_val, X_test, y_test = \
tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))

sess = tf.InteractiveSession()

batch_size = 128

x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1])
y_ = tf.placeholder(tf.int64, shape=[batch_size])


def model(x, is_train=True, reuse=False):
with tf.variable_scope("binarynet", reuse=reuse):
net = tl.layers.InputLayer(x, name='input')
net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', name='bcnn1')
net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1')

net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bn')
net = tl.layers.SignLayer(net, name='sign2')
net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', name='bcnn2')
net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2')

net = tl.layers.SignLayer(net, name='sign2')
net = tl.layers.FlattenLayer(net, name='flatten')
net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop1')
# net = tl.layers.DenseLayer(net, 256, act=tf.nn.relu, name='dense')
net = tl.layers.BinaryDenseLayer(net, 256, name='dense')
net = tl.layers.DropoutLayer(net, 0.5, True, is_train, name='drop2')
# net = tl.layers.DenseLayer(net, 10, act=tf.identity, name='output')
net = tl.layers.BinaryDenseLayer(net, 10, name='bout')
# net = tl.layers.ScaleLayer(net, name='scale')
return net


# define inferences
net_train = model(x, is_train=True, reuse=False)
net_test = model(x, is_train=False, reuse=True)

# cost for training
y = net_train.outputs
cost = tl.cost.cross_entropy(y, y_, name='xentropy')

# cost and accuracy for evalution
y2 = net_test.outputs
cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2')
correct_prediction = tf.equal(tf.argmax(y2, 1), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# define the optimizer
train_params = tl.layers.get_variables_with_name('binarynet', True, True)
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params)

# initialize all variables in the session
tl.layers.initialize_global_variables(sess)

net_train.print_params()
net_train.print_layers()

n_epoch = 200
print_freq = 5

# print(sess.run(net_test.all_params)) # print real value of parameters

for epoch in range(n_epoch):
start_time = time.time()
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a})

if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
train_loss, train_acc, n_batch = 0, 0, 0
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a})
train_loss += err
train_acc += ac
n_batch += 1
print(" train loss: %f" % (train_loss / n_batch))
print(" train acc: %f" % (train_acc / n_batch))
val_loss, val_acc, n_batch = 0, 0, 0
for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True):
err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a})
val_loss += err
val_acc += ac
n_batch += 1
print(" val loss: %f" % (val_loss / n_batch))
print(" val acc: %f" % (val_acc / n_batch))

print('Evaluation')
test_loss, test_acc, n_batch = 0, 0, 0
for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True):
err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a})
test_loss += err
test_acc += ac
n_batch += 1
print(" test loss: %f" % (test_loss / n_batch))
print(" test acc: %f" % (test_acc / n_batch))
6 changes: 4 additions & 2 deletions tensorlayer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def _sign_grad(unused_op, grad):


def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L36
"""Differentiable sign function by clipping linear gradient into [-1, 1], usually be used for quantizing value in binary network, see `tf.sign <https://www.tensorflow.org/api_docs/python/tf/sign>`__.
"""Sign function.

Clip and binarize tensor using the straight through estimator (STE) for the gradient, usually be used for quantizing values in `Binarized Neural Networks <https://arxiv.org/abs/1602.02830>`__.

Parameters
----------
Expand All @@ -141,7 +143,7 @@ def sign(x): # https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models

"""
with tf.get_default_graph().gradient_override_map({"sign": "QuantizeGrad"}):
return tf.sign(x, name='tl_sign')
return tf.sign(x, name='sign')


# if tf.__version__ > "1.7":
Expand Down
1 change: 1 addition & 0 deletions tensorlayer/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .core import *
from .convolution import *
from .binary import *
from .super_resolution import *
from .normalization import *
from .spatial_transformer import *
Expand Down
Loading