Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batchnormlayer compatibility to TF12 #42

Merged
merged 1 commit into from
Dec 21, 2016

Conversation

boscotsang
Copy link
Contributor

Since the default argument of gama_init is ones_initializer and the change of api in TF12. Though it can be fixed by user to specify the gama_init argument to be ones_initializer() it's easier for newbie to avoid this issue by revising the source code of tensorlayer.

Add compatibility to TF12 cause by the change of ones_initializer api.
@zsdonghao zsdonghao merged commit 0f2ab83 into tensorlayer:master Dec 21, 2016
@wagamamaz
Copy link
Collaborator

wagamamaz commented Jan 6, 2017

@boscotsang hi, are you using TL with TF12?

I found a interesting thing, when I use TF11, a BatchNormLayer only have 4 parameters, but have 8 parameters when using TF12. Do you have any idea about that?

Thank you in advance.

TF11 TL1.3

  param   0: (5, 5, 1, 32)      CNN/cnn_layer1/W_conv2d:0
  param   1: (32,)              CNN/batch1/beta:0
  param   2: (32,)              CNN/batch1/gamma:0
  param   3: (32,)              CNN/batch1/moving_mean:0
  param   4: (32,)              CNN/batch1/moving_variance:0
  param   5: (5, 5, 32, 64)     CNN/cnn_layer2/W_conv2d:0
  param   6: (64,)              CNN/cnn_layer2/b_conv2d:0
  param   7: (64,)              CNN/batch2/beta:0
  param   8: (64,)              CNN/batch2/gamma:0
  param   9: (64,)              CNN/batch2/moving_mean:0
  param  10: (64,)              CNN/batch2/moving_variance:0
  param  11: (3136, 256)        CNN/relu1/W:0
  param  12: (256,)             CNN/relu1/b:0
  param  13: (256, 10)          CNN/output_layer/W:0
  param  14: (10,)              CNN/output_layer/b:0

TF12 TL1.3

  param   0: (5, 5, 1, 32)      CNN/cnn_layer1/W_conv2d:0
  param   1: (32,)              CNN/batch1/beta:0
  param   2: (32,)              CNN/batch1/gamma:0
  param   3: (32,)              CNN/batch1/moving_mean:0
  param   4: (32,)              CNN/batch1/moving_variance:0
  param   5: (32,)              CNN/batch1/CNN/batch1/moving_mean/biased:0
  param   6: ()                 CNN/batch1/CNN/batch1/moving_mean/local_step:0
  param   7: (32,)              CNN/batch1/CNN/batch1/moving_variance/biased:0
  param   8: ()                 CNN/batch1/CNN/batch1/moving_variance/local_step:0
  param   9: (5, 5, 32, 64)     CNN/cnn_layer2/W_conv2d:0
  param  10: (64,)              CNN/cnn_layer2/b_conv2d:0
  param  11: (64,)              CNN/batch2/beta:0
  param  12: (64,)              CNN/batch2/gamma:0
  param  13: (64,)              CNN/batch2/moving_mean:0
  param  14: (64,)              CNN/batch2/moving_variance:0
  param  15: (64,)              CNN/batch2/CNN/batch2/moving_mean/biased:0
  param  16: ()                 CNN/batch2/CNN/batch2/moving_mean/local_step:0
  param  17: (64,)              CNN/batch2/CNN/batch2/moving_variance/biased:0
  param  18: ()                 CNN/batch2/CNN/batch2/moving_variance/local_step:0
  param  19: (3136, 256)        CNN/relu1/W:0
  param  20: (256,)             CNN/relu1/b:0
  param  21: (256, 10)          CNN/output_layer/W:0
  param  22: (10,)              CNN/output_layer/b:0

Code

X_train, y_train, X_val, y_val, X_test, y_test = \
                tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))

X_train = np.asarray(X_train, dtype=np.float32)
y_train = np.asarray(y_train, dtype=np.int64)
X_val = np.asarray(X_val, dtype=np.float32)
y_val = np.asarray(y_val, dtype=np.int64)
X_test = np.asarray(X_test, dtype=np.float32)
y_test = np.asarray(y_test, dtype=np.int64)

print('X_train.shape', X_train.shape)
print('y_train.shape', y_train.shape)
print('X_val.shape', X_val.shape)
print('y_val.shape', y_val.shape)
print('X_test.shape', X_test.shape)
print('y_test.shape', y_test.shape)
print('X %s   y %s' % (X_test.dtype, y_test.dtype))

