In [None]:
#将数据集转为tfrecord
import os

train_captcha_dir = './baidu/'

train_filenames = []
train_labels = []

for filename in os.listdir(train_captcha_dir):
    train_filenames.append(train_captcha_dir + filename)     # 建立图片路径索引
    train_labels.append(filename.split('_')[1].split('.')[0].encode('utf-8'))# 图片多标签

In [None]:
import tensorflow as tf

# 图片转为TFRecord 
tfrecord_file = './captcha.tfrecords'

with tf.io.TFRecordWriter(tfrecord_file) as writer:              # 准备 TFRecord 文件
    for filename, label in zip(train_filenames, train_labels):
        image = open(filename, 'rb').read()                       # 读取数据集图片到内存，image 为一个 Byte 类型的字符串
        feature = {                                               # 建立 tf.train.Feature 字典
            'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 图片是一个 Bytes 对象
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))   # 标签也是一个 Bytes 对象
        }
        example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通过字典建立 Example
        writer.write(example.SerializeToString())                 # 将Example序列化并写入 TFRecord 文件

In [None]:
# 获得数据集
tfrecord_file = './captcha.tfrecords'
raw_dataset = tf.data.TFRecordDataset(tfrecord_file)    # 读取 TFRecord 文件

feature_description = {                                 # 定义Feature结构，告诉解码器每个Feature的类型是什么
    'image': tf.io.FixedLenFeature([], tf.string),
    'label': tf.io.FixedLenFeature([], tf.string)
}

def _parse_example(example_string):                    # 将 TFRecord 文件中的每一个序列化的 tf.train.Example 解码
    feature_dict = tf.io.parse_single_example(example_string, feature_description)
    feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'], channels=1)    # 解码JPEG图片
    return feature_dict['image'], feature_dict['label']

dataset = raw_dataset.map(_parse_example)

In [None]:
for data,label in dataset.shuffle(200).batch(200):#  打乱数据
    x_train = data
    y_train = label

In [None]:
import numpy as np

x_train = tf.cast(x_train,tf.float32)/255.0 #归一化
y_onehot = [[],[],[],[]] #shape(4，batch, 36)

for item in y_train:
    nums = []
    chars = str(item.numpy(),'utf-8')#bytes转字符串
    for char in chars:
        if ord(char) > 60:
            nums.append(ord(char)-87) #字符转ascii码(0:48,A:65,a:97),并排序为0-35
        else:
            nums.append(ord(char)-48)
   
    for i,char in enumerate(nums):
        y_onehot[i].append(tf.one_hot(char, 36).numpy())# 0-35转onehot

y_onehot = tf.cast(y_onehot, tf.int32)
y_label = [np.array(item) for item in y_onehot]

In [None]:
# 定义网络结构
from tensorflow import keras
from tensorflow.keras import layers

n_class=36
n_len=4
width=80
height=40

# 创建一个输入节点
img_inputs = keras.Input(shape=(height, width, 1), name='img')
x = img_inputs
# 定义多层，创建CNN网络，卷积两次池化一次，循环5次后拉伸，输出n_len个分类
for i, n_cnn in enumerate([2, 2, 2, 2, 2]):
    for j in range(n_cnn):
        x = layers.Conv2D(32*2**min(i, 3), kernel_size=3, padding='same', kernel_initializer='he_uniform')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
    x = layers.MaxPooling2D(2)(x)

x = layers.Flatten()(x)
x = [layers.Dense(n_class, activation='softmax', name='c%d'%(i+1))(x) for i in range(n_len)]

model = keras.Model(inputs=img_inputs, outputs=x, name='captcha_multilabel_class')# 通过在图层图中指定其输入和输出来创建一个model

model.summary() # 查看模型摘要，需要模型built（实例化）后调用

In [None]:
# 训练模型
from tensorflow.keras.callbacks import EarlyStopping, CSVLogger, ModelCheckpoint
from tensorflow.keras.optimizers import *

callbacks = [EarlyStopping(patience=3), CSVLogger('./captcha.csv'),
             ModelCheckpoint('./captcha.h5', save_best_only=True)]

model.compile(loss='categorical_crossentropy', optimizer=Adam(1e-3, amsgrad=True), metrics=['accuracy'])

model.fit(x_train, y_label, epochs=20, callbacks=callbacks, batch_size=10)

In [None]:
# 预测
import string
import cv2 as cv

characters = string.digits + string.ascii_lowercase
#print(characters)

# y，独热编码
def decode(y):
    y = np.argmax(np.array(y), axis=2)[:,0]
    return ''.join([characters[x] for x in y])

x_test = tf.reshape(x_train[0],[1,40,80,1])
y_pred = model.predict(x_test)
#print(decode(y_pred))

# 显示(40,80,1)数组图片
cv.imshow(str(decode(y_pred)), x_train[0].numpy())
cv.waitKey(0)
cv.destroyAllWindows()