In [4]:
import numpy as np
import tensorflow as tf
import os
from PIL import Image

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [5]:
N_W = 64
N_H = 64
N_C = 3
N_CLASS = 369
BATCH_SIZE = 64
learn_rate = 0.0001   #学习率
pool_size = (2, 2)
TRAIN_LEN = 1514097
TEST_LEN = 1000
MAX_STEP = 100000
EPOCHS = TRAIN_LEN // BATCH_SIZE

tfrecords_dir = './data/tfrecords'  #tfrecords目录
log_dir = './logs'  # 保存参数日志的路径

In [6]:
def read_and_decode(filename, batch_size, shuffle):  # 读取tfrecords数据

    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })

    label = tf.cast(features['label'], tf.int32)
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [N_H, N_W, N_C])
    img = tf.cast(img, tf.float32)
    # img = (img - 128) / 128.0
    # img = tf.image.per_image_standardization(img)

    if shuffle:
        imgs, label_batch = tf.train.shuffle_batch(
            [img, label],
            batch_size=batch_size,
            capacity=20000,
            min_after_dequeue=3000)
    else:
        imgs, label_batch = tf.train.batch(
            [img, label],
            batch_size=batch_size,
            capacity=20000)

    label_batch = tf.one_hot(label_batch, depth=N_CLASS)
    label_batch = tf.cast(label_batch, dtype=tf.int32)
    label_batch = tf.reshape(label_batch, [batch_size, N_CLASS])

    return imgs, label_batch

In [7]:
#定义多种网络结构
class Model:

    def __init__(self, class_num, pool_size, is_train):
        self.class_num = class_num
        self.pool_size = pool_size
        self.is_train = is_train

    def conv(self, x_tensor, conv_num_outputs, conv_ksize=3, conv_strides=1, conv_padding='SAME', name=None):
        res = tf.layers.conv2d(x_tensor, conv_num_outputs, conv_ksize, strides=conv_strides, padding=conv_padding,
                               kernel_initializer=tf.contrib.layers.xavier_initializer(), name=name)
        return res

    def maxpool(self, x_tensor, pool_strides=(2, 2)):
        res = tf.nn.max_pool(x_tensor, ksize=[1, self.pool_size[0], self.pool_size[1], 1],
                             strides=[1, pool_strides[0], pool_strides[1], 1], padding='SAME')
        return res

    def avgpool(self, x_tensor, pool_strides=(2, 2)):
        res = tf.nn.avg_pool(x_tensor, ksize=[1, self.pool_size[0], self.pool_size[1], 1],
                             strides=[1, pool_strides[0], pool_strides[1], 1], padding='SAME')
        return res

    def fc(self, x_tensor, num_outputs, active=None, name=None):
        std_dev = x_tensor.shape[-1].value ** -0.5
        weight = tf.Variable(tf.random_normal([x_tensor.shape[-1].value, num_outputs], stddev=std_dev))
        bias = tf.Variable(tf.zeros([num_outputs]))
        res = tf.add(tf.matmul(x_tensor, weight), bias, name=name)
        if active == 'relu':
            res = tf.nn.relu(res)
        return res

    def conv_with_batch_norm(self, X, size):
        net = self.conv(X, size)
        net = tf.layers.batch_normalization(net, training = self.is_train)
        return net

    def conv_relu(self, X, size):
        net = self.conv(X, size)
        net = tf.layers.batch_normalization(net, training=self.is_train)
        net = tf.nn.relu(net)
        return net

    def basic_residual_block(self, X, size):
        residual = tf.layers.conv2d(X, size, kernel_size = 1, padding='SAME')
        net = self.conv_with_batch_norm(X, size)
        net = tf.nn.relu(net)
        net = self.conv_with_batch_norm(net, size)
        return residual + net

    def basic_residual_block_3(self, X, size):
        residual = tf.layers.conv2d(X, size, kernel_size=1, padding='SAME')

        net = tf.layers.conv2d(X, size, kernel_size=1, padding='SAME')
        net = tf.nn.relu(net)
        net = self.conv_with_batch_norm(net, size)
        net = tf.layers.conv2d(net, size, kernel_size=1, padding='SAME')
        net = tf.nn.relu(net)
        return residual + net

    def residual_block(self, X, size, is_reduce=True):
        net = self.basic_residual_block(X, size)
        net = tf.nn.relu(net)
        net = self.basic_residual_block(net, size)
        if is_reduce:
            net = self.maxpool(net)
        net = tf.nn.relu(net)
        return net

    def residual_block_3(self, X, size, is_reduce=True):
        net = self.basic_residual_block_3(X, size)
        net = tf.nn.relu(net)
        net = self.basic_residual_block_3(net, size)
        if is_reduce:
            net = self.maxpool(net)
        net = tf.nn.relu(net)
        return net


    def Resnet_18(self, input_op):
        net = self.conv(input_op, 64, name="input_node")
        net = self.maxpool(net)
        net = tf.nn.relu(net)
        net = self.residual_block(net, 64)
        net = self.residual_block(net, 128)
#         net = self.residual_block(net, 256)
        net = self.residual_block(net, 256, is_reduce=False)
        net = self.avgpool(net)
        net = tf.contrib.layers.flatten(net)
        net = self.fc(net, 768, active = 'relu')
        logits = self.fc(net, self.class_num, name="output_node")

        return logits

    def Resnet_50(self, input_op):
        net = self.conv(input_op, 64, name="input_node")
        net = self.maxpool(net)
        net = tf.nn.relu(net)
        net = self.residual_block_3(net, 64)
        net = self.residual_block_3(net, 128)
        net = self.residual_block_3(net, 256)
        net = self.residual_block_3(net, 512, is_reduce=False)
        net = self.avgpool(net)
        net = tf.contrib.layers.flatten(net)
        net = self.fc(net, 512, active = 'relu')
        logits = self.fc(net, self.class_num, name="output_node")

        return logits


In [None]:
if __name__ == '__main__':

    tra_image_batch, tra_label_batch = read_and_decode(filename=os.path.join(tfrecords_dir, 'train.tfrecords'),
                                                       batch_size=BATCH_SIZE,
                                                       shuffle=True)
    val_image_batch, val_label_batch = read_and_decode(filename=os.path.join(tfrecords_dir, 'test.tfrecords'),
                                                       batch_size=TEST_LEN,
                                                       shuffle=True)

    log = open('./log.txt', 'w')



    X = tf.placeholder(tf.float32, [None, N_H, N_W, N_C])
    Y = tf.placeholder(tf.float32, [None, N_CLASS])
    Z = tf.placeholder(tf.bool, name='training')

    # to choose the nueral network: Resnet_18, Resnet_50
    model = Model(N_CLASS, pool_size, Z)
    logits = model.Resnet_18(X)

    cost = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
    optimizer = tf.train.RMSPropOptimizer(learn_rate).minimize(cost) #RMSPropOptimizer.AdamOptimizer

    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    saver = tf.train.Saver(max_to_keep=10)

    init = tf.global_variables_initializer()
    sess = tf.Session()
    sess.run(init)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        for step in np.arange(MAX_STEP):
            if coord.should_stop():
                break

            tra_images, tra_labels = sess.run([tra_image_batch, tra_label_batch])
            feed = {X: tra_images,
                    Y: tra_labels,
                    Z: 1}
            _, tra_loss = sess.run([optimizer, cost], feed)
            if step % 10 == 0 or (step + 1) == MAX_STEP:
                val_images, val_labels = sess.run([val_image_batch, val_label_batch])
                val_loss = sess.run(cost, feed_dict={X: val_images,
                                                     Y: val_labels,
                                                     Z: 1})
                print('Step: %d, tra_loss: %.8f, val_loss: %.8f' % (step, tra_loss, val_loss))
                print('Step: %d, tra_loss: %.8f, val_loss: %.8f' % (step, tra_loss, val_loss), file=log)

            if step % EPOCHS == 0 or (step + 1) == MAX_STEP:
                checkpoint_path = os.path.join(log_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step=step)

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()