diff --git a/examples/quantized_net/tutorial_binarynet_cifar10_tfrecord.py b/examples/quantized_net/tutorial_binarynet_cifar10_tfrecord.py index 98532debb..3f4d0fcf1 100644 --- a/examples/quantized_net/tutorial_binarynet_cifar10_tfrecord.py +++ b/examples/quantized_net/tutorial_binarynet_cifar10_tfrecord.py @@ -39,236 +39,180 @@ """ -import os +import multiprocessing import time +import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import ( + BinaryConv2d, BinaryDense, Conv2d, Dense, Flatten, Input, LocalResponseNorm, MaxPool2d, Sign +) +from tensorlayer.models import Model -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) -model_file_name = "./model_cifar10_tfrecord.ckpt" -resume = False # load model, resume from previous checkpoint? - # Download data, and convert to TFRecord format, see ```tutorial_tfrecord.py``` +# prepare cifar10 data X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) -print('X_train.shape', X_train.shape) # (50000, 32, 32, 3) -print('y_train.shape', y_train.shape) # (50000,) -print('X_test.shape', X_test.shape) # (10000, 32, 32, 3) -print('y_test.shape', y_test.shape) # (10000,) -print('X %s y %s' % (X_test.dtype, y_test.dtype)) - - -def data_to_tfrecord(images, labels, filename): - """Save data into TFRecord.""" - if os.path.isfile(filename): - print("%s exists" % filename) - return - print("Converting data into %s ..." % filename) - # cwd = os.getcwd() - writer = tf.python_io.TFRecordWriter(filename) - for index, img in enumerate(images): - img_raw = img.tobytes() - # Visualize a image - # tl.visualize.frame(np.asarray(img, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236) - label = int(labels[index]) - # print(label) - # Convert the bytes back to image as follow: - # image = Image.frombytes('RGB', (32, 32), img_raw) - # image = np.fromstring(img_raw, np.float32) - # image = image.reshape([32, 32, 3]) - # tl.visualize.frame(np.asarray(image, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236) - example = tf.train.Example( - features=tf.train.Features( - feature={ - "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), - 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), - } - ) - ) - writer.write(example.SerializeToString()) # Serialize To String - writer.close() - - -def read_and_decode(filename, is_train=None): - """Return tensor to read from TFRecord.""" - filename_queue = tf.train.string_input_producer([filename]) - reader = tf.TFRecordReader() - _, serialized_example = reader.read(filename_queue) - features = tf.parse_single_example( - serialized_example, features={ - 'label': tf.FixedLenFeature([], tf.int64), - 'img_raw': tf.FixedLenFeature([], tf.string), - } - ) - # You can do more image distortion here for training data - img = tf.decode_raw(features['img_raw'], tf.float32) - img = tf.reshape(img, [32, 32, 3]) - # img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5 - if is_train ==True: - # 1. Randomly crop a [height, width] section of the image. - img = tf.random_crop(img, [24, 24, 3]) - - # 2. Randomly flip the image horizontally. - img = tf.image.random_flip_left_right(img) - - # 3. Randomly change brightness. - img = tf.image.random_brightness(img, max_delta=63) - - # 4. Randomly change contrast. - img = tf.image.random_contrast(img, lower=0.2, upper=1.8) - - # 5. Subtract off the mean and divide by the variance of the pixels. - img = tf.image.per_image_standardization(img) - - elif is_train == False: - # 1. Crop the central [height, width] of the image. - img = tf.image.resize_image_with_crop_or_pad(img, 24, 24) - - # 2. Subtract off the mean and divide by the variance of the pixels. - img = tf.image.per_image_standardization(img) - - elif is_train == None: - img = img - - label = tf.cast(features['label'], tf.int32) - return img, label - - -# Save data into TFRecord files -data_to_tfrecord(images=X_train, labels=y_train, filename="train.cifar10") -data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10") +def binary_model(input_shape, n_classes): + in_net = Input(shape=input_shape, name='input') + + net = Conv2d(64, (5, 5), (1, 1), act='relu', padding='SAME', name='conv1')(in_net) + net = Sign(name='sign1')(net) + + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(net) + net = LocalResponseNorm(4, 1.0, 0.001 / 9.0, 0.75, name='norm1')(net) + net = BinaryConv2d(64, (5, 5), (1, 1), act='relu', padding='SAME', name='bconv1')(net) + + net = LocalResponseNorm(4, 1.0, 0.001 / 9.0, 0.75, name='norm2')(net) + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(net) + net = Flatten(name='flatten')(net) + net = Sign(name='sign2')(net) + net = BinaryDense(384, act='relu', name='d1relu')(net) + net = Sign(name='sign3')(net) + net = BinaryDense(192, act='relu', name='d2relu')(net) + net = Dense(n_classes, act=None, name='output')(net) + net = Model(inputs=in_net, outputs=net, name='binarynet') + return net + + +# training settings +net = binary_model([None, 24, 24, 3], n_classes=10) batch_size = 128 -model_file_name = "./model_cifar10_advanced.ckpt" -resume = False # load model, resume from previous checkpoint? - -with tf.device('/cpu:0'): - sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) - # prepare data in cpu - x_train_, y_train_ = read_and_decode("train.cifar10", True) - x_test_, y_test_ = read_and_decode("test.cifar10", False) - # set the number of threads here - x_train_batch, y_train_batch = tf.train.shuffle_batch( - [x_train_, y_train_], batch_size=batch_size, capacity=2000, min_after_dequeue=1000, num_threads=32 - ) - # for testing, uses batch instead of shuffle_batch - x_test_batch, y_test_batch = tf.train.batch( - [x_test_, y_test_], batch_size=batch_size, capacity=50000, num_threads=32 - ) - - def model(x_crop, y_, reuse): - """For more simplified CNN APIs, check tensorlayer.org.""" - with tf.variable_scope("model", reuse=reuse): - net = tl.layers.InputLayer(x_crop, name='input') - net = tl.layers.Conv2d(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='cnn1') - net = tl.layers.SignLayer(net) - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1') - net = tl.layers.LocalResponseNormLayer(net, 4, 1.0, 0.001 / 9.0, 0.75, name='norm1') - net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='cnn2') - net = tl.layers.LocalResponseNormLayer(net, 4, 1.0, 0.001 / 9.0, 0.75, name='norm2') - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2') - net = tl.layers.FlattenLayer(net, name='flatten') - net = tl.layers.SignLayer(net) - net = tl.layers.BinaryDenseLayer(net, 384, act=tf.nn.relu, name='d1relu') - net = tl.layers.SignLayer(net) - net = tl.layers.BinaryDenseLayer(net, 192, act=tf.nn.relu, name='d2relu') - net = tl.layers.DenseLayer(net, 10, act=None, name='output') - - y = net.outputs - - ce = tl.cost.cross_entropy(y, y_, name='cost') - # L2 for the MLP, without this, the accuracy will be reduced by 15%. - L2 = 0 - for p in tl.layers.get_variables_with_name('relu/W', True, True): - L2 += tf.contrib.layers.l2_regularizer(0.004)(p) - cost = ce + L2 - - # correct_prediction = tf.equal(tf.argmax(tf.nn.softmax(y), 1), y_) - correct_prediction = tf.equal(tf.cast(tf.argmax(y, 1), tf.int32), y_) - acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - - return net, cost, acc - - # You can also use placeholder to feed_dict in data after using - # val, l = sess.run([x_train_batch, y_train_batch]) to get the data - # x_crop = tf.placeholder(tf.float32, shape=[batch_size, 24, 24, 3]) - # y_ = tf.placeholder(tf.int32, shape=[batch_size,]) - # cost, acc, network = model(x_crop, y_, None) - - with tf.device('/gpu:0'): # <-- remove it if you don't have GPU - network, cost, acc, = model(x_train_batch, y_train_batch, False) - _, cost_test, acc_test = model(x_test_batch, y_test_batch, True) - - # train - n_epoch = 50000 - learning_rate = 0.0001 - print_freq = 1 - n_step_epoch = int(len(y_train) / batch_size) - n_step = n_epoch * n_step_epoch - - with tf.device('/gpu:0'): # <-- remove it if you don't have GPU - train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost) - - sess.run(tf.global_variables_initializer()) - if resume: - print("Load existing model " + "!" * 10) - saver = tf.train.Saver() - saver.restore(sess, model_file_name) - - network.print_params(False) - network.print_layers() - - print(' learning_rate: %f' % learning_rate) - print(' batch_size: %d' % batch_size) - print(' n_epoch: %d, step in an epoch: %d, total n_step: %d' % (n_epoch, n_step_epoch, n_step)) - - coord = tf.train.Coordinator() - threads = tf.train.start_queue_runners(sess=sess, coord=coord) - step = 0 - for epoch in range(n_epoch): - start_time = time.time() - train_loss, train_acc, n_batch = 0, 0, 0 - for s in range(n_step_epoch): - # You can also use placeholder to feed_dict in data after using - # val, l = sess.run([x_train_batch, y_train_batch]) - # tl.visualize.images2d(val, second=3, saveable=False, name='batch', dtype=np.uint8, fig_idx=2020121) - # err, ac, _ = sess.run([cost, acc, train_op], feed_dict={x_crop: val, y_: l}) - err, ac, _ = sess.run([cost, acc, train_op]) - step += 1 - train_loss += err - train_acc += ac - n_batch += 1 - - if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: - print( - "Epoch %d : Step %d-%d of %d took %fs" % - (epoch, step, step + n_step_epoch, n_step, time.time() - start_time) - ) - print(" train loss: %f" % (train_loss / n_batch)) - print(" train acc: %f" % (train_acc / n_batch)) - - test_loss, test_acc, n_batch = 0, 0, 0 - for _ in range(int(len(y_test) / batch_size)): - err, ac = sess.run([cost_test, acc_test]) - test_loss += err - test_acc += ac - n_batch += 1 - print(" test loss: %f" % (test_loss / n_batch)) - print(" test acc: %f" % (test_acc / n_batch)) - - if (epoch + 1) % (print_freq * 50) == 0: - print("Save model " + "!" * 10) - saver = tf.train.Saver() - save_path = saver.save(sess, model_file_name) - # you can also save model into npz - tl.files.save_npz(network.all_params, name='model.npz', sess=sess) - # and restore it as follow: - # tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network) - - coord.request_stop() - coord.join(threads) - sess.close() +n_epoch = 50000 +learning_rate = 0.0001 +print_freq = 5 +n_step_epoch = int(len(y_train) / batch_size) +n_step = n_epoch * n_step_epoch +shuffle_buffer_size = 128 + +train_weights = net.trainable_weights +optimizer = tf.optimizers.Adam(learning_rate) +cost = tl.cost.cross_entropy + + +def generator_train(): + inputs = X_train + targets = y_train + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def generator_test(): + inputs = X_test + targets = y_test + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def _map_fn_train(img, target): + # 1. Randomly crop a [height, width] section of the image. + img = tf.image.random_crop(img, [24, 24, 3]) + # 2. Randomly flip the image horizontally. + img = tf.image.random_flip_left_right(img) + # 3. Randomly change brightness. + img = tf.image.random_brightness(img, max_delta=63) + # 4. Randomly change contrast. + img = tf.image.random_contrast(img, lower=0.2, upper=1.8) + # 5. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + target = tf.reshape(target, ()) + return img, target + + +def _map_fn_test(img, target): + # 1. Crop the central [height, width] of the image. + img = tf.image.resize_with_pad(img, 24, 24) + # 2. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + img = tf.reshape(img, (24, 24, 3)) + target = tf.reshape(target, ()) + return img, target + + +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None + + +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + + +# dataset API and augmentation +train_ds = tf.data.Dataset.from_generator( + generator_train, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count()) +# train_ds = train_ds.repeat(n_epoch) +train_ds = train_ds.shuffle(shuffle_buffer_size) +train_ds = train_ds.prefetch(buffer_size=4096) +train_ds = train_ds.batch(batch_size) +# value = train_ds.make_one_shot_iterator().get_next() + +test_ds = tf.data.Dataset.from_generator( + generator_test, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +# test_ds = test_ds.shuffle(shuffle_buffer_size) +test_ds = test_ds.map(_map_fn_test, num_parallel_calls=multiprocessing.cpu_count()) +# test_ds = test_ds.repeat(n_epoch) +test_ds = test_ds.prefetch(buffer_size=4096) +test_ds = test_ds.batch(batch_size) +# value_test = test_ds.make_one_shot_iterator().get_next() + +for epoch in range(n_epoch): + start_time = time.time() + + train_loss, train_acc, n_iter = 0, 0, 0 + for X_batch, y_batch in train_ds: + net.train() + _loss, acc = _train_step(net, X_batch, y_batch, cost=cost, train_op=optimizer, acc=accuracy) + + train_loss += _loss + train_acc += acc + n_iter += 1 + + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + # use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + net.eval() + val_loss, val_acc, n_val_iter = 0, 0, 0 + for X_batch, y_batch in test_ds: + _logits = net(X_batch) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_batch, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_val_iter += 1 + print(" val loss: {}".format(val_loss / n_val_iter)) + print(" val acc: {}".format(val_acc / n_val_iter)) + +# use testing data to evaluate the model +net.eval() +test_loss, test_acc, n_iter = 0, 0, 0 +for X_batch, y_batch in test_ds: + _logits = net(X_batch) + test_loss += tl.cost.cross_entropy(_logits, y_batch, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 +print(" test loss: {}".format(test_loss / n_iter)) +print(" test acc: {}".format(test_acc / n_iter)) diff --git a/examples/quantized_net/tutorial_binarynet_mnist_cnn.py b/examples/quantized_net/tutorial_binarynet_mnist_cnn.py index 248812e23..4eccd5c2e 100644 --- a/examples/quantized_net/tutorial_binarynet_mnist_cnn.py +++ b/examples/quantized_net/tutorial_binarynet_mnist_cnn.py @@ -3,110 +3,104 @@ import time +import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import (BatchNorm, BinaryConv2d, BinaryDense, Flatten, Input, MaxPool2d, Sign) +from tensorlayer.models import Model -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) -# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) - -sess = tf.InteractiveSession() batch_size = 128 -x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) -y_ = tf.placeholder(tf.int64, shape=[batch_size]) - -def model(x, is_train=True, reuse=False): +def model(inputs_shape, n_class=10): # In BNN, all the layers inputs are binary, with the exception of the first layer. # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py - with tf.variable_scope("binarynet", reuse=reuse): - net = tl.layers.InputLayer(x, name='input') - net = tl.layers.BinaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1') - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') - - net = tl.layers.SignLayer(net) - net = tl.layers.BinaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2') - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') - - net = tl.layers.FlattenLayer(net) - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') - net = tl.layers.SignLayer(net) - net = tl.layers.BinaryDenseLayer(net, 256, b_init=None, name='dense') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3') - - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') - net = tl.layers.SignLayer(net) - net = tl.layers.BinaryDenseLayer(net, 10, b_init=None, name='bout') - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno') + net_in = Input(inputs_shape, name='input') + net = BinaryConv2d(32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')(net_in) + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1')(net) + net = BatchNorm(act=tl.act.htanh, name='bn1')(net) + + net = Sign("sign1")(net) + net = BinaryConv2d(64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')(net) + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2')(net) + net = BatchNorm(act=tl.act.htanh, name='bn2')(net) + + net = Flatten('ft')(net) + net = Sign("sign2")(net) + net = BinaryDense(256, b_init=None, name='dense')(net) + net = BatchNorm(act=tl.act.htanh, name='bn3')(net) + + net = Sign("sign3")(net) + net = BinaryDense(10, b_init=None, name='bout')(net) + net = BatchNorm(name='bno')(net) + net = Model(inputs=net_in, outputs=net, name='binarynet') return net -# define inferences -net_train = model(x, is_train=True, reuse=False) -net_test = model(x, is_train=False, reuse=True) - -# cost for training -y = net_train.outputs -cost = tl.cost.cross_entropy(y, y_, name='xentropy') +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None -# cost and accuracy for evalution -y2 = net_test.outputs -cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2') -correct_prediction = tf.equal(tf.argmax(y2, 1), y_) -acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) -# define the optimizer -train_params = tl.layers.get_variables_with_name('binarynet', True, True) -train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params) +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) -# initialize all variables in the session -sess.run(tf.global_variables_initializer()) - -net_train.print_params() -net_train.print_layers() n_epoch = 200 print_freq = 5 -# print(sess.run(net_test.all_params)) # print real values of parameters +net = model([None, 28, 28, 1]) +train_op = tf.optimizers.Adam(learning_rate=0.0001) +cost = tl.cost.cross_entropy for epoch in range(n_epoch): start_time = time.time() + train_loss, train_acc, n_batch = 0, 0, 0 + net.train() + for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) + _loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy) + train_loss += _loss + train_acc += acc + n_batch += 1 + + # print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) + # print(" train loss: %f" % (train_loss / n_batch)) + # print(" train acc: %f" % (train_acc / n_batch)) if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) - train_loss, train_acc, n_batch = 0, 0, 0 - for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) - train_loss += err - train_acc += ac - n_batch += 1 print(" train loss: %f" % (train_loss / n_batch)) print(" train acc: %f" % (train_acc / n_batch)) - val_loss, val_acc, n_batch = 0, 0, 0 + val_loss, val_acc, val_batch = 0, 0, 0 + net.eval() for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a}) - val_loss += err - val_acc += ac - n_batch += 1 - print(" val loss: %f" % (val_loss / n_batch)) - print(" val acc: %f" % (val_acc / n_batch)) - -print('Evaluation') -test_loss, test_acc, n_batch = 0, 0, 0 + _logits = net(X_val_a) + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) + val_batch += 1 + print(" val loss: {}".format(val_loss / val_batch)) + print(" val acc: {}".format(val_acc / val_batch)) + +net.test() +test_loss, test_acc, n_test_batch = 0, 0, 0 for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) - test_loss += err - test_acc += ac - n_batch += 1 -print(" test loss: %f" % (test_loss / n_batch)) -print(" test acc: %f" % (test_acc / n_batch)) + _logits = net(X_test_a) + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) + n_test_batch += 1 +print(" test loss: %f" % (test_loss / n_test_batch)) +print(" test acc: %f" % (test_acc / n_test_batch)) diff --git a/examples/quantized_net/tutorial_dorefanet_cifar10_tfrecord.py b/examples/quantized_net/tutorial_dorefanet_cifar10_tfrecord.py index 9c8ab1239..5ebb7cfa6 100644 --- a/examples/quantized_net/tutorial_dorefanet_cifar10_tfrecord.py +++ b/examples/quantized_net/tutorial_dorefanet_cifar10_tfrecord.py @@ -39,232 +39,173 @@ """ -import os +import multiprocessing import time +import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import (Conv2d, Dense, DorefaConv2d, DorefaDense, Flatten, Input, LocalResponseNorm, MaxPool2d) +from tensorlayer.models import Model -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) -model_file_name = "./model_cifar10_tfrecord.ckpt" -resume = False # load model, resume from previous checkpoint? - # Download data, and convert to TFRecord format, see ```tutorial_tfrecord.py``` +# prepare cifar10 data X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) -print('X_train.shape', X_train.shape) # (50000, 32, 32, 3) -print('y_train.shape', y_train.shape) # (50000,) -print('X_test.shape', X_test.shape) # (10000, 32, 32, 3) -print('y_test.shape', y_test.shape) # (10000,) -print('X %s y %s' % (X_test.dtype, y_test.dtype)) - - -def data_to_tfrecord(images, labels, filename): - """Save data into TFRecord.""" - if os.path.isfile(filename): - print("%s exists" % filename) - return - print("Converting data into %s ..." % filename) - # cwd = os.getcwd() - writer = tf.python_io.TFRecordWriter(filename) - for index, img in enumerate(images): - img_raw = img.tobytes() - # Visualize a image - # tl.visualize.frame(np.asarray(img, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236) - label = int(labels[index]) - # print(label) - # Convert the bytes back to image as follow: - # image = Image.frombytes('RGB', (32, 32), img_raw) - # image = np.fromstring(img_raw, np.float32) - # image = image.reshape([32, 32, 3]) - # tl.visualize.frame(np.asarray(image, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236) - example = tf.train.Example( - features=tf.train.Features( - feature={ - "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), - 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), - } - ) - ) - writer.write(example.SerializeToString()) # Serialize To String - writer.close() - - -def read_and_decode(filename, is_train=None): - """Return tensor to read from TFRecord.""" - filename_queue = tf.train.string_input_producer([filename]) - reader = tf.TFRecordReader() - _, serialized_example = reader.read(filename_queue) - features = tf.parse_single_example( - serialized_example, features={ - 'label': tf.FixedLenFeature([], tf.int64), - 'img_raw': tf.FixedLenFeature([], tf.string), - } - ) - # You can do more image distortion here for training data - img = tf.decode_raw(features['img_raw'], tf.float32) - img = tf.reshape(img, [32, 32, 3]) - # img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5 - if is_train ==True: - # 1. Randomly crop a [height, width] section of the image. - img = tf.random_crop(img, [24, 24, 3]) - - # 2. Randomly flip the image horizontally. - img = tf.image.random_flip_left_right(img) - - # 3. Randomly change brightness. - img = tf.image.random_brightness(img, max_delta=63) - - # 4. Randomly change contrast. - img = tf.image.random_contrast(img, lower=0.2, upper=1.8) - - # 5. Subtract off the mean and divide by the variance of the pixels. - img = tf.image.per_image_standardization(img) - - elif is_train == False: - # 1. Crop the central [height, width] of the image. - img = tf.image.resize_image_with_crop_or_pad(img, 24, 24) - - # 2. Subtract off the mean and divide by the variance of the pixels. - img = tf.image.per_image_standardization(img) - - elif is_train == None: - img = img - - label = tf.cast(features['label'], tf.int32) - return img, label - - -# Save data into TFRecord files -data_to_tfrecord(images=X_train, labels=y_train, filename="train.cifar10") -data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10") +def dorefanet_model(input_shape, n_classes): + in_net = Input(shape=input_shape, name='input') + net = Conv2d(32, (5, 5), (1, 1), act='relu', padding='SAME', name='conv1')(in_net) + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(net) + net = LocalResponseNorm(4, 1.0, 0.001 / 9.0, 0.75, name='norm1')(net) + net = tl.layers.Sign("sign")(net) + net = DorefaConv2d(8, 32, 64, (5, 5), (1, 1), act='relu', padding='SAME', name='DorefaConv1')(net) + net = LocalResponseNorm(4, 1.0, 0.001 / 9.0, 0.75, name='norm2')(net) + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(net) + net = Flatten(name='flatten')(net) + net = DorefaDense(8, 16, 384, act='relu', name='DorefaDense1')(net) + net = DorefaDense(8, 16, 192, act='relu', name='DorefaDense2')(net) + net = Dense(n_classes, act=None, name='output')(net) + net = Model(inputs=in_net, outputs=net, name='dorefanet') + return net + + +# training settings +net = dorefanet_model([None, 24, 24, 3], n_classes=10) batch_size = 128 -model_file_name = "./model_cifar10_advanced.ckpt" -resume = False # load model, resume from previous checkpoint? - -with tf.device('/cpu:0'): - sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) - # prepare data in cpu - x_train_, y_train_ = read_and_decode("train.cifar10", True) - x_test_, y_test_ = read_and_decode("test.cifar10", False) - # set the number of threads here - x_train_batch, y_train_batch = tf.train.shuffle_batch( - [x_train_, y_train_], batch_size=batch_size, capacity=2000, min_after_dequeue=1000, num_threads=32 - ) - # for testing, uses batch instead of shuffle_batch - x_test_batch, y_test_batch = tf.train.batch( - [x_test_, y_test_], batch_size=batch_size, capacity=50000, num_threads=32 - ) - - def model(x_crop, y_, reuse): - """For more simplified CNN APIs, check tensorlayer.org.""" - with tf.variable_scope("model", reuse=reuse): - net = tl.layers.InputLayer(x_crop, name='input') - net = tl.layers.Conv2d(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='cnn1') - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1') - net = tl.layers.LocalResponseNormLayer(net, 4, 1.0, 0.001 / 9.0, 0.75, name='norm1') - net = tl.layers.DorefaConv2d(net, 1, 3, 64, (5, 5), (1, 1), tf.nn.relu, padding='SAME', name='cnn2') - net = tl.layers.LocalResponseNormLayer(net, 4, 1.0, 0.001 / 9.0, 0.75, name='norm2') - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2') - net = tl.layers.FlattenLayer(net, name='flatten') - net = tl.layers.DorefaDenseLayer(net, 1, 3, 384, act=tf.nn.relu, name='d1relu') - net = tl.layers.DorefaDenseLayer(net, 1, 3, 192, act=tf.nn.relu, name='d2relu') - net = tl.layers.DenseLayer(net, 10, act=None, name='output') - y = net.outputs - - ce = tl.cost.cross_entropy(y, y_, name='cost') - # L2 for the MLP, without this, the accuracy will be reduced by 15%. - L2 = 0 - for p in tl.layers.get_variables_with_name('relu/W', True, True): - L2 += tf.contrib.layers.l2_regularizer(0.004)(p) - cost = ce + L2 - - # correct_prediction = tf.equal(tf.argmax(tf.nn.softmax(y), 1), y_) - correct_prediction = tf.equal(tf.cast(tf.argmax(y, 1), tf.int32), y_) - acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - - return net, cost, acc - - # You can also use placeholder to feed_dict in data after using - # val, l = sess.run([x_train_batch, y_train_batch]) to get the data - # x_crop = tf.placeholder(tf.float32, shape=[batch_size, 24, 24, 3]) - # y_ = tf.placeholder(tf.int32, shape=[batch_size,]) - # cost, acc, network = model(x_crop, y_, None) - - with tf.device('/gpu:0'): # <-- remove it if you don't have GPU - network, cost, acc, = model(x_train_batch, y_train_batch, False) - _, cost_test, acc_test = model(x_test_batch, y_test_batch, True) - - # train - n_epoch = 50000 - learning_rate = 0.0001 - print_freq = 1 - n_step_epoch = int(len(y_train) / batch_size) - n_step = n_epoch * n_step_epoch - - with tf.device('/gpu:0'): # <-- remove it if you don't have GPU - train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost) - - sess.run(tf.global_variables_initializer()) - if resume: - print("Load existing model " + "!" * 10) - saver = tf.train.Saver() - saver.restore(sess, model_file_name) - - network.print_params(False) - network.print_layers() - - print(' learning_rate: %f' % learning_rate) - print(' batch_size: %d' % batch_size) - print(' n_epoch: %d, step in an epoch: %d, total n_step: %d' % (n_epoch, n_step_epoch, n_step)) - - coord = tf.train.Coordinator() - threads = tf.train.start_queue_runners(sess=sess, coord=coord) - step = 0 - for epoch in range(n_epoch): - start_time = time.time() - train_loss, train_acc, n_batch = 0, 0, 0 - for s in range(n_step_epoch): - # You can also use placeholder to feed_dict in data after using - # val, l = sess.run([x_train_batch, y_train_batch]) - # tl.visualize.images2d(val, second=3, saveable=False, name='batch', dtype=np.uint8, fig_idx=2020121) - # err, ac, _ = sess.run([cost, acc, train_op], feed_dict={x_crop: val, y_: l}) - err, ac, _ = sess.run([cost, acc, train_op]) - step += 1 - train_loss += err - train_acc += ac - n_batch += 1 - - if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: - print( - "Epoch %d : Step %d-%d of %d took %fs" % - (epoch, step, step + n_step_epoch, n_step, time.time() - start_time) - ) - print(" train loss: %f" % (train_loss / n_batch)) - print(" train acc: %f" % (train_acc / n_batch)) - - test_loss, test_acc, n_batch = 0, 0, 0 - for _ in range(int(len(y_test) / batch_size)): - err, ac = sess.run([cost_test, acc_test]) - test_loss += err - test_acc += ac - n_batch += 1 - print(" test loss: %f" % (test_loss / n_batch)) - print(" test acc: %f" % (test_acc / n_batch)) - - if (epoch + 1) % (print_freq * 50) == 0: - print("Save model " + "!" * 10) - saver = tf.train.Saver() - save_path = saver.save(sess, model_file_name) - # you can also save model into npz - tl.files.save_npz(network.all_params, name='model.npz', sess=sess) - # and restore it as follow: - # tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network) - - coord.request_stop() - coord.join(threads) - sess.close() +n_epoch = 50000 +learning_rate = 0.0001 +print_freq = 5 +n_step_epoch = int(len(y_train) / batch_size) +n_step = n_epoch * n_step_epoch +shuffle_buffer_size = 128 + +optimizer = tf.optimizers.Adam(learning_rate) +# optimizer = tf.optimizers.SGD(learning_rate) +cost = tl.cost.cross_entropy + + +def generator_train(): + inputs = X_train + targets = y_train + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def generator_test(): + inputs = X_test + targets = y_test + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def _map_fn_train(img, target): + # 1. Randomly crop a [height, width] section of the image. + img = tf.image.random_crop(img, [24, 24, 3]) + # 2. Randomly flip the image horizontally. + img = tf.image.random_flip_left_right(img) + # 3. Randomly change brightness. + img = tf.image.random_brightness(img, max_delta=63) + # 4. Randomly change contrast. + img = tf.image.random_contrast(img, lower=0.2, upper=1.8) + # 5. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + target = tf.reshape(target, ()) + return img, target + + +def _map_fn_test(img, target): + # 1. Crop the central [height, width] of the image. + img = tf.image.resize_with_pad(img, 24, 24) + # 2. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + img = tf.reshape(img, (24, 24, 3)) + target = tf.reshape(target, ()) + return img, target + + +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None + + +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + + +# dataset API and augmentation +train_ds = tf.data.Dataset.from_generator( + generator_train, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count()) +# train_ds = train_ds.repeat(n_epoch) +train_ds = train_ds.shuffle(shuffle_buffer_size) +train_ds = train_ds.prefetch(buffer_size=4096) +train_ds = train_ds.batch(batch_size) +# value = train_ds.make_one_shot_iterator().get_next() + +test_ds = tf.data.Dataset.from_generator( + generator_test, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +# test_ds = test_ds.shuffle(shuffle_buffer_size) +test_ds = test_ds.map(_map_fn_test, num_parallel_calls=multiprocessing.cpu_count()) +# test_ds = test_ds.repeat(n_epoch) +test_ds = test_ds.prefetch(buffer_size=4096) +test_ds = test_ds.batch(batch_size) +# value_test = test_ds.make_one_shot_iterator().get_next() + +for epoch in range(n_epoch): + start_time = time.time() + + train_loss, train_acc, n_iter = 0, 0, 0 + net.train() + for X_batch, y_batch in train_ds: + _loss, acc = _train_step(net, X_batch, y_batch, cost=cost, train_op=optimizer, acc=accuracy) + + train_loss += _loss + train_acc += acc + n_iter += 1 + + # use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + net.eval() + val_loss, val_acc, n_val_iter = 0, 0, 0 + for X_batch, y_batch in test_ds: + _logits = net(X_batch) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_batch, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_val_iter += 1 + print(" val loss: {}".format(val_loss / n_val_iter)) + print(" val acc: {}".format(val_acc / n_val_iter)) + +# use testing data to evaluate the model +net.eval() +test_loss, test_acc, n_iter = 0, 0, 0 +for X_batch, y_batch in test_ds: + _logits = net(X_batch) + test_loss += tl.cost.cross_entropy(_logits, y_batch, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 +print(" test loss: {}".format(test_loss / n_iter)) +print(" test acc: {}".format(test_acc / n_iter)) diff --git a/examples/quantized_net/tutorial_dorefanet_mnist_cnn.py b/examples/quantized_net/tutorial_dorefanet_mnist_cnn.py index 90d7b0893..1cfd68124 100644 --- a/examples/quantized_net/tutorial_dorefanet_mnist_cnn.py +++ b/examples/quantized_net/tutorial_dorefanet_mnist_cnn.py @@ -3,110 +3,99 @@ import time +import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import (BatchNorm, Dense, DorefaConv2d, DorefaDense, Flatten, Input, MaxPool2d) +from tensorlayer.models import Model -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) -# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) - -sess = tf.InteractiveSession() batch_size = 128 -x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) -y_ = tf.placeholder(tf.int64, shape=[batch_size]) - - -def model(x, is_train=True, reuse=False): - # In BNN, all the layers inputs are binary, with the exception of the first layer. - # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py - with tf.variable_scope("binarynet", reuse=reuse): - net = tl.layers.InputLayer(x, name='input') - net = tl.layers.DorefaConv2d(net, 1, 3, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1') #pylint: disable=bare-except - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') - - # net = tl.layers.SignLayer(net) - net = tl.layers.DorefaConv2d(net, 1, 3, 64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2') #pylint: disable=bare-except - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') - - net = tl.layers.FlattenLayer(net) - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') - # net = tl.layers.SignLayer(net) - net = tl.layers.DorefaDenseLayer(net, 1, 3, 256, b_init=None, name='dense') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3') - - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') - # net = tl.layers.SignLayer(net) - net = tl.layers.DenseLayer(net, 10, b_init=None, name='bout') - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno') - return net +def model(inputs_shape, n_class=10): + in_net = Input(inputs_shape, name='input') + net = DorefaConv2d(1, 3, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')(in_net) + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1')(net) + net = BatchNorm(act=tl.act.htanh, name='bn1')(net) -# define inferences -net_train = model(x, is_train=True, reuse=False) -net_test = model(x, is_train=False, reuse=True) + net = DorefaConv2d(1, 3, 64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')(net) + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2')(net) + net = BatchNorm(act=tl.act.htanh, name='bn2')(net) -# cost for training -y = net_train.outputs -cost = tl.cost.cross_entropy(y, y_, name='xentropy') + net = Flatten('flatten')(net) + net = DorefaDense(1, 3, 256, b_init=None, name='dense')(net) + net = BatchNorm(act=tl.act.htanh, name='bn3')(net) + + net = Dense(n_class, b_init=None, name='bout')(net) + net = BatchNorm(name='bno')(net) + net = Model(inputs=in_net, outputs=net, name='dorefanet') + return net -# cost and accuracy for evalution -y2 = net_test.outputs -cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2') -correct_prediction = tf.equal(tf.argmax(y2, 1), y_) -acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) -# define the optimizer -train_params = tl.layers.get_variables_with_name('binarynet', True, True) -train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params) +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None -# initialize all variables in the session -sess.run(tf.global_variables_initializer()) -net_train.print_params() -net_train.print_layers() +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_epoch = 200 print_freq = 5 -# print(sess.run(net_test.all_params)) # print real values of parameters +net = model([None, 28, 28, 1]) +train_op = tf.optimizers.Adam(learning_rate=0.0001) +cost = tl.cost.cross_entropy for epoch in range(n_epoch): start_time = time.time() + train_loss, train_acc, n_batch = 0, 0, 0 + net.train() + for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) + _loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy) + train_loss += _loss + train_acc += acc + n_batch += 1 + + # print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) + # print(" train loss: %f" % (train_loss / n_batch)) + # print(" train acc: %f" % (train_acc / n_batch)) if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) - train_loss, train_acc, n_batch = 0, 0, 0 - for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) - train_loss += err - train_acc += ac - n_batch += 1 print(" train loss: %f" % (train_loss / n_batch)) print(" train acc: %f" % (train_acc / n_batch)) - val_loss, val_acc, n_batch = 0, 0, 0 + val_loss, val_acc, val_batch = 0, 0, 0 + net.eval() for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a}) - val_loss += err - val_acc += ac - n_batch += 1 - print(" val loss: %f" % (val_loss / n_batch)) - print(" val acc: %f" % (val_acc / n_batch)) - -print('Evaluation') -test_loss, test_acc, n_batch = 0, 0, 0 + _logits = net(X_val_a) + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) + val_batch += 1 + print(" val loss: {}".format(val_loss / val_batch)) + print(" val acc: {}".format(val_acc / val_batch)) + +net.test() +test_loss, test_acc, n_test_batch = 0, 0, 0 for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) - test_loss += err - test_acc += ac - n_batch += 1 -print(" test loss: %f" % (test_loss / n_batch)) -print(" test acc: %f" % (test_acc / n_batch)) + _logits = net(X_test_a) + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) + n_test_batch += 1 +print(" test loss: %f" % (test_loss / n_test_batch)) +print(" test acc: %f" % (test_acc / n_test_batch)) diff --git a/examples/quantized_net/tutorial_quanconv_cifar10.py b/examples/quantized_net/tutorial_quanconv_cifar10.py index 6eb35ed67..9b649e6f0 100644 --- a/examples/quantized_net/tutorial_quanconv_cifar10.py +++ b/examples/quantized_net/tutorial_quanconv_cifar10.py @@ -38,105 +38,171 @@ we run them inside 16 separate threads which continuously fill a TensorFlow queue. """ +import multiprocessing import time import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import (Dense, Flatten, Input, MaxPool2d, QuanConv2dWithBN, QuanDense) +from tensorlayer.models import Model -bitW = 8 -bitA = 8 - -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) -sess = tf.InteractiveSession() - +# Download data, and convert to TFRecord format, see ```tutorial_tfrecord.py``` +# prepare cifar10 data X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) -def model(x, y_, reuse, is_train, bitW, bitA): - with tf.variable_scope("model", reuse=reuse): - net = tl.layers.InputLayer(x, name='input') - net = tl.layers.QuanConv2dWithBN( - net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', is_train=is_train, bitW=bitW, bitA=bitA, - name='qcnnbn1' - ) - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1') - # net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') - net = tl.layers.QuanConv2dWithBN( - net, 64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, is_train=is_train, bitW=bitW, bitA=bitA, - name='qcnnbn2' - ) - # net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2') - net = tl.layers.FlattenLayer(net, name='flatten') - net = tl.layers.QuanDenseLayer(net, 384, act=tf.nn.relu, bitW=bitW, bitA=bitA, name='qd1relu') - net = tl.layers.QuanDenseLayer(net, 192, act=tf.nn.relu, bitW=bitW, bitA=bitA, name='qd2relu') - net = tl.layers.DenseLayer(net, 10, act=None, name='output') - y = net.outputs - - ce = tl.cost.cross_entropy(y, y_, name='cost') - L2 = 0 - for p in tl.layers.get_variables_with_name('relu/W', True, True): - L2 += tf.contrib.layers.l2_regularizer(0.004)(p) - cost = ce + L2 - - # correct_prediction = tf.equal(tf.argmax(tf.nn.softmax(y), 1), y_) - correct_prediction = tf.equal(tf.cast(tf.argmax(y, 1), tf.int64), y_) - acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - - return net, cost, acc - - -def distort_fn(x, is_train=False): - x = tl.prepro.crop(x, 24, 24, is_random=is_train) - if is_train: - x = tl.prepro.flip_axis(x, axis=1, is_random=True) - x = tl.prepro.brightness(x, gamma=0.1, gain=1, is_random=True) - x = (x - np.mean(x)) / max(np.std(x), 1e-5) # avoid values divided by 0 - return x - - -x = tf.placeholder(dtype=tf.float32, shape=[None, 24, 24, 3], name='x') -y_ = tf.placeholder(dtype=tf.int64, shape=[None], name='y_') - -network, cost, _ = model(x, y_, False, True, bitW=bitW, bitA=bitA) -_, cost_test, acc = model(x, y_, True, False, bitW=bitW, bitA=bitA) - -# train -n_epoch = 50000 -learning_rate = 0.0001 -print_freq = 1 -batch_size = 128 - -train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08, - use_locking=False).minimize(cost) - -sess.run(tf.global_variables_initializer()) +def model(input_shape, n_classes, bitW, bitA): + in_net = Input(shape=input_shape, name='input') + net = QuanConv2dWithBN(64, (5, 5), (1, 1), act='relu', padding='SAME', bitW=bitW, bitA=bitA, name='qcnnbn1')(in_net) + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(net) + net = QuanConv2dWithBN(64, (5, 5), (1, 1), padding='SAME', act='relu', bitW=bitW, bitA=bitA, name='qcnnbn2')(net) + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(net) + net = Flatten(name='flatten')(net) + net = QuanDense(384, act=tf.nn.relu, bitW=bitW, bitA=bitA, name='qd1relu')(net) + net = QuanDense(192, act=tf.nn.relu, bitW=bitW, bitA=bitA, name='qd2relu')(net) + net = Dense(n_classes, act=None, name='output')(net) + net = Model(inputs=in_net, outputs=net, name='dorefanet') + return net -network.print_params(False) -network.print_layers() -print(' learning_rate: %f' % learning_rate) -print(' batch_size: %d' % batch_size) -print(' bitW: %d, bitA: %d' % (bitW, bitA)) +# training settings +bitW = 8 +bitA = 8 +net = model([None, 24, 24, 3], n_classes=10, bitW=bitW, bitA=bitA) +batch_size = 128 +n_epoch = 50000 +learning_rate = 0.0001 +print_freq = 5 +n_step_epoch = int(len(y_train) / batch_size) +n_step = n_epoch * n_step_epoch +shuffle_buffer_size = 128 + +optimizer = tf.optimizers.Adam(learning_rate) +cost = tl.cost.cross_entropy + + +def generator_train(): + inputs = X_train + targets = y_train + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def generator_test(): + inputs = X_test + targets = y_test + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def _map_fn_train(img, target): + # 1. Randomly crop a [height, width] section of the image. + img = tf.image.random_crop(img, [24, 24, 3]) + # 2. Randomly flip the image horizontally. + img = tf.image.random_flip_left_right(img) + # 3. Randomly change brightness. + img = tf.image.random_brightness(img, max_delta=63) + # 4. Randomly change contrast. + img = tf.image.random_contrast(img, lower=0.2, upper=1.8) + # 5. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + target = tf.reshape(target, ()) + return img, target + + +def _map_fn_test(img, target): + # 1. Crop the central [height, width] of the image. + img = tf.image.resize_with_pad(img, 24, 24) + # 2. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + img = tf.reshape(img, (24, 24, 3)) + target = tf.reshape(target, ()) + return img, target + + +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None + + +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + + +# dataset API and augmentation +train_ds = tf.data.Dataset.from_generator( + generator_train, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count()) +# train_ds = train_ds.repeat(n_epoch) +train_ds = train_ds.shuffle(shuffle_buffer_size) +train_ds = train_ds.prefetch(buffer_size=4096) +train_ds = train_ds.batch(batch_size) +# value = train_ds.make_one_shot_iterator().get_next() + +test_ds = tf.data.Dataset.from_generator( + generator_test, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +# test_ds = test_ds.shuffle(shuffle_buffer_size) +test_ds = test_ds.map(_map_fn_test, num_parallel_calls=multiprocessing.cpu_count()) +# test_ds = test_ds.repeat(n_epoch) +test_ds = test_ds.prefetch(buffer_size=4096) +test_ds = test_ds.batch(batch_size) +# value_test = test_ds.make_one_shot_iterator().get_next() for epoch in range(n_epoch): start_time = time.time() - for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - X_train_a = tl.prepro.threading_data(X_train_a, fn=distort_fn, is_train=True) # data augmentation for training - sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) + train_loss, train_acc, n_iter = 0, 0, 0 + net.train() + for X_batch, y_batch in train_ds: + _loss, acc = _train_step(net, X_batch, y_batch, cost=cost, train_op=optimizer, acc=accuracy) + + train_loss += _loss + train_acc += acc + n_iter += 1 + + # use training and evaluation sets to evaluate the model every print_freq epoch if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: - print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) - test_loss, test_acc, n_batch = 0, 0, 0 - for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=False): - X_test_a = tl.prepro.threading_data(X_test_a, fn=distort_fn, is_train=False) # central crop - err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) - test_loss += err - test_acc += ac - n_batch += 1 - print(" test loss: %f" % (test_loss / n_batch)) - print(" test acc: %f" % (test_acc / n_batch)) + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + net.eval() + val_loss, val_acc, n_val_iter = 0, 0, 0 + for X_batch, y_batch in test_ds: + _logits = net(X_batch) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_batch, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_val_iter += 1 + print(" val loss: {}".format(val_loss / n_val_iter)) + print(" val acc: {}".format(val_acc / n_val_iter)) + +# use testing data to evaluate the model +net.eval() +test_loss, test_acc, n_iter = 0, 0, 0 +for X_batch, y_batch in test_ds: + _logits = net(X_batch) + test_loss += tl.cost.cross_entropy(_logits, y_batch, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 +print(" test loss: {}".format(test_loss / n_iter)) +print(" test acc: {}".format(test_acc / n_iter)) diff --git a/examples/quantized_net/tutorial_quanconv_mnist.py b/examples/quantized_net/tutorial_quanconv_mnist.py index 4060c6137..1dbfe8d4d 100644 --- a/examples/quantized_net/tutorial_quanconv_mnist.py +++ b/examples/quantized_net/tutorial_quanconv_mnist.py @@ -1,107 +1,116 @@ -#! /usr/bin/python +#!/usr/bin/env python3 # -*- coding: utf-8 -*- import time +import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import ( + Dense, Dropout, Flatten, Input, MaxPool2d, QuanConv2d, QuanConv2dWithBN, QuanDense, QuanDenseLayerWithBN +) +from tensorlayer.models import Model -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) # X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) -sess = tf.InteractiveSession() - batch_size = 128 -x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) -y_ = tf.placeholder(tf.int64, shape=[batch_size]) +def model(inputs_shape, n_class=10): + net_in = Input(inputs_shape, name="input") -def model(x, is_train=True, reuse=False): - with tf.variable_scope("quan_cnn", reuse=reuse): - net = tl.layers.InputLayer(x, name='input') - net = tl.layers.QuanConv2dWithBN( - net, 32, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, is_train=is_train, name='qcbnb1' - ) - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') + net = QuanConv2dWithBN( + n_filter=32, filter_size=(5, 5), strides=(1, 1), padding='SAME', act=tl.nn.relu, name='qconvbn1' + )(net_in) + net = MaxPool2d(filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool1')(net) - net = tl.layers.QuanConv2dWithBN( - net, 64, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, is_train=is_train, name='qcbn2' - ) - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') + net = QuanConv2dWithBN( + n_filter=64, filter_size=(5, 5), strides=(1, 1), padding='SAME', act=tl.nn.relu, name='qconvbn2' + )(net) + net = MaxPool2d(filter_size=(2, 2), strides=(2, 2), padding='SAME', name='pool2')(net) - net = tl.layers.FlattenLayer(net) - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') - net = tl.layers.QuanDenseLayerWithBN(net, 256, is_train=is_train, act=tf.nn.relu, name='qdbn') + net = Flatten(name='ft')(net) - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') - net = tl.layers.QuanDenseLayer(net, 10, name='qdbn_out') - return net + # net = QuanDense(256, act="relu", name='qdbn')(net) + # net = QuanDense(n_class, name='qdbn_out')(net) + net = QuanDenseLayerWithBN(256, act="relu", name='qdbn')(net) + net = QuanDenseLayerWithBN(n_class, name='qdbn_out')(net) -# define inferences -net_train = model(x, is_train=True, reuse=False) -net_test = model(x, is_train=False, reuse=True) + # net = Dense(256, act='relu', name='Dense1')(net) + # net = Dense(n_class, name='Dense2')(net) -# cost for training -y = net_train.outputs -cost = tl.cost.cross_entropy(y, y_, name='xentropy') + net = Model(inputs=net_in, outputs=net, name='quan') + return net -# cost and accuracy for evalution -y2 = net_test.outputs -cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2') -correct_prediction = tf.equal(tf.argmax(y2, 1), y_) -acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) -# define the optimizer -train_params = tl.layers.get_variables_with_name('quan_cnn', True, True) -train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params) +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None -# initialize all variables in the session -sess.run(tf.global_variables_initializer()) -net_train.print_params(False) -net_train.print_layers() +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_epoch = 200 -print_freq = 5 +print_freq = 1 # print(sess.run(net_test.all_params)) # print real values of parameters +net = model([None, 28, 28, 1]) +train_op = tf.optimizers.Adam(learning_rate=0.0001) +cost = tl.cost.cross_entropy for epoch in range(n_epoch): start_time = time.time() + train_loss, train_acc, n_iter = 0, 0, 0 + for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) + net.train() + _loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy) + + train_loss += _loss + train_acc += acc + n_iter += 1 + + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: - print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) - train_loss, train_acc, n_batch = 0, 0, 0 - for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) - train_loss += err - train_acc += ac - n_batch += 1 - print(" train loss: %f" % (train_loss / n_batch)) - print(" train acc: %f" % (train_acc / n_batch)) - val_loss, val_acc, n_batch = 0, 0, 0 + + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + # net.eval() + val_loss, val_acc, n_eval = 0, 0, 0 for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a}) - val_loss += err - val_acc += ac - n_batch += 1 - print(" val loss: %f" % (val_loss / n_batch)) - print(" val acc: %f" % (val_acc / n_batch)) - -print('Evaluation') -test_loss, test_acc, n_batch = 0, 0, 0 + _logits = net(X_val_a) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) + n_eval += 1 + print(" val loss: {}".format(val_loss / n_eval)) + print(" val acc: {}".format(val_acc / n_eval)) + +# net.eval() +test_loss, test_acc, n_test_batch = 0, 0, 0 for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) - test_loss += err - test_acc += ac - n_batch += 1 -print(" test loss: %f" % (test_loss / n_batch)) -print(" test acc: %f" % (test_acc / n_batch)) + _logits = net(X_test_a) + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) + n_test_batch += 1 +print(" test loss: %f" % (test_loss / n_test_batch)) +print(" test acc: %f" % (test_acc / n_test_batch)) diff --git a/examples/quantized_net/tutorial_ternaryweight_cifar10_tfrecord.py b/examples/quantized_net/tutorial_ternaryweight_cifar10_tfrecord.py index f1ee7b4bb..c78686011 100644 --- a/examples/quantized_net/tutorial_ternaryweight_cifar10_tfrecord.py +++ b/examples/quantized_net/tutorial_ternaryweight_cifar10_tfrecord.py @@ -38,232 +38,184 @@ we run them inside 16 separate threads which continuously fill a TensorFlow queue. """ -import os +import multiprocessing import time +import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import ( + Conv2d, Dense, Flatten, Input, LocalResponseNorm, MaxPool2d, TernaryConv2d, TernaryDense +) +from tensorlayer.models import Model -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) -model_file_name = "./model_cifar10_tfrecord.ckpt" -resume = False # load model, resume from previous checkpoint? - # Download data, and convert to TFRecord format, see ```tutorial_tfrecord.py``` +# prepare cifar10 data X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) -print('X_train.shape', X_train.shape) # (50000, 32, 32, 3) -print('y_train.shape', y_train.shape) # (50000,) -print('X_test.shape', X_test.shape) # (10000, 32, 32, 3) -print('y_test.shape', y_test.shape) # (10000,) -print('X %s y %s' % (X_test.dtype, y_test.dtype)) - - -def data_to_tfrecord(images, labels, filename): - """Save data into TFRecord.""" - if os.path.isfile(filename): - print("%s exists" % filename) - return - print("Converting data into %s ..." % filename) - # cwd = os.getcwd() - writer = tf.python_io.TFRecordWriter(filename) - for index, img in enumerate(images): - img_raw = img.tobytes() - # Visualize a image - # tl.visualize.frame(np.asarray(img, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236) - label = int(labels[index]) - # print(label) - # Convert the bytes back to image as follow: - # image = Image.frombytes('RGB', (32, 32), img_raw) - # image = np.fromstring(img_raw, np.float32) - # image = image.reshape([32, 32, 3]) - # tl.visualize.frame(np.asarray(image, dtype=np.uint8), second=1, saveable=False, name='frame', fig_idx=1236) - example = tf.train.Example( - features=tf.train.Features( - feature={ - "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])), - 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), - } - ) - ) - writer.write(example.SerializeToString()) # Serialize To String - writer.close() - - -def read_and_decode(filename, is_train=None): - """Return tensor to read from TFRecord.""" - filename_queue = tf.train.string_input_producer([filename]) - reader = tf.TFRecordReader() - _, serialized_example = reader.read(filename_queue) - features = tf.parse_single_example( - serialized_example, features={ - 'label': tf.FixedLenFeature([], tf.int64), - 'img_raw': tf.FixedLenFeature([], tf.string), - } - ) - # You can do more image distortion here for training data - img = tf.decode_raw(features['img_raw'], tf.float32) - img = tf.reshape(img, [32, 32, 3]) - # img = tf.cast(img, tf.float32) #* (1. / 255) - 0.5 - if is_train ==True: - # 1. Randomly crop a [height, width] section of the image. - img = tf.random_crop(img, [24, 24, 3]) - - # 2. Randomly flip the image horizontally. - img = tf.image.random_flip_left_right(img) - - # 3. Randomly change brightness. - img = tf.image.random_brightness(img, max_delta=63) - - # 4. Randomly change contrast. - img = tf.image.random_contrast(img, lower=0.2, upper=1.8) - - # 5. Subtract off the mean and divide by the variance of the pixels. - img = tf.image.per_image_standardization(img) - - elif is_train == False: - # 1. Crop the central [height, width] of the image. - img = tf.image.resize_image_with_crop_or_pad(img, 24, 24) - - # 2. Subtract off the mean and divide by the variance of the pixels. - img = tf.image.per_image_standardization(img) - - elif is_train == None: - img = img - - label = tf.cast(features['label'], tf.int32) - return img, label - - -# Save data into TFRecord files -data_to_tfrecord(images=X_train, labels=y_train, filename="train.cifar10") -data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10") +def model(input_shape, n_classes): + in_net = Input(shape=input_shape, name='input') + + net = Conv2d(64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='cnn1')(in_net) + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(net) + net = LocalResponseNorm(4, 1.0, 0.001 / 9.0, 0.75, name='norm1')(net) + + net = TernaryConv2d(64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='cnn2')(net) + net = LocalResponseNorm(4, 1.0, 0.001 / 9.0, 0.75, name='norm2')(net) + net = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(net) + + net = Flatten(name='flatten')(net) + + net = TernaryDense(384, act=tf.nn.relu, name='d1relu')(net) + net = TernaryDense(192, act=tf.nn.relu, name='d2relu')(net) + net = Dense(n_classes, act=None, name='output')(net) + + net = Model(inputs=in_net, outputs=net, name='dorefanet') + return net + + +# training settings +bitW = 8 +bitA = 8 +net = model([None, 24, 24, 3], n_classes=10) batch_size = 128 -model_file_name = "./model_cifar10_advanced.ckpt" -resume = False # load model, resume from previous checkpoint? - -with tf.device('/cpu:0'): - sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) - # prepare data in cpu - x_train_, y_train_ = read_and_decode("train.cifar10", True) - x_test_, y_test_ = read_and_decode("test.cifar10", False) - # set the number of threads here - x_train_batch, y_train_batch = tf.train.shuffle_batch( - [x_train_, y_train_], batch_size=batch_size, capacity=2000, min_after_dequeue=1000, num_threads=32 - ) - # for testing, uses batch instead of shuffle_batch - x_test_batch, y_test_batch = tf.train.batch( - [x_test_, y_test_], batch_size=batch_size, capacity=50000, num_threads=32 - ) - - def model(x_crop, y_, reuse): - """For more simplified CNN APIs, check tensorlayer.org.""" - with tf.variable_scope("model", reuse=reuse): - net = tl.layers.InputLayer(x_crop, name='input') - net = tl.layers.Conv2d(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='cnn1') - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool1') - net = tl.layers.LocalResponseNormLayer(net, 4, 1.0, 0.001 / 9.0, 0.75, name='norm1') - net = tl.layers.TernaryConv2d(net, 64, (5, 5), (1, 1), act=tf.nn.relu, padding='SAME', name='cnn2') - net = tl.layers.LocalResponseNormLayer(net, 4, 1.0, 0.001 / 9.0, 0.75, name='norm2') - net = tl.layers.MaxPool2d(net, (3, 3), (2, 2), padding='SAME', name='pool2') - net = tl.layers.FlattenLayer(net, name='flatten') - net = tl.layers.TernaryDenseLayer(net, 384, act=tf.nn.relu, name='d1relu') - net = tl.layers.TernaryDenseLayer(net, 192, act=tf.nn.relu, name='d2relu') - net = tl.layers.DenseLayer(net, 10, act=None, name='output') - y = net.outputs - - ce = tl.cost.cross_entropy(y, y_, name='cost') - # L2 for the MLP, without this, the accuracy will be reduced by 15%. - L2 = 0 - for p in tl.layers.get_variables_with_name('relu/W', True, True): - L2 += tf.contrib.layers.l2_regularizer(0.004)(p) - cost = ce + L2 - - # correct_prediction = tf.equal(tf.argmax(tf.nn.softmax(y), 1), y_) - correct_prediction = tf.equal(tf.cast(tf.argmax(y, 1), tf.int32), y_) - acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) - - return net, cost, acc - - # You can also use placeholder to feed_dict in data after using - # val, l = sess.run([x_train_batch, y_train_batch]) to get the data - # x_crop = tf.placeholder(tf.float32, shape=[batch_size, 24, 24, 3]) - # y_ = tf.placeholder(tf.int32, shape=[batch_size,]) - # cost, acc, network = model(x_crop, y_, None) - - with tf.device('/gpu:0'): # <-- remove it if you don't have GPU - network, cost, acc, = model(x_train_batch, y_train_batch, False) - _, cost_test, acc_test = model(x_test_batch, y_test_batch, True) - - # train - n_epoch = 50000 - learning_rate = 0.0001 - print_freq = 1 - n_step_epoch = int(len(y_train) / batch_size) - n_step = n_epoch * n_step_epoch - - with tf.device('/gpu:0'): # <-- remove it if you don't have GPU - train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost) - - sess.run(tf.global_variables_initializer()) - if resume: - print("Load existing model " + "!" * 10) - saver = tf.train.Saver() - saver.restore(sess, model_file_name) - - network.print_params(False) - network.print_layers() - - print(' learning_rate: %f' % learning_rate) - print(' batch_size: %d' % batch_size) - print(' n_epoch: %d, step in an epoch: %d, total n_step: %d' % (n_epoch, n_step_epoch, n_step)) - - coord = tf.train.Coordinator() - threads = tf.train.start_queue_runners(sess=sess, coord=coord) - step = 0 - for epoch in range(n_epoch): - start_time = time.time() - train_loss, train_acc, n_batch = 0, 0, 0 - for s in range(n_step_epoch): - # You can also use placeholder to feed_dict in data after using - # val, l = sess.run([x_train_batch, y_train_batch]) - # tl.visualize.images2d(val, second=3, saveable=False, name='batch', dtype=np.uint8, fig_idx=2020121) - # err, ac, _ = sess.run([cost, acc, train_op], feed_dict={x_crop: val, y_: l}) - err, ac, _ = sess.run([cost, acc, train_op]) - step += 1 - train_loss += err - train_acc += ac - n_batch += 1 - - if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: - print( - "Epoch %d : Step %d-%d of %d took %fs" % - (epoch, step, step + n_step_epoch, n_step, time.time() - start_time) - ) - print(" train loss: %f" % (train_loss / n_batch)) - print(" train acc: %f" % (train_acc / n_batch)) - - test_loss, test_acc, n_batch = 0, 0, 0 - for _ in range(int(len(y_test) / batch_size)): - err, ac = sess.run([cost_test, acc_test]) - test_loss += err - test_acc += ac - n_batch += 1 - print(" test loss: %f" % (test_loss / n_batch)) - print(" test acc: %f" % (test_acc / n_batch)) - - if (epoch + 1) % (print_freq * 50) == 0: - print("Save model " + "!" * 10) - saver = tf.train.Saver() - save_path = saver.save(sess, model_file_name) - # you can also save model into npz - tl.files.save_npz(network.all_params, name='model.npz', sess=sess) - # and restore it as follow: - # tl.files.load_and_assign_npz(sess=sess, name='model.npz', network=network) - - coord.request_stop() - coord.join(threads) - sess.close() +n_epoch = 50000 +learning_rate = 0.0001 +print_freq = 5 +n_step_epoch = int(len(y_train) / batch_size) +n_step = n_epoch * n_step_epoch +shuffle_buffer_size = 128 + +optimizer = tf.optimizers.Adam(learning_rate) +cost = tl.cost.cross_entropy + + +def generator_train(): + inputs = X_train + targets = y_train + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def generator_test(): + inputs = X_test + targets = y_test + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + # yield _input.encode('utf-8'), _target.encode('utf-8') + yield _input, _target + + +def _map_fn_train(img, target): + # 1. Randomly crop a [height, width] section of the image. + img = tf.image.random_crop(img, [24, 24, 3]) + # 2. Randomly flip the image horizontally. + img = tf.image.random_flip_left_right(img) + # 3. Randomly change brightness. + img = tf.image.random_brightness(img, max_delta=63) + # 4. Randomly change contrast. + img = tf.image.random_contrast(img, lower=0.2, upper=1.8) + # 5. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + target = tf.reshape(target, ()) + return img, target + + +def _map_fn_test(img, target): + # 1. Crop the central [height, width] of the image. + img = tf.image.resize_with_pad(img, 24, 24) + # 2. Subtract off the mean and divide by the variance of the pixels. + img = tf.image.per_image_standardization(img) + img = tf.reshape(img, (24, 24, 3)) + target = tf.reshape(target, ()) + return img, target + + +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None + + +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + + +# dataset API and augmentation +train_ds = tf.data.Dataset.from_generator( + generator_train, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count()) +# train_ds = train_ds.repeat(n_epoch) +train_ds = train_ds.shuffle(shuffle_buffer_size) +train_ds = train_ds.prefetch(buffer_size=4096) +train_ds = train_ds.batch(batch_size) +# value = train_ds.make_one_shot_iterator().get_next() + +test_ds = tf.data.Dataset.from_generator( + generator_test, output_types=(tf.float32, tf.int32) +) # , output_shapes=((24, 24, 3), (1))) +# test_ds = test_ds.shuffle(shuffle_buffer_size) +test_ds = test_ds.map(_map_fn_test, num_parallel_calls=multiprocessing.cpu_count()) +# test_ds = test_ds.repeat(n_epoch) +test_ds = test_ds.prefetch(buffer_size=4096) +test_ds = test_ds.batch(batch_size) +# value_test = test_ds.make_one_shot_iterator().get_next() + +for epoch in range(n_epoch): + start_time = time.time() + + train_loss, train_acc, n_iter = 0, 0, 0 + net.train() + for X_batch, y_batch in train_ds: + _loss, acc = _train_step(net, X_batch, y_batch, cost=cost, train_op=optimizer, acc=accuracy) + + train_loss += _loss + train_acc += acc + n_iter += 1 + + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + # use training and evaluation sets to evaluate the model every print_freq epoch + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + net.eval() + val_loss, val_acc, n_val_iter = 0, 0, 0 + for X_batch, y_batch in test_ds: + _logits = net(X_batch) # is_train=False, disable dropout + val_loss += tl.cost.cross_entropy(_logits, y_batch, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_val_iter += 1 + print(" val loss: {}".format(val_loss / n_val_iter)) + print(" val acc: {}".format(val_acc / n_val_iter)) + +# use testing data to evaluate the model +net.eval() +test_loss, test_acc, n_iter = 0, 0, 0 +for X_batch, y_batch in test_ds: + _logits = net(X_batch) + test_loss += tl.cost.cross_entropy(_logits, y_batch, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch)) + n_iter += 1 +print(" test loss: {}".format(test_loss / n_iter)) +print(" test acc: {}".format(test_acc / n_iter)) diff --git a/examples/quantized_net/tutorial_ternaryweight_mnist_cnn.py b/examples/quantized_net/tutorial_ternaryweight_mnist_cnn.py index e1c305db6..a708d1f0e 100644 --- a/examples/quantized_net/tutorial_ternaryweight_mnist_cnn.py +++ b/examples/quantized_net/tutorial_ternaryweight_mnist_cnn.py @@ -3,110 +3,100 @@ import time +import numpy as np import tensorflow as tf import tensorlayer as tl +from tensorlayer.layers import (BatchNorm, Dense, Flatten, Input, MaxPool2d, TernaryConv2d, TernaryDense) +from tensorlayer.models import Model -tf.logging.set_verbosity(tf.logging.DEBUG) tl.logging.set_verbosity(tl.logging.DEBUG) X_train, y_train, X_val, y_val, X_test, y_test = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1)) -# X_train, y_train, X_test, y_test = tl.files.load_cropped_svhn(include_extra=False) - -sess = tf.InteractiveSession() batch_size = 128 -x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1]) -y_ = tf.placeholder(tf.int64, shape=[batch_size]) - - -def model(x, is_train=True, reuse=False): - # In BNN, all the layers inputs are binary, with the exception of the first layer. - # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py - with tf.variable_scope("binarynet", reuse=reuse): - net = tl.layers.InputLayer(x, name='input') - net = tl.layers.TernaryConv2d(net, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1') - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool1') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn1') - - # net = tl.layers.SignLayer(net) - net = tl.layers.TernaryConv2d(net, 64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2') - net = tl.layers.MaxPool2d(net, (2, 2), (2, 2), padding='SAME', name='pool2') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn2') - - net = tl.layers.FlattenLayer(net) - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop1') - # net = tl.layers.SignLayer(net) - net = tl.layers.TernaryDenseLayer(net, 256, b_init=None, name='dense') - net = tl.layers.BatchNormLayer(net, act=tl.act.htanh, is_train=is_train, name='bn3') - - # net = tl.layers.DropoutLayer(net, 0.8, True, is_train, name='drop2') - # net = tl.layers.SignLayer(net) - net = tl.layers.TernaryDenseLayer(net, 10, b_init=None, name='bout') - net = tl.layers.BatchNormLayer(net, is_train=is_train, name='bno') - return net +def model(inputs_shape, n_class=10): + in_net = Input(inputs_shape, name='input') + net = TernaryConv2d(32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')(in_net) + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool1')(net) + net = BatchNorm(act=tl.act.htanh, name='bn1')(net) + + net = TernaryConv2d(64, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn2')(net) + net = MaxPool2d((2, 2), (2, 2), padding='SAME', name='pool2')(net) + net = BatchNorm(act=tl.act.htanh, name='bn2')(net) -# define inferences -net_train = model(x, is_train=True, reuse=False) -net_test = model(x, is_train=False, reuse=True) + net = Flatten('flatten')(net) + net = Dense(256, b_init=None, name='dense')(net) + net = BatchNorm(act=tl.act.htanh, name='bn3')(net) + + net = TernaryDense(n_class, b_init=None, name='bout')(net) + net = BatchNorm(name='bno')(net) + + net = Model(inputs=in_net, outputs=net, name='dorefanet') + return net -# cost for training -y = net_train.outputs -cost = tl.cost.cross_entropy(y, y_, name='xentropy') -# cost and accuracy for evalution -y2 = net_test.outputs -cost_test = tl.cost.cross_entropy(y2, y_, name='xentropy2') -correct_prediction = tf.equal(tf.argmax(y2, 1), y_) -acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) +def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None): + with tf.GradientTape() as tape: + y_pred = network(X_batch) + _loss = cost(y_pred, y_batch) + grad = tape.gradient(_loss, network.trainable_weights) + train_op.apply_gradients(zip(grad, network.trainable_weights)) + if acc is not None: + _acc = acc(y_pred, y_batch) + return _loss, _acc + else: + return _loss, None -# define the optimizer -train_params = tl.layers.get_variables_with_name('binarynet', True, True) -train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cost, var_list=train_params) -# initialize all variables in the session -sess.run(tf.global_variables_initializer()) +def accuracy(_logits, y_batch): + return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) -net_train.print_params() -net_train.print_layers() n_epoch = 200 print_freq = 5 -# print(sess.run(net_test.all_params)) # print real values of parameters +net = model([None, 28, 28, 1]) +train_op = tf.optimizers.Adam(learning_rate=0.0001) +cost = tl.cost.cross_entropy for epoch in range(n_epoch): start_time = time.time() + train_loss, train_acc, n_batch = 0, 0, 0 + net.train() + for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - sess.run(train_op, feed_dict={x: X_train_a, y_: y_train_a}) + _loss, acc = _train_step(net, X_train_a, y_train_a, cost=cost, train_op=train_op, acc=accuracy) + train_loss += _loss + train_acc += acc + n_batch += 1 + + print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: %f" % (train_loss / n_batch)) + print(" train acc: %f" % (train_acc / n_batch)) if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) - train_loss, train_acc, n_batch = 0, 0, 0 - for X_train_a, y_train_a in tl.iterate.minibatches(X_train, y_train, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_train_a, y_: y_train_a}) - train_loss += err - train_acc += ac - n_batch += 1 print(" train loss: %f" % (train_loss / n_batch)) print(" train acc: %f" % (train_acc / n_batch)) - val_loss, val_acc, n_batch = 0, 0, 0 + val_loss, val_acc, val_batch = 0, 0, 0 + net.eval() for X_val_a, y_val_a in tl.iterate.minibatches(X_val, y_val, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_val_a, y_: y_val_a}) - val_loss += err - val_acc += ac - n_batch += 1 - print(" val loss: %f" % (val_loss / n_batch)) - print(" val acc: %f" % (val_acc / n_batch)) - -print('Evaluation') -test_loss, test_acc, n_batch = 0, 0, 0 + _logits = net(X_val_a) + val_loss += tl.cost.cross_entropy(_logits, y_val_a, name='eval_loss') + val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_val_a)) + val_batch += 1 + print(" val loss: {}".format(val_loss / val_batch)) + print(" val acc: {}".format(val_acc / val_batch)) + +net.test() +test_loss, test_acc, n_test_batch = 0, 0, 0 for X_test_a, y_test_a in tl.iterate.minibatches(X_test, y_test, batch_size, shuffle=True): - err, ac = sess.run([cost_test, acc], feed_dict={x: X_test_a, y_: y_test_a}) - test_loss += err - test_acc += ac - n_batch += 1 -print(" test loss: %f" % (test_loss / n_batch)) -print(" test acc: %f" % (test_acc / n_batch)) + _logits = net(X_test_a) + test_loss += tl.cost.cross_entropy(_logits, y_test_a, name='test_loss') + test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_test_a)) + n_test_batch += 1 +print(" test loss: %f" % (test_loss / n_test_batch)) +print(" test acc: %f" % (test_acc / n_test_batch))