# Chinese OCR simple testing

import packages

In [125]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from alfred.utils.log import logger as logging



## load datasets
load data from directory "dataset". The format is stored as tfrecord.

In [126]:
def parse_example_v2(record):
    """
    latest version format
    :param record:
    :return:
    """
    features = tf.io.parse_single_example(record,
                                          features={
                                              'width':
                                                  tf.io.FixedLenFeature([], tf.int64),
                                              'height':
                                                  tf.io.FixedLenFeature([], tf.int64),
                                              'label':
                                                  tf.io.FixedLenFeature([], tf.int64),
                                              'image':
                                                  tf.io.FixedLenFeature([], tf.string),
                                          })
    img = tf.io.decode_raw(features['image'], out_type=tf.uint8)
    # we can not reshape since it stores with original size
    w = features['width']
    h = features['height']
    img = tf.cast(tf.reshape(img, (w, h)), dtype=tf.float32)
    label = tf.cast(features['label'], tf.int64)
    return {'image': img, 'label': label}



def load_ds(filedir):
    input_files = [filedir]
    ds = tf.data.TFRecordDataset(input_files)
    ds = ds.map(parse_example_v2)
    return ds

In [127]:
train = load_ds('dataset/train.tfrecord') # read train.tfrecord
test = load_ds('dataset/test.tfrecord')

In [128]:
train
test
train_mapped = train.shuffle(100).batch(32).repeat()
train_mapped


<RepeatDataset element_spec={'image': TensorSpec(shape=(None, None, None), dtype=tf.float32, name=None), 'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}>

## Training

In [129]:
# some arguments
target_size = 64
num_classes = 7356
# use_keras_fit = False
use_keras_fit = True
ckpt_path = './checkpoints/cn_ocr-{epoch}.ckpt'
this_dir = os.path.dirname(os.path.abspath("__file__"))
train_path = 'dataset/train.tfrecord'
test_path = 'dataset/test.tfrecord'

'f:\\OCR'

In [130]:
def load_characters():
    a = open(os.path.join(this_dir, 'dataset\\characters.txt'), 'r').readlines()
    return [i.strip() for i in a]

def preprocess(x):
    """
    minus mean pixel or normalize?
    """
    # original is 64x64, add a channel dim
    x['image'] = tf.expand_dims(x['image'], axis=-1)
    x['image'] = tf.image.resize(x['image'], (target_size, target_size))
    x['image'] = (x['image'] - 128.) / 128.
    return x['image'], x['label']



In [131]:
# model
# some simple models
def build_net_001(input_shape, n_classes):
    assert len(input_shape) == 3, 'only support 3 channels'
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Conv2D(
        input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
        padding='valid', activation='relu'))
    model.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2)))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(32, activation='relu'))
    model.add(tf.keras.layers.Dense(n_classes, activation='softmax'))
    return model


def build_net_002(input_shape, n_classes):
    model = tf.keras.Sequential([
        layers.Conv2D(input_shape=input_shape, filters=64, kernel_size=(3, 3), strides=(1, 1),
                      padding='same', activation='relu'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=256, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),

        layers.Flatten(),
        layers.Dense(1024, activation='relu'),
        layers.Dense(n_classes, activation='softmax')
    ])
    return model


# this model is converge in terms of chinese characters classification
# so simply is effective sometimes, adding a dense maybe model will be better?
def build_net_003(input_shape, n_classes):
    model = tf.keras.Sequential([
        layers.Conv2D(input_shape=input_shape, filters=32, kernel_size=(3, 3), strides=(1, 1),
                      padding='same', activation='relu'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),
        layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same'),
        layers.MaxPool2D(pool_size=(2, 2), padding='same'),

        layers.Flatten(),
        # layers.Dense(1024, activation='relu'),
        layers.Dense(n_classes, activation='softmax')
    ])
    return model

In [134]:

def train():
    all_characters = load_characters()
    num_classes = len(all_characters)
    logging.info('all characters: {}'.format(num_classes))
    train_dataset = load_ds(train_path)
    train_dataset = train_dataset.shuffle(100).map(preprocess).batch(32).repeat()

    val_ds = load_ds(test_path)
    val_ds = val_ds.shuffle(100).map(preprocess).batch(32).repeat()
    
    for data in train_dataset.take(2):
        print(data)

    # init model
    model = build_net_003((64, 64, 1), num_classes)
    model.summary()
    logging.info('model loaded.')

    start_epoch = 0
    latest_ckpt = tf.train.latest_checkpoint(os.path.dirname(ckpt_path))
    if latest_ckpt:
        start_epoch = int(latest_ckpt.split('-')[1].split('.')[0])
        model.load_weights(latest_ckpt)
        logging.info('model resumed from: {}, start at epoch: {}'.format(latest_ckpt, start_epoch))
    else:
        logging.info('passing resume since weights not there. training from scratch')

    if use_keras_fit:
        model.compile(
            optimizer=tf.keras.optimizers.Adam(),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            metrics=['accuracy'])
        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(ckpt_path,
                                               save_weights_only=True,
                                               verbose=1,
                                               period=500)
        ]
        try:
            model.fit(
                train_dataset,
                validation_data=val_ds,
                validation_steps=1000,
                epochs=15000,
                steps_per_epoch=1024,
                callbacks=callbacks)
        except KeyboardInterrupt:
            model.save_weights(ckpt_path.format(epoch=0))
            logging.info('keras model saved.')
        model.save_weights(ckpt_path.format(epoch=0))
        model.save(os.path.join(os.path.dirname(ckpt_path), 'cn_ocr.h5'))
    else:
        loss_fn = tf.losses.SparseCategoricalCrossentropy()
        optimizer = tf.optimizers.RMSprop()

        train_loss = tf.metrics.Mean(name='train_loss')
        train_accuracy = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')

        for epoch in range(start_epoch, 120):
            try:
                for batch, data in enumerate(train_dataset):
                    # images, labels = data['image'], data['label']
                    images, labels = data
                    with tf.GradientTape() as tape:
                        predictions = model(images)
                        loss = loss_fn(labels, predictions)
                    gradients = tape.gradient(loss, model.trainable_variables)
                    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                    train_loss(loss)
                    train_accuracy(labels, predictions)
                    if batch % 10 == 0:
                        logging.info('Epoch: {}, iter: {}, loss: {}, train_acc: {}'.format(
                            epoch, batch, train_loss.result(), train_accuracy.result()))
            except KeyboardInterrupt:
                logging.info('interrupted.')
                model.save_weights(ckpt_path.format(epoch=epoch))
                logging.info('model saved into: {}'.format(ckpt_path.format(epoch=epoch)))
                exit(0)

In [135]:
train()

11:19:33 05.08 [1mINFO[0m 1599783940.py:4]: all characters: 3755


(<tf.Tensor: shape=(32, 64, 64, 1), dtype=float32, numpy=
array([[[[ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ],
         ...,
         [ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ]],

        [[ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ],
         ...,
         [ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ]],

        [[ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ],
         ...,
         [ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ]],

        ...,

        [[ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ],
         ...,
         [ 0.22584534],
         [ 0.5256119 ],
         [ 0.8377762 ]],

        [[ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ],
         ...,
         [-0.03772736],
         [ 0.23122406],
         [ 0.42704773]],

        [[ 0.9921875 ],
         [ 0.9921875 ],
         [ 0.9921875 ],
         ...,
         [ 0.44918823],
         [ 0.50251

11:19:33 05.08 [1mINFO[0m 1599783940.py:17]: model loaded.
11:19:33 05.08 [1mINFO[0m 1599783940.py:26]: passing resume since weights not there. training from scratch


Epoch 1/15000
Epoch 2/15000
Epoch 3/15000
Epoch 4/15000
Epoch 5/15000
Epoch 6/15000
Epoch 7/15000
Epoch 8/15000
Epoch 9/15000
Epoch 10/15000
Epoch 11/15000
Epoch 12/15000
Epoch 13/15000
Epoch 14/15000
Epoch 15/15000
Epoch 16/15000
Epoch 17/15000
Epoch 18/15000
Epoch 19/15000
Epoch 20/15000
Epoch 21/15000
Epoch 22/15000
Epoch 23/15000
Epoch 24/15000
Epoch 25/15000
Epoch 26/15000
Epoch 27/15000
Epoch 28/15000
Epoch 29/15000
Epoch 30/15000
Epoch 31/15000
Epoch 32/15000
Epoch 33/15000
Epoch 34/15000
Epoch 35/15000
Epoch 36/15000
Epoch 37/15000
Epoch 38/15000
Epoch 39/15000
Epoch 40/15000
Epoch 41/15000
Epoch 42/15000
Epoch 43/15000
Epoch 44/15000
Epoch 45/15000
Epoch 46/15000
Epoch 47/15000
Epoch 48/15000
Epoch 49/15000
Epoch 50/15000
Epoch 51/15000
Epoch 52/15000
Epoch 53/15000
Epoch 54/15000
Epoch 55/15000
Epoch 56/15000
Epoch 57/15000
Epoch 58/15000
Epoch 59/15000
Epoch 60/15000
Epoch 61/15000
Epoch 62/15000
Epoch 63/15000
Epoch 64/15000
Epoch 65/15000
Epoch 66/15000
Epoch 67/15000
Epoc

14:07:26 05.08 [1mINFO[0m 1599783940.py:49]: keras model saved.