sess = tf.InteractiveSession()

# Define the batchsize at the begin, you can give the batchsize in x and y_
# rather than 'None', this can allow TensorFlow to apply some optimizations
# – especially for convolutional layers.
batch_size = 128

x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1])   # [batch_size, height, width, channels]
y_ = tf.placeholder(tf.int64, shape=[batch_size,])

def inference(x, is_train, reuse=None):
    with tf.variable_scope("CNN", reuse=reuse):
        tl.layers.set_name_reuse(reuse)
        network = tl.layers.InputLayer(x, name='input_layer')
        network = tl.layers.Conv2d(network, n_filter=32, filter_size=(5, 5), strides=(1, 1),
                act=None, b_init=None, padding='SAME', name='cnn_layer1')
        network = tl.layers.BatchNormLayer(network, act=tf.nn.relu, is_train=True, name='batch1')

        network = tl.layers.MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                padding='SAME', name='pool_layer1')
        network = tl.layers.Conv2d(network, n_filter=64, filter_size=(5, 5), strides=(1, 1),
                act=None, padding='SAME', name='cnn_layer2')
        network = tl.layers.BatchNormLayer(network, act=tf.nn.relu, is_train=True, name='batch2')

        network = tl.layers.MaxPool2d(network, filter_size=(2, 2), strides=(2, 2),
                padding='SAME', name='pool_layer2')
        ## end of conv
        network = tl.layers.FlattenLayer(network, name='flatten_layer')   # output: (?, 3136)
        network = tl.layers.DropoutLayer(network, keep=0.5, name='drop1') # output: (?, 3136)
        network = tl.layers.DenseLayer(network, n_units=256,
                                        act = tf.nn.relu, name='relu1')   # output: (?, 256)
        network = tl.layers.DropoutLayer(network, keep=0.5, name='drop2') # output: (?, 256)
        network = tl.layers.DenseLayer(network, n_units=10,
                                        act = tf.identity,
                                        name='output_layer')    # output: (?, 10)
    return network

network = inference(x, is_train=True, reuse=False)
y = network.outputs

ce = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(y, y_))
cost = ce

correct_prediction = tf.equal(tf.argmax(y, 1), y_)
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# train
n_epoch = 200
learning_rate = 0.0001
print_freq = 1

train_params = network.all_params
train_op = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.999,
    epsilon=1e-08, use_locking=False).minimize(cost, var_list=train_params)

tl.layers.initialize_global_variables(sess)
network.print_params(False)
network.print_layers()

print('   learning_rate: %f' % learning_rate)
print('   batch_size: %d' % batch_size)

for epoch in range(n_epoch):
    start_time = time.time()
    for X_train_a, y_train_a in tl.iterate.minibatches(
                                X_train, y_train, batch_size, shuffle=True):
        feed_dict = {x: X_train_a, y_: y_train_a}
        feed_dict.update( network.all_drop )        # enable noise layers
        sess.run(train_op, feed_dict=feed_dict)

    if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
        print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
        train_loss, train_acc, n_batch = 0, 0, 0
        for X_train_a, y_train_a in tl.iterate.minibatches(
                                X_train, y_train, batch_size, shuffle=True):
            dp_dict = tl.utils.dict_to_one( network.all_drop )    # disable noise layers
            feed_dict = {x: X_train_a, y_: y_train_a}
            feed_dict.update(dp_dict)
            err, ac = sess.run([cost, acc], feed_dict=feed_dict)
            train_loss += err; train_acc += ac; n_batch += 1
        print("   train loss: %f" % (train_loss/ n_batch))
        print("   train acc: %f" % (train_acc/ n_batch))
        val_loss, val_acc, n_batch = 0, 0, 0
        for X_val_a, y_val_a in tl.iterate.minibatches(
                                    X_val, y_val, batch_size, shuffle=True):
            dp_dict = tl.utils.dict_to_one( network.all_drop )    # disable noise layers
            feed_dict = {x: X_val_a, y_: y_val_a}
            feed_dict.update(dp_dict)
            err, ac = sess.run([cost, acc], feed_dict=feed_dict)
            val_loss += err; val_acc += ac; n_batch += 1
        print("   val loss: %f" % (val_loss/ n_batch))
        print("   val acc: %f" % (val_acc/ n_batch))

