In [1]:
import tensorflow as tf
from models.crnn import model
from models.ctc_loss import CTCLoss
from models.accuracy import WordAccuracy
from models.config import BATCH_SIZE, BUFFER_SIZE, WORK_PATH
from models.data_prepare import load_and_preprocess_image, decode_label, get_image_path
import numpy as np
import json
import time
import os

In [2]:
print("当前可用GPU数量： ", len(tf.config.experimental.list_physical_devices('GPU')))

当前可用GPU数量：  0


# 一、数据集准备

## 1、获取并划分训练集、验证集 

In [5]:
train_all_image_paths, train_all_image_labels,val_all_image_paths, val_all_image_labels = get_image_path(WORK_PATH+'dataset/train/')
print(len(train_all_image_paths),len(train_all_image_labels),len(val_all_image_paths),len(val_all_image_labels))

数据集加载完毕！
980 980 12 12


## 2、训练集数据预处理

In [None]:

train_images_num = len(train_all_image_paths)
train_steps_per_epoch = train_images_num//BATCH_SIZE
train_ds = tf.data.Dataset.from_tensor_slices((train_all_image_paths, train_all_image_labels))
train_ds = train_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.shuffle(buffer_size=BUFFER_SIZE)
train_ds = train_ds.repeat()
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.map(decode_label, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.apply(tf.data.experimental.ignore_errors())
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

## 3、验证集数据预处理

In [None]:
val_images_num = len(val_all_image_paths)
val_steps_per_epoch = val_images_num//BATCH_SIZE
val_ds = tf.data.Dataset.from_tensor_slices((val_all_image_paths, val_all_image_labels))
val_ds = val_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_ds = val_ds.shuffle(buffer_size=BUFFER_SIZE)
val_ds = val_ds.repeat()
val_ds = val_ds.batch(BATCH_SIZE)
val_ds = val_ds.map(decode_label, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_ds = val_ds.apply(tf.data.experimental.ignore_errors())
val_ds = val_ds.prefetch(tf.data.experimental.AUTOTUNE)

# 二、模型训练

## 1、模型结构

加载已保存模型

In [None]:
model = tf.keras.models.load_model(WORK_PATH + 'output/crnn_30.h5', compile=False)

In [None]:
model.summary()

## 2、模型编译

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
              loss=CTCLoss(), metrics=[WordAccuracy()])

## 3、配置回调函数

In [None]:
callbacks = [tf.keras.callbacks.ModelCheckpoint(WORK_PATH + 'output/crnn_{epoch}.h5',monitor='val_loss',verbose=1)]

## 4、模型训练

In [None]:
model.fit(train_ds, 
          epochs=20, 
          steps_per_epoch=train_steps_per_epoch,
          validation_data=val_ds,
          validation_steps=val_steps_per_epoch,
          initial_epoch=0,
          callbacks = callbacks)