In [0]:
# code adapted from https://github.com/keras-team/keras/blob/master/examples/image_ocr.py
# 06/02/2020

from google.colab import drive
drive.mount('/content/drive')

In [0]:
# open source fonts from Google Fonts https://fonts.google.com

!cp -r '/content/drive/My Drive/Fonts/Caveat' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/CedarvilleCursive' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Courgette' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Damion' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/DancingScript' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/GloriaHallelujah' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/GochiHand' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Kalam' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/KaushanScript' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/LaBelleAurore' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Merienda' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Montez' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Satisfy' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Tangerine' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/WorkSans' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/UbuntuMono' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/RobotoSlab' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/RobotoMono' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/PTSerif' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/OxygenMono' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/Merriweather' '/usr/share/fonts/truetype/'
!cp -r '/content/drive/My Drive/Fonts/LibreBaskerville' '/usr/share/fonts/truetype/'

In [0]:
!pip install -q cairocffi editdistance
!apt install -q libcairo2-dev
!apt install -q graphviz
!pip install -q pydot
!pip install -q matplotlib graphviz pydot

In [0]:
%tensorflow_version 1.x
import os
import itertools
import codecs
import re
import datetime
import cairocffi as cairo
import editdistance
import numpy as np
import pylab
from scipy import ndimage
from secrets import choice
import matplotlib.pyplot as plt
from keras import backend as K
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers import Input, Dense, Activation
from keras.layers import Reshape, Lambda
from keras.layers.merge import add, concatenate
from keras.models import Model
from keras.layers.recurrent import GRU
from keras.layers.recurrent import LSTM
from keras.layers.recurrent import SimpleRNN
from keras.optimizers import SGD
from keras.optimizers import RMSprop
from keras.optimizers import Adam
from keras.optimizers import Adadelta
from keras.optimizers import Nadam
from keras.utils.data_utils import get_file
from keras.preprocessing import image
import keras.callbacks

In [0]:
OUTPUT_DIR = 'lstm_ocr'

# character classes and matching regex filter
regex = r'^[a-z ]+$'
alphabet = u'abcdefghijklmnopqrstuvwxyz '

np.random.seed(55)

# this creates larger "blotches" of noise which look
# more realistic than just adding gaussian noise
# assumes greyscale with pixels ranging from 0 to 1

def speckle(img):
    severity = np.random.uniform(0, 0.6)
    blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
    img_speck = (img + blur)
    img_speck[img_speck > 1] = 1
    img_speck[img_speck <= 0] = 0
    return img_speck


# paints the string in a random location in the bounding box
# also uses a random font, a slight random rotation,
# and a random amount of speckle noise