print('Evaluation')
test_loss, test_acc, n_batch = 0, 0, 0
for X_test_a, y_test_a in tl.iterate.minibatches(
                            X_test, y_test, batch_size, shuffle=True):
    dp_dict = tl.utils.dict_to_one( network.all_drop )    # disable noise layers
    feed_dict = {x: X_test_a, y_: y_test_a}
    feed_dict.update(dp_dict)
    err, ac = sess.run([cost, acc], feed_dict=feed_dict)
    test_loss += err; test_acc += ac; n_batch += 1
print("   test loss: %f" % (test_loss/n_batch))
print("   test acc: %f" % (test_acc/n_batch))

@zsdonghao zsdonghao added the bug label Jan 6, 2017
@boscotsang
Copy link
Contributor Author

Yes, I'm using TF0.12r and I found that when I use BatchNormLayer and share variables between train and test as you code in my ResNet 164 on Cifar10 the training cost drops normally while the test cost nearly doesn't change. Did you have this issue?

@wagamamaz
Copy link
Collaborator

@boscotsang Can you show your code?

The test accuracy increase in my case

Epoch 1 of 200 took 133.758089s
   train loss: 0.711337
   train acc: 0.874659
   val loss: 0.662424
   val acc: 0.889323

Can you run your code again under TensorFlow 12 ? just to see whether your test accuracy increase.
or change the variables of BatchNormLayer to variables = [beta, gamma, moving_mean, moving_variance]?

@boscotsang
Copy link
Contributor Author

@wagamamaz The following is the my code. The image read is the tensorflow pipeline. The data is the cifar10 binary and is put in the dataset directory.

# cifar10_input.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

IMAGE_SIZE = 32

# Global constants describing the CIFAR-10 data set.
NUM_CLASSES = 10
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000


def read_cifar10(filename_queue):
    class CIFAR10Record(object):
        pass

    result = CIFAR10Record()

    label_bytes = 1
    result.height = 32
    result.width = 32
    result.depth = 3
    image_bytes = result.height * result.width * result.depth
    record_bytes = label_bytes + image_bytes
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    result.key, value = reader.read(filename_queue)
    record_bytes = tf.decode_raw(value, tf.uint8)
    result.label = tf.cast(
        tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
    depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
                             [result.depth, result.height, result.width])
    result.uint8image = tf.pad(tf.transpose(depth_major, [1, 2, 0]), [[1, 1], [1, 1], [0, 0]])
    return result


def _generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size, shuffle):
    num_preprocess_threads = 24
    if shuffle:
        images, label_batch = tf.train.shuffle_batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocess_threads,
            capacity=min_queue_examples + 3 * batch_size,
            min_after_dequeue=min_queue_examples)
    else:
        images, label_batch = tf.train.batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocess_threads,
            capacity=min_queue_examples + 3 * batch_size)
    return images, tf.reshape(label_batch, [batch_size])


def distorted_inputs(data_dir, batch_size):
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                 for i in xrange(1, 6)]
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    filename_queue = tf.train.string_input_producer(filenames)
    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    height = IMAGE_SIZE
    width = IMAGE_SIZE
    distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    distorted_image = tf.image.random_brightness(distorted_image,
                                                 max_delta=63)
    distorted_image = tf.image.random_contrast(distorted_image,
                                               lower=0.2, upper=1.8)
    float_image = tf.image.per_image_standardization(distorted_image)
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                             min_fraction_of_examples_in_queue)
    print('Filling queue with %d CIFAR images before starting to train. '
          'This will take a few minutes.' % min_queue_examples)
    return _generate_image_and_label_batch(float_image, read_input.label,
                                           min_queue_examples, batch_size,
                                           shuffle=True)


def inputs(eval_data, data_dir, batch_size):
    if not eval_data:
        filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
                     for i in xrange(1, 6)]
        num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
    else:
        filenames = [os.path.join(data_dir, 'test_batch.bin')]
        num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    filename_queue = tf.train.string_input_producer(filenames)
    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)
    height = IMAGE_SIZE
    width = IMAGE_SIZE
    resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
                                                           width, height)
    float_image = tf.image.per_image_standardization(resized_image)
    # float_image = tf.image.per_image_standardization(reshaped_image)
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(num_examples_per_epoch *
                             min_fraction_of_examples_in_queue)
    return _generate_image_and_label_batch(float_image, read_input.label,
                                           min_queue_examples, batch_size,
                                           shuffle=False)

