In [None]:
%pip install -qqq keras_ocr
%pip install -qqq tensorflow_addons
%pip install -qqq straug
!apt install -qq libmagickwand-dev

In [None]:
import os
import string
import random
import math

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import sklearn.model_selection
import tensorflow as tf
import tensorflow_addons as tfa
import keras_ocr

from straug.blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur, ZoomBlur
from straug.camera import Contrast, Brightness, JpegCompression, Pixelate
from straug.noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise
from straug.process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color
from straug.weather import Shadow

In [None]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
alphabet = string.ascii_letters + '!"\'(),-.:;?[]'
recognizer_alphabet = ''.join(sorted(set(alphabet.lower())))

In [None]:
fonts = keras_ocr.data_generation.get_fonts(alphabet=alphabet)
backgrounds = keras_ocr.data_generation.get_backgrounds()

In [None]:
def get_train_test_split(arr):
    return sklearn.model_selection.train_test_split(arr, train_size=0.8, random_state=SEED)

background_splits = get_train_test_split(backgrounds)
font_splits = get_train_test_split(fonts)

text_generator = keras_ocr.data_generation.get_text_generator(alphabet=alphabet)

image_generators = [
    keras_ocr.data_generation.get_image_generator(
        height=320,
        width=640,
        text_generator=text_generator,
        font_groups={
            alphabet: current_fonts
        },
        backgrounds=current_backgrounds,
        font_size=(60, 120),
        margin=50,
        rotationX=(-0.05, 0.05),
        rotationY=(-0.05, 0.05),
        rotationZ=(-15, 15),
    )  for current_fonts, current_backgrounds in zip(
        font_splits,
        background_splits
    )
]

In [None]:
BATCH_SIZE = 224

recognizer = keras_ocr.recognition.Recognizer(alphabet=recognizer_alphabet, weights='kurapan')
recognizer.compile()

for layer in recognizer.backbone.layers:
    layer.trainable = False

In [None]:
def augment_image(image_generator, M=0, random_state=SEED):
    rng = np.random.default_rng(random_state)
    op_sets = [
        [GaussianNoise(rng), ShotNoise(rng), ImpulseNoise(rng), SpeckleNoise(rng)],
        [GaussianBlur(rng), DefocusBlur(rng), MotionBlur(rng), GlassBlur(rng), ZoomBlur(rng)],
        [Contrast(rng), Brightness(rng), JpegCompression(rng), Pixelate(rng)],
        [Shadow(rng)],
        [Posterize(rng), Solarize(rng), Invert(rng), Equalize(rng), AutoContrast(rng), Sharpness(rng), Color(rng)],
    ]

    for image, text in image_generator:
        image = Image.fromarray(image).convert('RGB')
        n_op_set = random.randint(0, len(op_sets))
        sampled_op_sets = random.sample(op_sets, n_op_set)
        
        for op_set in sampled_op_sets:
            sampled_op = random.choice(op_set)
            image = sampled_op(image, mag=M)

        yield np.asarray(image), text

In [None]:
recognition_image_generators = [
    keras_ocr.data_generation.convert_image_generator_to_recognizer_input(
        image_generator=image_generator,
        max_string_length=recognizer.training_model.input_shape[1][1],
        target_width=recognizer.model.input_shape[2],
        target_height=recognizer.model.input_shape[1],
        margin=1
    ) for image_generator in image_generators
]

recognition_image_generators[0] = augment_image(recognition_image_generators[0], M=0)

recognition_train_generator, recognition_test_generator = [
    recognizer.get_batch_generator(
        image_generator=image_generator,
        batch_size=BATCH_SIZE,
        lowercase=True
    ) for image_generator in recognition_image_generators
]

In [None]:
save_path = '/content/drive/Shareddrives/วิสัยทัศน์สมองกล/keras_ocr/checkpoints/version_2'
os.makedirs(save_path, exist_ok=True)
recognizer_basepath = os.path.join(save_path, 'recognizer')

recognizer.training_model.fit(
    recognition_train_generator,
    epochs=1000,
    steps_per_epoch=math.ceil(len(background_splits[0]) / BATCH_SIZE),
    callbacks=[
        tf.keras.callbacks.ReduceLROnPlateau(patience=10),
        tf.keras.callbacks.EarlyStopping(restore_best_weights=True, patience=50),
        tf.keras.callbacks.CSVLogger(f'{recognizer_basepath}.csv', append=True),
        tf.keras.callbacks.ModelCheckpoint(filepath=f'{recognizer_basepath}.h5')
    ],
    validation_data=recognition_test_generator,
    validation_steps=math.ceil(len(background_splits[1]) / BATCH_SIZE),
    workers=-1,
    batch_size=BATCH_SIZE,
)