In [1]:
import os
import tensorflow as tf
import numpy as np
import tensorflow.contrib.slim as slim

In [15]:
#数据路径
dataset_dir = './captcha/'

#测试数据集占比
num_test = 0.2
#批次大小
batch_size = 32
#周期大小
epochs = 100
#分类数
num_classes = 10
#学习率
lr = tf.Variable(0.001,dtype=tf.float32)
#是否为训练状态
is_training = tf.placeholder(tf.bool)


In [16]:
#获取所有验证码图片路径和标签
def get_filename_and_classes(dataset_dir):
    photo_filenames = []
    labels = []
    for filename in os.listdir(dataset_dir):
        #获取文件路径
        path = os.path.join(dataset_dir,filename)
        photo_filenames.append(path)
        label = filename[0:4]
        num_labels = []
        for i in range(4):
            num_labels.append(int(label[i]))
        labels.append(num_labels)
    return photo_filenames,labels


In [17]:
#获取图片路径和标签
photo_filenames,labels = get_filename_and_classes(dataset_dir)
photo_filenames = np.array(photo_filenames)
labels = np.array(labels)


In [18]:
#打乱数据
np.random.seed(10)
shuffle_indices = np.random.permutation(np.arange(len(photo_filenames)))
photo_filenames_shuffled = photo_filenames[shuffle_indices]
labels_shuffled = labels[shuffle_indices]

In [25]:
#切分数据集
test_sample_index = -1 * int(num_test * float(len(photo_filenames)))
x_train,x_test = photo_filenames_shuffled[:test_sample_index],photo_filenames_shuffled[test_sample_index:]
y_train,y_test = labels_shuffled[:test_sample_index],labels_shuffled[test_sample_index:]

In [26]:
#图像处理函数
def parse_function(filenames,labels = None):
    #读取图片并解码
    image = tf.read_file(filenames)
    image = tf.image.decode_jpeg(image,channels=3)
    #resize
    image = tf.image.resize_images(image,[224,224])
    
    image = tf.cast(image,tf.float32) / 255.0
    image = tf.subtract(image,0.5)
    image = tf.multiply(image,2.0)
    
    return image,labels

In [27]:
#定义placeholder
features_placeholder = tf.placeholder(photo_filenames_shuffled.dtype,[None])
labels_placeholder = tf.placeholder(labels_shuffled.dtype,[None,4])

#创建dataset对象()
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder,labels_placeholder))
#预处理图片
dataset = dataset.map(parse_function)
#训练周期
dataset = dataset.repeat(1)
#批次大小
dataset = dataset.batch(batch_size)

#初始化迭代器
iterator = dataset.make_initializable_iterator()

#获得一个批次数据和标签
data_batch,label_batch = iterator.get_next()

In [28]:
def alexnet(inputs,is_training=True):
    with slim.arg_scope([slim.conv2d,slim.fully_connected],
                       activation_fn = tf.nn.relu,
                       weights_initializer = tf.glorot_uniform_initializer(),
                       biases_initializer = tf.constant_initializer(0)):
        
        net = slim.conv2d(inputs,64,[11,11],4)
        net = slim.max_pool2d(net,[3,3])
        net = slim.conv2d(net,192,[5,5])
        net = slim.max_pool2d(net,[3,3])
        net = slim.conv2d(net,384,[3,3])
        net = slim.conv2d(net,384,[3,3])
        net = slim.conv2d(net,256,[3,3])
        net = slim.max_pool2d(net,[3,3])
        
        #数据扁平化
        net = slim.flatten(net)
        net = slim.fully_connected(net,1024)
        net = slim.dropout(net,is_training=is_training)
        
        #分为四项输出
        net0 = slim.fully_connected(net,num_classes,activation_fn=tf.nn.softmax)
        net1 = slim.fully_connected(net,num_classes,activation_fn=tf.nn.softmax)
        net2 = slim.fully_connected(net,num_classes,activation_fn=tf.nn.softmax)
        net3 = slim.fully_connected(net,num_classes,activation_fn=tf.nn.softmax)
        
    return net0,net1,net2,net3