# cifar10_resnet.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import numpy as np
import tensorflow as tf
import tensorlayer as tl
import cifar10_input

FLAGS = tf.app.flags.FLAGS

# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
                            """Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', 'datasets/cifar10_data',
                           """Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
                            """Train the model using fp16.""")

IMAGE_SIZE = cifar10_input.IMAGE_SIZE
NUM_CLASSES = cifar10_input.NUM_CLASSES
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL

MOVING_AVERAGE_DECAY = 0.9999  # The decay to use for the moving average.
NUM_EPOCHS_PER_DECAY = 350.0  # Epochs after which learning rate decays.


# LEARNING_RATE_DECAY_FACTOR = 0.1  # Learning rate decay factor.


# INITIAL_LEARNING_RATE = 0.05  # Initial learning rate.


def distorted_inputs():
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
    images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
                                                    batch_size=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels


def inputs(eval_data):
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
    images, labels = cifar10_input.inputs(eval_data=eval_data,
                                          data_dir=data_dir,
                                          batch_size=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels

def inference(images, reuse=False, is_train=True):
    def cblock(x, in_filter, nb_filter, stride, stage):
        with tf.variable_scope("CBlock{}".format(stage), reuse=reuse) as scope:
            shortcut = x
            x = tl.layers.BatchNormLayer(x, act=tf.nn.relu, is_train=is_train, name=scope.name + "BN1")
            # x = tl.layers.PReluLayer(x, name=scope.name + "_PRELU1")
            _shortcut = x
            x = tl.layers.Conv2d(x, nb_filter // 4, (1, 1), (stride, stride), padding='SAME',
                                 W_init=tf.contrib.layers.variance_scaling_initializer(
                                     factor=2.0, mode='FAN_IN', uniform=False, seed=None, dtype=tf.float32),
                                 name=scope.name + "CONV1")
            x = tl.layers.BatchNormLayer(x, act=tf.nn.relu, is_train=is_train, name=scope.name + "BN2")
            # x = tl.layers.PReluLayer(x, name=scope.name + "_PRELU2")
            x = tl.layers.Conv2d(x, nb_filter // 4, (3, 3), (1, 1),
                                 padding="SAME",
                                 W_init=tf.contrib.layers.variance_scaling_initializer(
                                     factor=2.0, mode='FAN_IN', uniform=False, seed=None, dtype=tf.float32),
                                 name=scope.name + "CONV2")
            x = tl.layers.BatchNormLayer(x, act=tf.nn.relu, is_train=is_train, name=scope.name + "BN3")
            # x = tl.layers.PReluLayer(x, name=scope.name + "_PRELU3")
            x = tl.layers.Conv2d(x, nb_filter, (1, 1), (1, 1),
                                 padding="SAME",
                                 W_init=tf.contrib.layers.variance_scaling_initializer(
                                     factor=2.0, mode='FAN_IN', uniform=False, seed=None, dtype=tf.float32),
                                 name=scope.name + "CONV3")
            if nb_filter != in_filter:
                shortcut = tl.layers.Conv2d(_shortcut, nb_filter, (1, 1), (stride, stride), padding='VALID',
                                            W_init=tf.contrib.layers.variance_scaling_initializer(
                                                factor=2.0, mode='FAN_IN', uniform=False, seed=None, dtype=tf.float32),
                                            name=scope.name + "IDENTITY")
            out = tl.layers.ElementwiseLayer((x, shortcut), tf.add, name=scope.name + "_Add")
            return out

    with tf.variable_scope("Net", reuse=reuse) as scope:
        tl.layers.set_name_reuse(reuse)
        x = tl.layers.InputLayer(images, name=scope.name + "_INPUT")
        x = tl.layers.BatchNormLayer(x, act=tf.nn.relu, is_train=is_train, name=scope.name + "_BN")
        # x = tl.layers.PReluLayer(x, name=scope.name + "_CONVOUT")
        x = tl.layers.Conv2d(x, 16, (3, 3), (1, 1),
                             padding='SAME',
                             W_init=tf.contrib.layers.variance_scaling_initializer(
                                 factor=2.0, mode='FAN_IN', uniform=False, seed=None, dtype=tf.float32),
                             name=scope.name + "_CONV")
        start, n = 0, 18
        for i in range(start, start + n):
            x = cblock(x, 16, 64, 1, i)
        start += n
        for i in range(start, start + n):
            if i == start:
                x = cblock(x, 64, 128, 2, i)
            else:
                x = cblock(x, 64, 128, 1, i)
        start += n
        for i in range(start, start + n):
            if i == start:
                x = cblock(x, 128, 256, 2, i)
            else:
                x = cblock(x, 128, 256, 1, i)
        x = tl.layers.BatchNormLayer(x, act=tf.nn.relu, is_train=is_train, name=scope.name + "_OUTBN")
        # x = tl.layers.PReluLayer(x, name=scope.name + "_OUTPRELU")
        pool = tl.layers.MeanPool2d(x, (8, 8), (1, 1), padding='VALID')
        out = tl.layers.FlattenLayer(pool, name=scope.name + "_Flatten")
        fc = tl.layers.DenseLayer(out, 10)
        return fc


def train(total_loss, global_step):
    num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
    # decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
    boundaries = (2000, 30000, 80000)
    values = (0.01, 0.1, 0.01, 0.001)
    lr = tf.train.piecewise_constant(global_step, boundaries, values)
    tf.summary.scalar('learning_rate', lr)
    opt = tf.train.AdamOptimizer(lr)
    grads = opt.compute_gradients(total_loss)
    apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
    for var in tf.trainable_variables():
        tf.summary.histogram(var.op.name, var)
    for grad, var in grads:
        if grad is not None:
            tf.summary.histogram(var.op.name + '/gradients', grad)
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
        train_op = tf.no_op(name='train')
    return train_op


def correct(logits, labels):
    return tf.nn.in_top_k(logits.outputs, labels, 1)


def loss(logits, labels, l2=0.0001):
    l2_v = tf.Variable(l2, trainable=False, dtype=tf.float32)
    l2_loss = tf.add_n([tf.nn.l2_loss(x) for x in logits.all_params if x.name.endswith("W_conv2d:0")])
    return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits.outputs, labels)) + l2_v*l2_loss
