In [1]:
import sys
sys.path.append('../src')

import os
import torch
import glob
import numpy as np
import torch.optim as optim
import torch.nn.functional as F

from ocr_data_loader import *
from ocr_utils import *
from ocr_image_transformations import *
from ocr_model import OCRModel

from torchvision import transforms
from torch import nn
import torchvision
from PIL import Image


In [7]:
BASE_DIR = '../GT4HistOCR/corpus'
DATA_SET_NAME = 'RefCorpus-ENHG-Incunabula'



FRAME_SIZE = 10
BATCH_SIZE = 5
# NN parameters
HIDDEN_LAYER_SIZE = 200
HIDDEN_LAYERS_NUM = 1 # Number of LSTM cells to stack

# Training parameters
LEARNING_RATE = 0.1
MOMENTUM = 0.9
EPOCHS = 1
TRAIN_TEST_SPLIT = .9
CLIPPING_VALUE = 1000
DROP_OUT_RATIO=0.1

MAX_IMAGE_WIDTH, MAX_IMAGE_HEIGHT = get_unfied_image_dimensions(f'{BASE_DIR}/{DATA_SET_NAME}', FRAME_SIZE)
print(MAX_IMAGE_HEIGHT)
print(MAX_IMAGE_WIDTH)

INPUT_DIMENSION = MAX_IMAGE_HEIGHT * FRAME_SIZE

# RefCorpus-ENHG-Incunabula (960, 85)
# EarlyModernLatin (4510, 270)
# dta19 (1350, 100)
# Kallimachos (2640, 337)
# RIDGES-Fraktur (3950, 484)

transformation = transforms.Compose([
    transforms.ToTensor(),
    ImageTensorPadding(MAX_IMAGE_HEIGHT, MAX_IMAGE_WIDTH),
    UnfoldImage(1, FRAME_SIZE, FRAME_SIZE)
    ])

#ocr_cv_dataloader = OCRCVDataLoader(base_dir = BASE_DIR, dataset_name = DATA_SET_NAME,
#                                              transformation=transformation, batch_size=BATCH_SIZE)
#dataset = ocr_cv_dataloader.get_dataset()

train_data, test_data, dataset = load_data(base_dir=BASE_DIR, dataset_name=DATA_SET_NAME, 
                                           transformation=transformation,
                                           train_test_split=TRAIN_TEST_SPLIT,
                                           batch_size=BATCH_SIZE)

ALPHABET_SIZE = len(dataset.alphabet)

model = OCRModel(INPUT_DIMENSION, HIDDEN_LAYER_SIZE, HIDDEN_LAYERS_NUM, ALPHABET_SIZE, dropout_ratio=DROP_OUT_RATIO)
    
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
model.train_model(train_data, optimizer, EPOCHS, CLIPPING_VALUE)


print( "============== Infer test data")
print_predicted_text(model, train_data, dataset.alphabet, samples_number = 40)
print_predicted_text(model, test_data, dataset.alphabet, samples_number = 40)


42
440
Epoch 0, Mean loss 4.372429132461548
Original Text:   Jnd des hertzongen myt ſcharper wer 
Predicted Text:  cuodoGmpsrGsyGJwymnfkvTſMuigruztehtMt 

Original Text:   Myt ſyme keyſerlygen her 
Predicted Text:  TaTaTMWpstWmVkpGzhr mVzut 

Original Text:   Vyll verdroncken ym waſſer doit 
Predicted Text:  TodzyW suiuxrDpkpmfGWrGhVt 

Original Text:   Wayr ſach ye mynſche des gelychs 
Predicted Text:  mTdGdTycnsMVk kVkWpshcDzTGWfoV tut 

Original Text:   Dorch doeden myt groiſſen ſchaeden 
Predicted Text:  opkrgJDkyMWxWkWnhfhyrVhoyxsysmwphfht 

Original Text:   Tuſſchen des roemſchen keyſers her 
Predicted Text:  TpGToxsVhfmnfrifafcMWVhdWzprewout 

Original Text:   Der furſten groys vermoegenheyt 
Predicted Text:  TkpmpeGcGuwWnirxMhfprdDVxpnGautkt 

Original Text:   Man hoerd dayr vyll iamers claegen 
Predicted Text:  Tzuz hwzyezpyrydJpGkvdmtkvJtDkxtxutoM 