def paint_text(text, w, h, multi_fonts, rotate=False, ud=False):
    surface = cairo.ImageSurface(cairo.FORMAT_RGB24, w, h)
    with cairo.Context(surface) as context:
        context.set_source_rgb(1, 1, 1)  # white backround
        context.paint()
        if multi_fonts == 'full':
            fonts = ['Courier', 'Tangerine', 'Satisfy', 'Montez', 'LaBelleAurore', 'DancingScript', 'Damion',
                     'CedarvilleCursive', 'Merienda', 'KaushanScript', 'Kalam', 'GochiHand',
                     'GloriaHallelujah', 'Courgette', 'Caveat', 'WorkSans', 'Ubuntumono', 
                     'RobotoSlab', 'RobotoMono', 'PTSerif', 'OxygenMono', 'Merriweather', 
                     'LibreBaskerville']
            context.select_font_face(choice(fonts), cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
        elif multi_fonts == 'narrow':
            fonts = ['Courier', 'Tangerine', 'Merienda', 'Courgette', 'Caveat', 'Kalam', 'GochiHand', 'GloriaHallelujah']
            context.select_font_face(choice(fonts), cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
        elif multi_fonts == 'block':
            fonts = ['Courier', 'WorkSans', 'Ubuntumono', 'RobotoSlab', 'RobotoMono', 'PTSerif',
                     'OxygenMono', 'Merriweather', 'LibreBaskerville']
            context.select_font_face(choice(fonts), cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
        elif multi_fonts == 'fancy':
            fonts = ['Satisfy', 'Montez', 'LaBelleAurore', 'DancingScript', 'Damion',
                     'CedarvilleCursive', 'KaushanScript']
            context.select_font_face(choice(fonts), cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
        elif multi_fonts == 'False':
            context.select_font_face('Courier', cairo.FONT_SLANT_NORMAL, cairo.FONT_WEIGHT_NORMAL)
        context.set_font_size(20)
        box = context.text_extents(text)
        border_w_h = (4, 4)
        if box[2] > (w - 2 * border_w_h[1]) or box[3] > (h - 2 * border_w_h[0]):
            raise IOError('Could not fit string into image. Max char count is too large for given image width.')

        # teach the RNN translational invariance by
        # fitting text box randomly on canvas, with some room to rotate
        max_shift_x = w - box[2] - border_w_h[0]
        max_shift_y = h - box[3] - border_w_h[1]
        top_left_x = np.random.randint(0, int(max_shift_x))
        if ud:
            top_left_y = np.random.randint(0, int(max_shift_y))
        else:
            top_left_y = h // 2
        context.move_to(top_left_x - int(box[0]), top_left_y - int(box[1]))
        context.set_source_rgb(0, 0, 0)
        context.show_text(text)

    buf = surface.get_data()
    a = np.frombuffer(buf, np.uint8)
    a.shape = (h, w, 4)
    a = a[:, :, 0]  # grab single channel
    a = a.astype(np.float32) / 255
    a = np.expand_dims(a, 0)
    if rotate:
        a = image.random_rotation(a, 3 * (w - top_left_x) / w + 1)
    a = speckle(a)

    return a


def shuffle_mats_or_lists(matrix_list, stop_ind=None):
    ret = []
    assert all([len(i) == len(matrix_list[0]) for i in matrix_list])
    len_val = len(matrix_list[0])
    if stop_ind is None:
        stop_ind = len_val
    assert stop_ind <= len_val

    a = list(range(stop_ind))
    np.random.shuffle(a)
    a += list(range(stop_ind, len_val))
    for mat in matrix_list:
        if isinstance(mat, np.ndarray):
            ret.append(mat[a])
        elif isinstance(mat, list):
            ret.append([mat[i] for i in a])
        else:
            raise TypeError('`shuffle_mats_or_lists` only supports '
                            'numpy.array and list objects.')
    return ret


# Translation of characters to unique integer values
def text_to_labels(text):
    ret = []
    for char in text:
        ret.append(alphabet.find(char))
    return ret


# Reverse translation of numerical classes back to characters
def labels_to_text(labels):
    ret = []
    for c in labels:
        if c == len(alphabet):  # CTC Blank
            ret.append("")
        else:
            ret.append(alphabet[c])
    return "".join(ret)


# only a-z and space
def is_valid_str(in_str):
    search = re.compile(regex, re.UNICODE).search
    return bool(search(in_str))

In [0]:

# Uses generator functions to supply train/test with
# data. Image renderings and text are created on the fly
# each time with random perturbations

class TextImageGenerator(keras.callbacks.Callback):

    def __init__(self, monogram_file, minibatch_size,
                 img_w, img_h, downsample_factor, multi_fonts, 
                 val_split, absolute_max_string_len=16, curriculum=True):

        self.minibatch_size = minibatch_size
        self.img_w = img_w
        self.img_h = img_h
        self.monogram_file = monogram_file
        self.downsample_factor = downsample_factor
        self.curriculum = curriculum
        self.multi_fonts = multi_fonts
        self.val_split = val_split
        self.blank_label = self.get_output_size() - 1
        self.absolute_max_string_len = absolute_max_string_len

    def get_output_size(self):
        return len(alphabet) + 1

    # num_words can be independent of the epoch size due to the use of generators
    # as max_string_len grows, num_words can grow
    def build_word_list(self, num_words, max_string_len=None, mono_fraction=0.5):
        assert max_string_len <= self.absolute_max_string_len
        assert num_words % self.minibatch_size == 0
        assert (self.val_split * num_words) % self.minibatch_size == 0
        self.num_words = num_words
        self.string_list = [''] * self.num_words
        tmp_string_list = []
        self.max_string_len = max_string_len
        self.Y_data = np.ones([self.num_words, self.absolute_max_string_len]) * -1
        self.X_text = []
        self.Y_len = [0] * self.num_words

        # monogram file is sorted by frequency in english speech
        with codecs.open(self.monogram_file, mode='r', encoding='utf-8') as f:
            for line in f:
                if len(tmp_string_list) == int(self.num_words * mono_fraction):
                    break
                word = line.rstrip()
                if max_string_len == -1 or max_string_len is None or len(word) <= max_string_len:
                    tmp_string_list.append(word)

        if len(tmp_string_list) != self.num_words:
            raise IOError('Could not pull enough words from supplied monogram and bigram files. ')
        # interlace to mix up the easy and hard words
        self.string_list[::2] = tmp_string_list[:self.num_words // 2]
        self.string_list[1::2] = tmp_string_list[self.num_words // 2:]

        for i, word in enumerate(self.string_list):
            self.Y_len[i] = len(word)
            self.Y_data[i, 0:len(word)] = text_to_labels(word)
            self.X_text.append(word)
        self.Y_len = np.expand_dims(np.array(self.Y_len), 1)

        self.cur_val_index = self.val_split
        self.cur_train_index = 0

    # each time an image is requested from train/val/test, a new random
    # painting of the text is performed
    def get_batch(self, index, size, train):
        # width and height are backwards from typical Keras convention
        # because width is the time dimension when it gets fed into the RNN
        if K.image_data_format() == 'channels_first':
            X_data = np.ones([size, 1, self.img_w, self.img_h])
        else:
            X_data = np.ones([size, self.img_w, self.img_h, 1])

        labels = np.ones([size, self.absolute_max_string_len])
        input_length = np.zeros([size, 1])
        label_length = np.zeros([size, 1])
        source_str = []
        for i in range(size):
            # Mix in some blank inputs.  This seems to be important for
            # achieving translational invariance
            if train and i > size - 4:
                if K.image_data_format() == 'channels_first':
                    X_data[i, 0, 0:self.img_w, :] = self.paint_func('')[0, :, :].T
                else:
                    X_data[i, 0:self.img_w, :, 0] = self.paint_func('',)[0, :, :].T
                labels[i, 0] = self.blank_label
                input_length[i] = self.img_w // self.downsample_factor - 2
                label_length[i] = 1
                source_str.append('')
            else:
                if K.image_data_format() == 'channels_first':
                    X_data[i, 0, 0:self.img_w, :] = self.paint_func(self.X_text[index + i])[0, :, :].T
                else:
                    X_data[i, 0:self.img_w, :, 0] = self.paint_func(self.X_text[index + i])[0, :, :].T
                labels[i, :] = self.Y_data[index + i]
                input_length[i] = self.img_w // self.downsample_factor - 2
                label_length[i] = self.Y_len[index + i]
                source_str.append(self.X_text[index + i])
        inputs = {'the_input': X_data,
                  'the_labels': labels,
                  'input_length': input_length,
                  'label_length': label_length,
                  'source_str': source_str  # used for visualization only
                  }
        outputs = {'ctc': np.zeros([size])}  # dummy data for dummy loss function
        return (inputs, outputs)

    def next_train(self):
        while 1:
            ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True)
            self.cur_train_index += self.minibatch_size
            if self.cur_train_index >= self.val_split:
                self.cur_train_index = self.cur_train_index % 32
                (self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
                    [self.X_text, self.Y_data, self.Y_len], self.val_split)
            yield ret

    def next_val(self):
        while 1:
            ret = self.get_batch(self.cur_val_index, self.minibatch_size, train=False)
            self.cur_val_index += self.minibatch_size
            if self.cur_val_index >= self.num_words:
                self.cur_val_index = self.val_split + self.cur_val_index % 32
            yield ret

    def on_train_begin(self, start_epoch, logs={}):
      if self.curriculum:
        self.build_word_list(16000, 4, 1)
        self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h, 
                                                  multi_fonts = 'False', rotate = False,
                                                  ud = False)
      else:
        self.build_word_list(16000, 7, 1)
        self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h,
                                                  multi_fonts = self.multi_fonts, rotate = True, 
                                                  ud = True)
            
    def on_epoch_begin(self, epoch, logs={}):
        # rebind the paint function to implement curriculum learning
        if self.curriculum:
          if 3 <= epoch < 6:
            self.build_word_list(16000, 5, 1)
            self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h,
                                                      multi_fonts = 'False', rotate = False, 
                                                      ud = True)
          elif 6 <= epoch < 9:
            self.build_word_list(16000, 6, 1)
            self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h, 
                                                      multi_fonts = self.multi_fonts, rotate = False, 
                                                      ud = True)
          elif epoch >= 9:
            self.build_word_list(16000, 7, 1)
            self.paint_func = lambda text: paint_text(text, self.img_w, self.img_h, 
                                                      multi_fonts = self.multi_fonts, rotate = True, 
                                                      ud = True)
              



In [0]:
# the actual loss calc occurs here despite it not being
# an internal Keras loss function

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    # the 2 is critical here since the first couple outputs of the RNN
    # tend to be unusable:
    y_pred = y_pred[:, 2:, :]
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)


