In [27]:
import tensorflow as tf
from pathlib import Path
import numpy as np
import random
import json
import time
import os

In [2]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
work_path = '/root/python_project/crnn_by_tensorflow2.2.0/'
test_path = work_path + 'dataset/test/'

字符集

In [3]:
table_path = work_path + "dataset/table.txt"
json_path = work_path + "dataset/char.json"
with open(json_path,'r') as f:
    chardic = json.load(f)
with open(table_path,'w') as fw:
    for char in chardic:
        fw.write(char+'\n')
num_classes = len(chardic) + 3
print('字符数：', num_classes)

字符数： 5531


In [4]:
table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(table_path, 
                                                                tf.string, 
                                                                tf.lookup.TextFileIndex.WHOLE_LINE, 
                                                                tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER), num_classes-2)

数据预处理方法

In [37]:
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [32, 280])
    return image

def load_and_preprocess_image(path,label):
    image = tf.io.read_file(path)
    return preprocess_image(image),label

def load_and_preprocess_image_pridict(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)

def decode_label(img,label):
    chars = tf.strings.unicode_split(label, "UTF-8")
    tokens = tf.ragged.map_flat_values(table.lookup, chars)
    tokens = tokens.to_sparse()
    return img,tokens

图片路径列表及其对应标签列表

In [6]:
def get_image_path(dir_path):
    '''
    获取图片路径列表,及其标签列表
    '''
    images  = []
    train_all_image_paths = []
    train_all_image_labels = []
    val_all_image_paths = []
    val_all_image_labels = []
    for root, dirs, files in os.walk(dir_path):
        for file in files:
            if '.jpg' in file:
                file_path = os.path.join(root, file)
                label_path = file_path.replace('.jpg','.txt')
                if Path(file_path.replace('.jpg','.txt')).exists():
                    with open(label_path) as f:
                        label = f.read().strip()
                    if len(label)<70 and len(label)>0:
                        images.append((file_path, label))
    random.shuffle(images)
    for image,label in images:
        random_num = random.randint(1,80)
        if random_num == 5:
            val_all_image_paths.append(image)
            val_all_image_labels.append(label)
        else:
            train_all_image_paths.append(image)
            train_all_image_labels.append(label)
    return train_all_image_paths, train_all_image_labels,val_all_image_paths,val_all_image_labels

In [7]:
#训练数据集、验证数据集
train_all_image_paths, train_all_image_labels,val_all_image_paths, val_all_image_labels = get_image_path(work_path+'dataset/train/')
print(len(train_all_image_paths),len(train_all_image_labels),len(val_all_image_paths),len(val_all_image_labels))

7628393 7628393 96497 96497


训练：tf.data.Dataset

In [38]:
BATCH_SIZE = 256
buffer_size = 10000
train_images_num = len(train_all_image_paths)
train_steps_per_epoch = train_images_num//BATCH_SIZE
train_ds = tf.data.Dataset.from_tensor_slices((train_all_image_paths, train_all_image_labels))
train_ds = train_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.shuffle(buffer_size=buffer_size)
train_ds = train_ds.repeat()
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.map(decode_label, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.apply(tf.data.experimental.ignore_errors())
train_ds = train_ds.prefetch(AUTOTUNE)

验证：tf.data.Dataset

In [39]:
val_images_num = len(val_all_image_paths)
val_steps_per_epoch = val_images_num//BATCH_SIZE
val_ds = tf.data.Dataset.from_tensor_slices((val_all_image_paths, val_all_image_labels))
val_ds = val_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.shuffle(buffer_size=buffer_size)
val_ds = val_ds.repeat()
val_ds = val_ds.batch(BATCH_SIZE)
val_ds = val_ds.map(decode_label, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.apply(tf.data.experimental.ignore_errors())
val_ds = val_ds.prefetch(AUTOTUNE)

模型定义


In [40]:
model = tf.keras.Sequential([
        #tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same',activation='relu',input_shape=(32,560,3)),
        tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same',activation='relu'),
        tf.keras.layers.MaxPool2D(pool_size=(2, 2), padding='valid'),
        tf.keras.layers.Conv2D(filters=128, kernel_size=3,padding='same',activation='relu'),
        tf.keras.layers.MaxPool2D(pool_size=(2, 2), padding='valid'),
        tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same'),
        tf.keras.layers.BatchNormalization(epsilon=1e-05,axis=1, momentum=0.1),
        tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same'),
        tf.keras.layers.ZeroPadding2D(padding=(0, 1)),
        tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 1), padding='valid'),
        tf.keras.layers.Conv2D(filters=512, kernel_size=3, padding='same',activation='relu'),
        tf.keras.layers.BatchNormalization(epsilon=1e-05,axis=1, momentum=0.1),
        tf.keras.layers.Conv2D(filters=512, kernel_size=3, padding='same',activation='relu'),
        tf.keras.layers.ZeroPadding2D(padding=(0, 1)),
        tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 1), padding='valid'),
        tf.keras.layers.Conv2D(filters=512, kernel_size=2, padding='valid', activation='relu'),
        tf.keras.layers.BatchNormalization(epsilon=1e-05,axis=1, momentum=0.1),
        tf.keras.layers.Reshape((-1, 512)),
        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=256, return_sequences=True, use_bias=True,recurrent_activation='sigmoid')),
        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=256, return_sequences=True, use_bias=True,recurrent_activation='sigmoid')),
        tf.keras.layers.Dense(units=num_classes)
])