Original Text:   Got wyll dye gud frund geleyden 
Predicted Text:  ToVwDWpymoWpGWylrDfsTfmkvhJkhkWxot 

Original Text:   Des dyn

In [3]:

accuracies = []

for i in range(0,5):
    print("======== Training For Split ", i)
    train_data, test_data = ocr_cv_dataloader.load_data(i)
    
    # Fixed values ( i.e.: not configurable)
    ALPHABET_SIZE = len(dataset.alphabet)
    INPUT_DIMENSION = MAX_IMAGE_HEIGHT * FRAME_SIZE
    
    # Define and train the model
    model = OCRModel(INPUT_DIMENSION, HIDDEN_LAYER_SIZE, HIDDEN_LAYERS_NUM, ALPHABET_SIZE, dropout_ratio=DROP_OUT_RATIO)
    
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    model.train_model(train_data, optimizer, EPOCHS, CLIPPING_VALUE)
    
    predictions = model.get_predictions(test_data, dataset.alphabet)
    accuracy = get_prediction_accuracy(predictions)
    accuracies.append(accuracy)
    
    print(f'Split {i}, model accuracy is {accuracy}')

print("Final model mean accuracy {}".format(np.mean(accuracies)))



Epoch 0, Mean loss 7.180083751678467
Epoch 1, Mean loss 7.179708480834961
Epoch 2, Mean loss 7.179335117340088
Split 0, model accuracy is 1.817252361426065
Epoch 0, Mean loss 7.105437755584717
Epoch 1, Mean loss 7.105776786804199
Epoch 2, Mean loss 7.105532646179199
Split 1, model accuracy is 1.8662055350463898
Epoch 0, Mean loss 7.173081398010254
Epoch 1, Mean loss 7.172935485839844
Epoch 2, Mean loss 7.172610282897949
Split 2, model accuracy is 1.8298295506555076
Epoch 0, Mean loss 7.201098918914795
Epoch 1, Mean loss 7.201435089111328
Epoch 2, Mean loss 7.201291084289551
Split 3, model accuracy is 1.7971567905363017
Epoch 0, Mean loss 7.089958667755127
Epoch 1, Mean loss 7.090210437774658
Epoch 2, Mean loss 7.089930057525635
Split 4, model accuracy is 1.9214711507167448
Final model mean accuracy 1.8463830776762016


In [4]:
print( "============== Infer test data")
print_predicted_text(model, test_data, dataset.alphabet, samples_number = 40)

Original Text:   Nu hoirt gud frund wat vort geſchach 
Predicted Text:  HFtmScfAĩ·ãfũãVKbHtOũtoZKãpbSxmjZPZnbieufMſqz LſBẽDãwEſsẽsEdePſLſE 

Original Text:   Dye bũnre goyt Spray chẽ oeuer luyt 
Predicted Text:  BZtAo VQuqõFsãDSſRQbaTmgqgxv·QWfĩmZDgvZSFHkEzEũNrzenNhpE·eĩãiyrDzcR 

Original Text:   Den homechtichſten roemſch keyſer 
Predicted Text:  mPlxgLnVRWRQp v bANk·tshcgoWETũQe uvj·nMBnEuLmdOuLzcdmOwẽjEMdſNtx 

Original Text:   Geſchach eyn ſwayr bedroeffde ſlacht 
Predicted Text:  Sb·j ZxmoZujmLGlũpũ·ycWkokJ KZbĩPfwo ũFũ oOnũjlrLwſDãrBqpvFBzvOyPNWO 

Original Text:   Jch neempt weerlych vp myn leuen 
Predicted Text:  cHWHJnſTRWSVFvdcHcuxTbx SZxũcSQsEKOojhkVyzODOweaEſhkLj rQmſgjNOuJRN 

Original Text:   All vangen woulden offenbayr 
Predicted Text:  NoĩletpVRKmĩVTunRHFhcbZhSRnshbGcQBQmeKpfB ſEZsNHsBsJOBMKe rNãwzLWP 

Original Text:   Meer mach ich nyet dayr aff ſaegen 
Predicted Text:  dbpNsVſTAJOwSũty·xbosSKFZgVT·vbZbſHcJRdOmTLjzxdregEſhV RDaPOLrvkGẽs 

Original Text:   Jnd fueghd