# cifar10_resnet.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import os.path
import time

import numpy as np
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

import cifar10_resnet

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', 'C:/Users/zgj/.keras/datasets/cifar10_train',
                           """Directory where to write event logs """
                           """and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 100000,
                            """Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")


def train():
    with tf.Graph().as_default():
        global_step = tf.Variable(0, trainable=False)
        images, labels = cifar10_resnet.distorted_inputs()
        images_eval, labels_eval = cifar10_resnet.inputs(True)

        logits = cifar10_resnet.inference(images, False, True)
        logits_eval = cifar10_resnet.inference(images_eval, True, False)
        correct = cifar10_resnet.correct(logits_eval, labels_eval)

        # Calculate loss.
        loss = cifar10_resnet.loss(logits, labels)
        loss_eval = cifar10_resnet.loss(logits_eval, labels_eval)
        train_op = cifar10_resnet.train(loss, global_step)
        saver = tf.train.Saver(tf.all_variables())
        summary_op = tf.summary.merge_all()

        init = tf.global_variables_initializer()

        sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))
        sess.run(init)
        tf.train.start_queue_runners(sess=sess)
        summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
        correct_cnt = 0.
        for step in xrange(FLAGS.max_steps):
            start_time = time.time()
            _, loss_value, loss_eval_value, correct_val = sess.run([train_op, loss, loss_eval, correct])
            # _, loss_value = sess.run([train_op, loss])
            correct_cnt += np.sum(correct_val)
            duration = time.time() - start_time
            assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
            if step % 100 == 0:
                sec_per_batch = float(duration)

                format_str = ('%s: step %d, loss = %.4f, eval_loss = %.4f, eval_accuracy = %.4f (%.3f '
                              'sec/batch)')
                print(format_str % (datetime.now(), step, loss_value, loss_eval_value,
                                    correct_cnt / (100 * FLAGS.batch_size), sec_per_batch))
                summary_str = sess.run(summary_op)
                summary_writer.add_summary(summary_str, step)
                correct_cnt = 0
            if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

def main(argv=None):  # pylint: disable=unused-argument
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    train()


if __name__ == '__main__':
    tf.app.run()

@wagamamaz
Copy link
Collaborator

wagamamaz commented Jan 7, 2017

@boscotsang To evaluate the performance, you need a inference with is_train=False.

e.g.

network = inference(x, is_train=True, reuse=False)
network_test = inference(x, is_train=False, reuse=True)

Donot use the network to evaluate the performance,
and @zsdonghao just update BatchNormLayer for TF12, it works in my case, please try yours.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants