In [2]:
##数据预处理
# 这是一个花朵识别的demo，数据集包含了5种花朵，数据都是脏数据 需要预处理
# 这里用到的数据集:http://download.tensorflow.org/example_images/flower_photos.tgz
# 注意这个数据集处理需要大概16G+的内存 否则会出现MemError，这一问题可通过减少图片数量来解决

import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

#定义输入文件夹和数据存储文件名
INPUT_DATA = '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos'
OUTPUT_FILE = '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_processed_data.npy'

#设定验证集和测试集的百分比
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10

def create_image_list(sess, testing_percentage, validation_percentage):

    #列出输入文件夹下的所有子文件夹，此时sub_dirs里面除了有子文件夹还有它自身，在第一个
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
    #设置一个bool值，指定第一次循环的时候跳过母文件夹
    is_root_dir = True
    #print(sub_dirs)

    #初始化数据矩阵
    training_images = []
    training_labels = []
    testing_images = []
    testing_labels = []
    validation_images = []
    validation_labels= []
    current_label = 0

    #分别处理每个子文件夹
    for sub_dir in sub_dirs:
        #跳过第一个值，即跳过母文件夹
        if is_root_dir:
            is_root_dir = False
            continue

        #获取子目录中的所有图片文件
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        #用列表记录所有图片文件
        file_list = []
        #获取此子目录的名字比如daisy
        dir_name = os.path.basename(sub_dir)
        #对此子目录中所有图片后缀的文件
        for extension in extensions:
            #获取每种图片的所有正则表达式
            file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
            print(file_glob)
            #将所有符合正则表达式的文件名加入文件列表
            file_list.extend(glob.glob(file_glob))
        print(file_list)
        #如果没有文件跳出循环
        if not file_list:
            continue
        #print("processing ", dir_name)

        i = 0
        #对于每张图片
        for file_name in file_list:
            i+=1
            #打开图片文件
            #print("process num : ",i,"   processing", file_name, file=f)
            image_raw_data = gfile.FastGFile(file_name,'rb').read()
            #解码
            image = tf.image.decode_jpeg(image_raw_data)
            #如果图片格式不是float32则转为float32
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            #将图片源数据转为299*299
            image = tf.image.resize_images(image, [300,300])
            #得到此图片的数据
            image_value = sess.run(image)
            print(np.shape(image_value))

            #生成一个100以内的数
            chance = np.random.randint(100)
            #按概率随机分到三个数据集中
            if chance < validation_percentage:
                validation_images.append(image_value)
                validation_labels.append(current_label)
            elif chance < (testing_percentage + validation_percentage):
                testing_images.append(image_value)
                testing_labels.append(current_label)
            else:
                training_images.append(image_value)
                training_labels.append(current_label)
            if i%100 == 0:
                print(i, "images processed.")
        #处理完此种品种就将标签+1
        current_label += 1

    #将训练数据和标签以同样的方式打乱
    state = np.random.get_state()
    np.random.shuffle(training_images)
    np.random.set_state(state)
    np.random.shuffle(training_labels)

    #返回所有数据
    return np.asarray([training_images, training_labels,
                       validation_images, validation_labels, testing_images, testing_labels])


def main():
    with tf.Session() as sess:
        processed_data = create_image_list(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
        #将数据存到文件中
        np.save(OUTPUT_FILE, processed_data)

if __name__ == "__main__":
    main()

/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/*.jpg
/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/*.jpeg
/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/*.JPG
/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/*.JPEG
['/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/410425647_4586667858.jpg', '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/3292654244_4a220ab96f_m.jpg', '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/3102535578_ec8c12a7b6_m.jpg', '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/4713531680_1110a2fa07_n.jpg', '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/2976723295_b16ab04231.jpg', '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/5333437251_ce0aa6925d_n.jpg', '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/2225411981_6638c3e988.jpg', '/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_photos/roses/387327

(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 

(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 

(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 

(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 

(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 300, 3)
(300, 

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 加载通过slim定义好的resnet_v1模型
import tensorflow.contrib.slim.python.slim.nets.resnet_v1 as resnet_v1

# 数据文件
INPUT_DATA = "/home/se7ven/Desktop/ResNet-V1-50/datasets/flower_processed_data.npy"
# 保存训练好的模型
TRAIN_FILE = "train_dir/model"
# 提供的已经训练好的模型
CKPT_FILE = "/home/se7ven/Desktop/ResNet-V1-50/datasets/resnet_v1_50.ckpt"

# 定义训练所用参数
LEARNING_RATE = 0.0001
STEPS = 500
BATCH = 32
N_CLASSES = 5

# 这里指出了不需要从训练好的模型中加载的参数，就是最后的自定义的全连接层
CHECKPOINT_EXCLUDE_SCOPES = 'Logits'
# 指定最后的全连接层为可训练的参数
TRAINABLE_SCOPES = 'Logits'


# 加载所有需要从训练好的模型加载的参数
def get_tuned_variables():
    ##不需要加载的范围
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
    # 初始化需要加载的参数
    variables_to_restore = []

    # 遍历模型中的所有参数
    for var in slim.get_model_variables():
        # 先指定为不需要移除
        excluded = False
        # 遍历exclusions，如果在exclusions中，就指定为需要移除
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        # 如果遍历完后还是不需要移除，就把参数加到列表里
        if not excluded:
            variables_to_restore.append(var)
    return variables_to_restore


# 获取所有需要训练的参数
def get_trainable_variables():
    # 同上
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
    variables_to_train = []
    # 枚举所有需要训练的参数的前缀，并找到这些前缀的所有参数
    for scope in scopes:
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train


def main():
    
    # 这里是因为numpy的版本更新了 如果不allow_pickle=True会报错，于是先打开，再改回来
    np_load_old = np.load
    # modify the default parameters of np.load
    np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)
    
    # 加载数据
    processed_data = np.load(INPUT_DATA)
    
    # restore np.load for future normal usage
    np.load = np_load_old
    
    training_images = processed_data[0]
    n_training_example = len(training_images)
    training_labels = processed_data[1]
    validation_images = processed_data[2]
    validation_labels = processed_data[3]
    testing_images = processed_data[4]
    testing_labels = processed_data[5]

    print("there is %d training examples, %d validation examples, %d testing examples" %
          (n_training_example, len(validation_labels), len(testing_labels)))

    # 定义数据格式
    images = tf.placeholder(tf.float32, [None, 300, 300, 3], name='input_images')
    labels = tf.placeholder(tf.int64, [None], name='labels')

    # 定义模型，因为给出的只有参数，并没有模型，这里需要指定模型的具体结构
    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
        # logits就是最后预测值，images就是输入数据，指定num_classes=None是为了使resnet模型最后的输出层禁用
        logits, _ = resnet_v1.resnet_v1_50(images, num_classes=None)

    #自定义的输出层
    with tf.variable_scope("Logits"):
        #将原始模型的输出数据去掉维度为2和3的维度，最后只剩维度1的batch数和维度4的300*300*3
        #也就是将原来的二三四维度全部压缩到第四维度
        net = tf.squeeze(logits, axis=[1,2])
        #加入一层dropout层
        net = slim.dropout(net, keep_prob=0.5,scope='dropout_scope')
        #加入一层全连接层，指定最后输出大小
        logits = slim.fully_connected(net, num_outputs=N_CLASSES, scope='fc')


    # 获取需要训练的变量
    trainable_variables = get_trainable_variables()

    # 定义损失，模型定义的时候已经考虑了正则化了
    tf.losses.softmax_cross_entropy(tf.one_hot(labels, N_CLASSES), logits, weights=1.0)
    # 定义训练过程
    train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())

    # 定义测试和验证过程
    with tf.name_scope('evaluation'):
        correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
        evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # 定义加载模型的函数，就是重新定义load_fn函数，从文件中获取参数，获取指定的变量，忽略缺省值
    load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE, get_tuned_variables(), ignore_missing_vars=True)

    # 定义保存新的训练好的模型的函数
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # 初始化没有加载进来的变量，一定要在模型加载之前，否则会将训练好的参数重新赋值
        init = tf.global_variables_initializer()
        sess.run(init)

        # 加载训练好的模型
        print('Loading tuned variables from %s' % CKPT_FILE)
        load_fn(sess)

        start = 0
        end = BATCH
        for i in range(STEPS):
            # 训练...
            sess.run(train_step, feed_dict={images: training_images[start:end],
                                            labels: training_labels[start:end]})
            # 间断地保存模型，并在验证集上验证
            if i % 50 == 0 or i + 1 == STEPS:
                saver.save(sess, TRAIN_FILE, global_step=i)
                validation_accuracy = sess.run(evaluation_step, feed_dict={images: validation_images,
                                                                           labels: validation_labels})
                print('Step %d: Validation accuracy = %.1f%%' % (
                    i, validation_accuracy * 100.0))

            # 更新起始和末尾,获取下一个batch
            start = end
            if start == n_training_example:
                start = 0
            end = start + BATCH
            if end > n_training_example:
                end = n_training_example

        # 训练完了在测试集上测试正确率
        testing_accuracy = sess.run(evaluation_step, feed_dict={images: testing_images,
                                                                labels: testing_labels})
        print('Final test accuracy = %.1f%%' % (testing_accuracy * 100))


if __name__ == '__main__':
    main()


there is 1215 training examples, 144 validation examples, 141 testing examples


W0829 23:02:37.941703 140606184159040 deprecation.py:323] From /home/se7ven/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0829 23:02:39.068042 140606184159040 deprecation.py:506] From /home/se7ven/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/rmsprop.py:119: calling Ones.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0829 23:02:43.354056 140606184159040 deprecation.py:323] From /home/se7ven/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensor

Loading tuned variables from /home/se7ven/Desktop/ResNet-V1-50/datasets/resnet_v1_50.ckpt
Step 0: Validation accuracy = 20.8%
Step 50: Validation accuracy = 27.1%
Step 100: Validation accuracy = 72.9%
Step 150: Validation accuracy = 91.7%
Step 200: Validation accuracy = 95.1%


W0829 23:04:58.831627 140606184159040 deprecation.py:323] From /home/se7ven/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/training/saver.py:960: remove_checkpoint (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to delete files with this prefix.


Step 250: Validation accuracy = 95.1%
Step 300: Validation accuracy = 93.1%
Step 350: Validation accuracy = 93.1%
Step 400: Validation accuracy = 91.0%
Step 450: Validation accuracy = 93.8%
Step 499: Validation accuracy = 91.7%
Final test accuracy = 92.9%
