# Example OCR model

**This notebook is still experimental.** It will be updated in the next updates to give a concrete training example + documentation ;) Nevertheless, the available pretrained models seems to be accurate enough for a simple usage! Check the `ocr` notebook for more information and examples !

In [1]:
import torch # to avoid errors when converting the pre-trained weights
import numpy as np
import pandas as pd
import tensorflow as tf

from sklearn.utils import shuffle as sklearn_shuffle

from loggers import set_level
from utils import plot, plot_multiple
from utils.image import load_image
from models import get_pretrained
from models.ocr import CRNN
from datasets import get_dataset, train_test_split, prepare_dataset, test_dataset_time

model_name = 'crnn_latin'
print('Tensorflow version : {}'.format(tf.__version__))

Tensorflow version : 2.10.0


## Model creation

In [None]:
lang = model_name.split('_')[-1]

model = CRNN(
    nom  = model_name, lang = lang, pretrained_lang = lang
)

print(model)
print(model.text_encoder)
model.summary()

## Model instanciation + dataset loading

In [None]:
model = get_pretrained(model_name)

lr = {'name' : 'DivideByStep', 'maxval' : 1e-3, 'minval' : 1e-4}

model.compile(optimizer = 'adam', optimizer_config = {'lr' : lr})
print(model)

In [None]:
dataset = get_dataset('synthtext', one_line_per_box = True, add_image_size = False)

if isinstance(dataset, dict):
    train, valid = dataset['train'], dataset['valid']
else:
    train, valid = train_test_split(dataset, valid_size = 0.1, shuffle = True, random_state = 10)

train = sklearn_shuffle(train, random_state = 10)
    
print('Dataset length :\n  Train size : {}\n  Valid size : {}'.format(len(train), len(valid)))

## Training

In [None]:
epochs     = 5
batch_size = 128

augment_prct = 0.
shuffle_size = batch_size * 8

max_output_length = 64

model.train(
    train, validation_data = valid, epochs = epochs, batch_size = batch_size,
    max_output_length = max_output_length,
    augment_prct = augment_prct, shuffle_size = shuffle_size, cache = len(train) < 200000
)

In [None]:
model.plot_history()
print(model.history)

## Prediction

In [None]:
samples = get_dataset('coco_text', modes = 'valid', one_line_per_box = True)

In [None]:
samples = valid

for idx, row in samples.sample(10, random_state = 2).iterrows():
    print(load_image(row['filename'], bbox = row['box']).shape)
    print(model.get_input(row).shape)
    inp = model.preprocess_input(tf.expand_dims(model.get_input(row), axis = 0))
    plot(inp[0], plot_type = 'imshow')
    print(inp.shape)
    if tf.reduce_any(tf.shape(inp)[1:-1] < 16): continue
    out = model.infer(inp, max_length = 10)
    
    print(row['label'], model.decode_output(out))
    plot(inp[0], plot_type = 'imshow')

## Tests

In [None]:
set_level('debug', 'datasets')

config = model.get_dataset_config(is_validation = False, batch_size = 64, prefetch = False)

train_ds = prepare_dataset(valid, ** config)

set_level('info', 'datasets')

test_dataset_time(train_ds, steps = 1000)