# uses best path algorithm. Could use beam search + language model, but language model prevents 
# typos from being read. 

def decode_batch(test_func, word_batch):
    out = test_func([word_batch])[0]
    ret = []
    for j in range(out.shape[0]):
        out_best = list(np.argmax(out[j, 2:], 1))
        out_best = [k for k, g in itertools.groupby(out_best)]
        outstr = labels_to_text(out_best)
        ret.append(outstr)
    return ret

In [0]:
class VizCallback(keras.callbacks.Callback):

    def __init__(self, run_name, test_func, text_img_gen, num_display_words=6):
        self.test_func = test_func
        self.output_dir = os.path.join(
            OUTPUT_DIR, run_name)
        self.text_img_gen = text_img_gen
        self.num_display_words = num_display_words
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

    def on_train_begin(self, logs={}):
        self.losses = []
        self.mean_eds = []
        self.mean_norm_eds = []

    def show_edit_distance(self, num):
        num_left = num
        mean_norm_ed = 0.0
        mean_ed = 0.0
        while num_left > 0:
            word_batch = next(self.text_img_gen)[0]
            num_proc = min(word_batch['the_input'].shape[0], num_left)
            decoded_res = decode_batch(self.test_func, word_batch['the_input'][0:num_proc])
            for j in range(num_proc):
                edit_dist = editdistance.eval(decoded_res[j], word_batch['source_str'][j])
                mean_ed += float(edit_dist)
                mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j])
            num_left -= num_proc
        mean_norm_ed = mean_norm_ed / num
        mean_ed = mean_ed / num
        print('\nOut of %d samples:  Mean edit distance: %.8f Mean normalized edit distance: %0.8f'
              % (num, mean_ed, mean_norm_ed))
        self.mean_eds.append(mean_ed)
        self.mean_norm_eds.append(mean_norm_ed)

    def on_epoch_end(self, epoch, logs={}):
        self.model.save_weights(os.path.join(self.output_dir, 'weights%02d.h5' % (epoch)))
        self.losses.append(logs.get('val_loss'))
        self.show_edit_distance(256)
        word_batch = next(self.text_img_gen)[0]
        res = decode_batch(self.test_func, word_batch['the_input'][0:self.num_display_words])
        if word_batch['the_input'][0].shape[0] < 256:
            cols = 2
        else:
            cols = 1
        for i in range(self.num_display_words):
            plt.subplot(self.num_display_words // cols, cols, i + 1)
            if K.image_data_format() == 'channels_first':
                the_input = word_batch['the_input'][i, 0, :, :]
            else:
                the_input = word_batch['the_input'][i, :, :, 0]
            plt.imshow(the_input.T, cmap='Greys_r')
            plt.xlabel('Truth = \'%s\'\nDecoded = \'%s\'' % (word_batch['source_str'][i], res[i]))
        fig = pylab.gcf()
        fig.set_size_inches(10, 13)
        plt.savefig(os.path.join(self.output_dir, 'e%02d.png' % (epoch)))
        plt.close()
    
    def on_train_end(self, logs={}):
        print(self.losses)
        print(self.mean_eds)
        print(self.mean_norm_eds)

In [0]:
!mkdir /content/lstm_ocr
!mkdir /content/lstm_ocr/weights

In [0]:
def train(run_name, start_epoch, stop_epoch, img_w, weights=False):
    # weights_external = '/content/drive/My Drive/weights/cnn_parallel_lstm_rnn/final_cnn_parallel_lstm_rnn_e100.h5'
    # Input Parameters
    img_h = 64
    words_per_epoch = 16000
    val_split = 0.2
    val_words = int(words_per_epoch * (val_split))
    multi_fonts = 'full'
    curriculum = False

    # Network parameters
    conv_filters = 16
    kernel_size = (3, 3)
    kernel_size_2 = (4, 4)
    kernel_size_3 = (5, 5)
    kernel_size_4 = (2, 2)
    pool_size = 2
    time_dense_size = 32
    rnn_size = 512
    minibatch_size = 32

    if K.image_data_format() == 'channels_first':
        input_shape = (1, img_w, img_h)
    else:
        input_shape = (img_w, img_h, 1)

    fdir = os.path.dirname(get_file('wordlists.tgz',
                                    origin='http://www.mythic-ai.com/datasets/wordlists.tgz', untar=True))

    img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
                                 minibatch_size=minibatch_size,
                                 img_w=img_w,
                                 img_h=img_h,
                                 downsample_factor=(pool_size ** 2),
                                 multi_fonts=multi_fonts,
                                 val_split=words_per_epoch - val_words,
                                 curriculum=curriculum
                                 )
    #### model ####
    act = 'relu'
    input_data = Input(name='the_input', shape=input_shape, dtype='float32')
    inner = Conv2D(conv_filters, kernel_size, padding='same', activation=act, kernel_initializer='he_normal', name='conv1')(input_data)
    inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)
    inner = Conv2D(conv_filters, kernel_size, padding='same', activation=act, kernel_initializer='he_normal', name='conv2')(inner)
    inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)

    inner_2 = Conv2D(conv_filters, kernel_size_2, padding='same', activation=act, kernel_initializer='he_normal', name='conv1_2')(input_data)
    inner_2 = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1_2')(inner_2)
    inner_2 = Conv2D(conv_filters, kernel_size_2, padding='same', activation=act, kernel_initializer='he_normal', name='conv2_2')(inner_2)
    inner_2 = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2_2')(inner_2)

    inner_3 = Conv2D(conv_filters, kernel_size_3, padding='same', activation=act, kernel_initializer='he_normal', name='conv1_3')(input_data)
    inner_3 = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1_3')(inner_3)
    inner_3 = Conv2D(conv_filters, kernel_size_3, padding='same', activation=act, kernel_initializer='he_normal', name='conv2_3')(inner_3)
    inner_3 = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2_3')(inner_3)

    inner_4 = Conv2D(conv_filters, kernel_size_4, padding='same', activation=act, kernel_initializer='he_normal', name='conv1_4')(input_data)
    inner_4 = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1_4')(inner_4)
    inner_4 = Conv2D(conv_filters, kernel_size_4, padding='same', activation=act, kernel_initializer='he_normal', name='conv2_4')(inner_4)
    inner_4 = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2_4')(inner_4)

    conv_to_rnn_dims = (img_w // (pool_size ** 2), (img_h // (pool_size ** 2)) * conv_filters)
    
    inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)
    inner_2 = Reshape(target_shape=conv_to_rnn_dims, name='reshape_2')(inner_2)
    inner_3 = Reshape(target_shape=conv_to_rnn_dims, name='reshape_3')(inner_3)
    inner_4 = Reshape(target_shape=conv_to_rnn_dims, name='reshape_4')(inner_4)


    # cuts down input size going into RNN:
    inner = Dense(time_dense_size, activation=act, name='dense1')(concatenate([inner, inner_2, inner_3, inner_4]))

    # Two layers of bidirectional LSTMs:
    lstm_1 = LSTM(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='LSTM1')(inner)
    lstm_1b = LSTM(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='LSTM1_b')(inner)
    lstm1_merged = add([lstm_1, lstm_1b])
    lstm_2 = LSTM(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='LSTM2')(lstm1_merged)
    lstm_2b = LSTM(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='LSTM2_b')(lstm1_merged)

    rnn_1 = SimpleRNN(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='RNN1')(inner)
    rnn_2 = SimpleRNN(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='RNN2')(rnn_1)

    # transforms RNN output to character activations:
    inner = Dense(img_gen.get_output_size(), kernel_initializer='he_normal',
                  name='dense2')(concatenate([lstm_2, lstm_2b, rnn_2]))
    y_pred = Activation('softmax', name='softmax')(inner)
    Model(inputs=input_data, outputs=y_pred).summary()

    labels = Input(name='the_labels', shape=[img_gen.absolute_max_string_len], dtype='float32')
    input_length = Input(name='input_length', shape=[1], dtype='int64')
    label_length = Input(name='label_length', shape=[1], dtype='int64')

    # CTC loss is implemented in a lambda layer
    loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])

    adadelta = Adadelta(lr=1.0)

    model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)

    # the loss calc occurs elsewhere, so use a dummy lambda func for the loss
    model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=adadelta)
    if start_epoch > 0 and curriculum==True:
        weight_file = os.path.join(OUTPUT_DIR, os.path.join(run_name, 'weights%02d.h5' % (start_epoch - 1)))
        model.load_weights(weight_file)
    if start_epoch == 0 and weights:
        model.load_weights(weights_external)
    # captures output of softmax so we can decode the output during visualization
    test_func = K.function([input_data], [y_pred])

    viz_cb = VizCallback(run_name, test_func, img_gen.next_val())
    
    model.fit_generator(generator=img_gen.next_train(),
                        steps_per_epoch=(words_per_epoch - val_words) // minibatch_size,
                        epochs=stop_epoch,
                        validation_data=img_gen.next_val(),
                        validation_steps=val_words // minibatch_size,
                        callbacks=[viz_cb, img_gen],
                        initial_epoch=start_epoch)

    return model

In [0]:
run_name = datetime.datetime.now().strftime('%Y:%m:%d:%H:%M:%S')
model = train(run_name, 0, 100, 128, True)

In [0]:
model_json = model.to_json()
with open("model.json", "w") as json_file:
   json_file.write(model_json)
model.save_weights('/content/lstm_ocr/weights/'+run_name+'.h5')

In [0]:
model.save('/content/lstm_ocr/cnn_parallel_lstm_rnn.h5')