In [29]:
with tf.Session() as sess:
    #传入数据得到结果
    logits0,logits1,logits2,logits3 = alexnet(data_batch,is_training)
    #定义loss
    #sparse_softmax_cross_entropyy：标签为整数
    # softmax_cross_entripy:标签为one-hot独热编码
    loss0 = tf.losses.sparse_softmax_cross_entropy(label_batch[:,0],logits0)
    loss1 = tf.losses.sparse_softmax_cross_entropy(label_batch[:,1],logits1)
    loss2 = tf.losses.sparse_softmax_cross_entropy(label_batch[:,2],logits2)
    loss3 = tf.losses.sparse_softmax_cross_entropy(label_batch[:,3],logits3)
    #计算总的loss
    total_loss = (loss0+loss1+loss2+loss3)/4.0
    #优化loss
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(total_loss)
    
    #计算准确率
    correct0 = tf.nn.in_top_k(logits0,label_batch[:,0],1)
    accuracy0 = tf.reduce_mean(tf.cast(correct0,tf.float32))
    correct1 = tf.nn.in_top_k(logits1,label_batch[:,1],1)
    accuracy1 = tf.reduce_mean(tf.cast(correct1,tf.float32))
    correct2 = tf.nn.in_top_k(logits2,label_batch[:,2],1)
    accuracy2 = tf.reduce_mean(tf.cast(correct2,tf.float32))
    correct3 = tf.nn.in_top_k(logits3,label_batch[:,3],1)
    accuracy3 = tf.reduce_mean(tf.cast(correct3,tf.float32))
    #总的准确率
    total_correct = tf.cast(correct0,tf.float32)*tf.cast(correct1,tf.float32)*tf.cast(correct2,tf.float32)*tf.cast(correct3,tf.float32)
    total_accuracy = tf.reduce_mean(tf.cast(total_correct,tf.float32))
    
    #所有变量初始化
    sess.run(tf.global_variables_initializer())
    #定义saver保存模型
    saver = tf.train.Saver()
    
    #训练一个epochs周期
    for i in range(epochs):
        if i%30 == 0:
            sess.run(tf.assign(lr,lr/3))
        #训练集传入迭代器
        sess.run(iterator.initializer,feed_dict={features_placeholder:x_train,
                                                labels_placeholder:y_train})
        
        #训练模型
        while True:
            try:
                sess.run(optimizer,feed_dict={is_training:True})
            except tf.errors.OutOfRangeError:
                #所有数据训练完成后跳出循环
                print('第%d个批次训练完成！'%i)
                break
        #测试集放入迭代器中
        sess.run(iterator.initializer,feed_dict={features_placeholder:x_test,
                                                labels_placeholder:y_test})
        #测试结果
        while True:
            try:
                #获得准确率和loss
                acc0,acc1,acc2,acc3,total_acc,result_loss = \
                    sess.run([accuracy0,accuracy1,accuracy2,accuracy3,total_accuracy,total_loss],feed_dict={is_training:False})

                #loss值统计
                tf.add_to_collection('sum_losses',result_loss)
                #准确率统计
                tf.add_to_collection('accuracy0',acc0)
                tf.add_to_collection('accuracy1',acc1)
                tf.add_to_collection('accuracy2',acc2)
                tf.add_to_collection('accuracy3',acc3)
                tf.add_to_collection('total_acc',total_acc)
            except tf.errors.OutOfRangeError:
                #loss值求平均
                avg_loss = sess.run(tf.reduce_mean(tf.get_collection('sum_losses')))
                #准确率求平均
                avg_acc0 = sess.run(tf.reduce_mean(tf.get_collection('accuracy0')))
                avg_acc1 = sess.run(tf.reduce_mean(tf.get_collection('accuracy1')))
                avg_acc2 = sess.run(tf.reduce_mean(tf.get_collection('accuracy2')))
                avg_acc3 = sess.run(tf.reduce_mean(tf.get_collection('accuracy3')))
                avg_total_acc = sess.run(tf.reduce_mean(tf.get_collection('total_acc')))
                print('%d:loss=%.3f acc0=%.3f acc1=%.3f acc2=%.3f acc3=%.3f total_acc=%.3f'%
                     (i,avg_loss,avg_acc0,avg_acc1,avg_acc2,avg_acc3,avg_total_acc))
                
                #清空loss统计
                temp = tf.get_collection_ref('sum_losses')
                del temp[:]
                
                #清空准确率统计
                temp = tf.get_collection_ref('accuracy0')
                del temp[:]
                temp = tf.get_collection_ref('accuracy1')
                del temp[:]
                temp = tf.get_collection_ref('accuracy2')
                del temp[:]
                temp = tf.get_collection_ref('accuracy3')
                del temp[:]
                temp = tf.get_collection_ref('total_acc')
                del temp[:]
                break
    #保存模型
    saver.save(sess,'models/model.ckpt',global_step=epochs)
            

第0个批次训练完成！
0:loss=2.303 acc0=0.100 acc1=0.095 acc2=0.095 acc3=0.089 total_acc=0.000
第1个批次训练完成！
1:loss=2.258 acc0=0.179 acc1=0.159 acc2=0.191 acc3=0.218 total_acc=0.000
第2个批次训练完成！
2:loss=2.199 acc0=0.270 acc1=0.213 acc2=0.220 acc3=0.295 total_acc=0.004
第3个批次训练完成！
3:loss=2.163 acc0=0.337 acc1=0.247 acc2=0.264 acc3=0.286 total_acc=0.005
第4个批次训练完成！
4:loss=2.129 acc0=0.374 acc1=0.313 acc2=0.302 acc3=0.317 total_acc=0.014
第5个批次训练完成！
5:loss=2.054 acc0=0.455 acc1=0.346 acc2=0.365 acc3=0.430 total_acc=0.018
第6个批次训练完成！
6:loss=2.007 acc0=0.531 acc1=0.406 acc2=0.407 acc3=0.462 total_acc=0.035
第7个批次训练完成！
7:loss=1.953 acc0=0.599 acc1=0.460 acc2=0.463 acc3=0.494 total_acc=0.058
第8个批次训练完成！
8:loss=1.914 acc0=0.636 acc1=0.513 acc2=0.501 acc3=0.541 total_acc=0.093
第9个批次训练完成！
9:loss=1.896 acc0=0.647 acc1=0.545 acc2=0.488 acc3=0.568 total_acc=0.108
第10个批次训练完成！
10:loss=1.855 acc0=0.686 acc1=0.544 acc2=0.538 acc3=0.651 total_acc=0.140
第11个批次训练完成！
11:loss=1.850 acc0=0.687 acc1=0.581 acc2=0.560 acc3=0.613 tota

KeyboardInterrupt: 