Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
384 changes: 164 additions & 220 deletions examples/quantized_net/tutorial_binarynet_cifar10_tfrecord.py

Large diffs are not rendered by default.

140 changes: 67 additions & 73 deletions examples/quantized_net/tutorial_binarynet_mnist_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,110 +3,104 @@

import time

import numpy as np
import tensorflow as tf

import tensorlayer as tl
from tensorlayer.layers import (BatchNorm, BinaryConv2d, BinaryDense, Flatten, Input, MaxPool2d, Sign)
from tensorlayer.models import Model

tf.logging.set_verbosity(tf.logging.DEBUG)
tl.logging.set_verbosity(tl.logging.DEBUG)

X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False)

sess = tf.InteractiveSession()

batch_size = 128

x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1])
y_ = tf.placeholder(tf.int64, shape=[batch_size])


def model(x, is_train=True, reuse=False):
def model(inputs_shape, n_class=10):
# In BNN, all the layers inputs are binary, with the exception of the first layer.
# ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py
with tf.variable_scope("binarynet", reuse=reuse):
net = tl.layers.InputLayer(x, name='input')
net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')
net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1')
net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1')

net = tl.layers.SignLayer(net)
net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')
net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2')
net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2')

net = tl.layers.FlattenLayer(net)
# net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1')
net = tl.layers.SignLayer(net)
net = tl.layers.BinaryDenseLayer(net, 256, b_init=None, name='dense')
net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3')

# net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2')
net = tl.layers.SignLayer(net)
net = tl.layers.BinaryDenseLayer(net, 10, b_init=None, name='bout')
net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno')
net_in = Input(inputs_shape, name='input')
net = BinaryConv2d(32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')(net_in)
net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1')(net)
net = BatchNorm(act=tl.act.htanh, name='bn1')(net)

net = Sign("sign1")(net)
net = BinaryConv2d(64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')(net)
net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2')(net)
net = BatchNorm(act=tl.act.htanh, name='bn2')(net)

net = Flatten('ft')(net)
net = Sign("sign2")(net)
net = BinaryDense(256, b_init=None, name='dense')(net)
net = BatchNorm(act=tl.act.htanh, name='bn3')(net)

net = Sign("sign3")(net)
net = BinaryDense(10, b_init=None, name='bout')(net)
net = BatchNorm(name='bno')(net)
net = Model(inputs=net_in, outputs=net, name='binarynet')
return net


# define inferences
net_train = model(x, is_train=True, reuse=False)
net_test = model(x, is_train=False, reuse=True)

# cost for training
y = net_train.outputs
cost = tl.cost.cross_entropy(y, y_, name='xentropy')
def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None):
with tf.GradientTape() as tape:
y_pred = network(X_batch)
_loss = cost(y_pred, y_batch)
grad = tape.gradient(_loss, network.trainable_weights)
train_op.apply_gradients(zip(grad, network.trainable_weights))
if acc is not None:
_acc = acc(y_pred, y_batch)
return _loss, _acc
else:
return _loss, None

# cost and accuracy for evalution
y2 = net_test.outputs
cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2')
correct_prediction = tf.equal(tf.argmax(y2, 1), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# define the optimizer
train_params = tl.layers.get_variables_with_name('binarynet', True, True)
train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params)
def accuracy(_logits, y_batch):
return np.mean(np.equal(np.argmax(_logits, 1), y_batch))

# initialize all variables in the session
sess.run(tf.global_variables_initializer())

net_train.print_params()
net_train.print_layers()

n_epoch = 200
print_freq = 5

# print(sess.run(net_test.all_params)) # print real values of parameters
net = model([None, 28, 28, 1])
train_op = tf.optimizers.Adam(learning_rate=0.0001)
cost = tl.cost.cross_entropy

for epoch in range(n_epoch):
start_time = time.time()
train_loss, train_acc, n_batch = 0, 0, 0
net.train()

for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a})
_loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy)
train_loss += _loss
train_acc += acc
n_batch += 1

# print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
# print(" train loss: %f" % (train_loss / n_batch))
# print(" train acc: %f" % (train_acc / n_batch))

if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
train_loss, train_acc, n_batch = 0, 0, 0
for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True):
err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a})
train_loss += err
train_acc += ac
n_batch += 1
print(" train loss: %f" % (train_loss / n_batch))
print(" train acc: %f" % (train_acc / n_batch))
val_loss, val_acc, n_batch = 0, 0, 0
val_loss, val_acc, val_batch = 0, 0, 0
net.eval()
for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True):
err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a})
val_loss += err
val_acc += ac
n_batch += 1
print(" val loss: %f" % (val_loss / n_batch))
print(" val acc: %f" % (val_acc / n_batch))

print('Evaluation')
test_loss, test_acc, n_batch = 0, 0, 0
_logits = net(X_val_a)
val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss')
val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a))
val_batch += 1
print(" val loss: {}".format(val_loss / val_batch))
print(" val acc: {}".format(val_acc / val_batch))

net.test()
test_loss, test_acc, n_test_batch = 0, 0, 0
for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True):
err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a})
test_loss += err
test_acc += ac
n_batch += 1
print(" test loss: %f" % (test_loss / n_batch))
print(" test acc: %f" % (test_acc / n_batch))
_logits = net(X_test_a)
test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss')
test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a))
n_test_batch += 1
print(" test loss: %f" % (test_loss / n_test_batch))
print(" test acc: %f" % (test_acc / n_test_batch))
Loading