In [41]:
#model.summary()

CTC损失函数

In [59]:
class CTCLoss(tf.keras.losses.Loss):
    def __init__(self, logits_time_major=False, blank_index=-1, 
                 reduction=tf.keras.losses.Reduction.AUTO, name='ctc_loss'):
        super().__init__(reduction=reduction, name=name)
        self.logits_time_major = logits_time_major
        self.blank_index = blank_index

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.int32)
        logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])
        loss = tf.nn.ctc_loss(
            labels=y_true,
            logits=y_pred,
            label_length=None,
            logit_length=logit_length,
            logits_time_major=self.logits_time_major,
            blank_index=self.blank_index
            )
        return tf.reduce_mean(loss)

In [60]:
class WordAccuracy(tf.keras.metrics.Metric):
    """
    Calculate the word accuracy between y_true and y_pred.
    """
    def __init__(self, name='word_accuracy', **kwargs):
        super().__init__(name=name, **kwargs)
        self.total = self.add_weight(name='total', dtype=tf.int32, 
                                     initializer=tf.zeros_initializer())
        self.count = self.add_weight(name='count', dtype=tf.int32, 
                                     initializer=tf.zeros_initializer())

    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Maybe have more fast implementation.
        """
        b = tf.shape(y_true)[0]
        max_width = tf.maximum(tf.shape(y_true)[1], tf.shape(y_pred)[1])
        logit_length = tf.fill([tf.shape(y_pred)[0]], tf.shape(y_pred)[1])        
        decoded, _ = tf.nn.ctc_greedy_decoder(
            inputs=tf.transpose(y_pred, perm=[1, 0, 2]),
            sequence_length=logit_length)
        y_true = tf.sparse.reset_shape(y_true, [b, max_width])
        y_pred = tf.sparse.reset_shape(decoded[0], [b, max_width])
        y_true = tf.sparse.to_dense(y_true, default_value=-1)
        y_pred = tf.sparse.to_dense(y_pred, default_value=-1)
        y_true = tf.cast(y_true, tf.int32)
        y_pred = tf.cast(y_pred, tf.int32)
        values = tf.math.reduce_any(tf.math.not_equal(y_true, y_pred), axis=1)
        values = tf.cast(values, tf.int32)
        values = tf.reduce_sum(values)
        self.total.assign_add(b)
        self.count.assign_add(b - values)

    def result(self):
        return self.count / self.total

    def reset_states(self):
        self.count.assign(0)
        self.total.assign(0)

加载已保存模型

In [70]:
model = tf.keras.models.load_model(work_path + 'output/crnn_4.h5', compile=False)

模型编译

In [71]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss=CTCLoss(), metrics=[WordAccuracy()])

显示使用GPU数量

In [72]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

Num GPUs Available:  1


In [73]:
callbacks = [tf.keras.callbacks.ModelCheckpoint(work_path + 'output/crnn_{epoch}.h5',monitor='val_loss',verbose=1),
             tf.keras.callbacks.TensorBoard(log_dir=work_path + "logs/{}".format(time.asctime()))]

加载模型继续训练

In [75]:
model.fit(train_ds, 
          epochs=20, 
          steps_per_epoch=train_steps_per_epoch,
          validation_data=val_ds,
          validation_steps=val_steps_per_epoch,
          initial_epoch=4,
          callbacks = callbacks,
          workers=8)

Train for 29798 steps, validate for 376 steps
Epoch 5/20
Epoch 00005: saving model to /root/python_project/jupyter_file/output/crnn_5.h5
Epoch 6/20
Epoch 00006: saving model to /root/python_project/jupyter_file/output/crnn_6.h5

KeyboardInterrupt: 