In [1]:
import sys
sys.path.append('../src')

from ocr_data_loader import load_data
from ocr_utils import *
from ocr_model import OCRModel
import os
from torchvision import transforms
from torch import nn
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
import numpy as np
import torch.optim as optim
import glob

In [2]:
BASE_DIR = '../../GT4HistOCR/corpus'
DATA_SET_NAME = 'RefCorpus-ENHG-Incunabula/1476-Historij-Wierstaat'

#========= Hyper parameters 

# Image parameters
IMAGE_WIDTH = 70
IMAGE_HEIGHT = 700
SEQUENCES_NUM = 20 # Number of input sequences ( i.e.: How many frames we will split the input image into)

# NN parameters
HIDDEN_LAYER_SIZE = 500
HIDDEN_LAYERS_NUM = 1 # Number of LSTM cells to stack

# Training parameters
LEARNING_RATE = 0.01
MOMENTUM = 0.05
EPOCHS = 100
TRAIN_TEST_SPLIT = .8
CLIPPING_VALUE = 3

transformation = transforms.Compose(
    [transforms.RandomRotation(degrees=(-90,-90), expand=True), 
     transforms.Resize((IMAGE_HEIGHT,IMAGE_WIDTH)), 
     transforms.ToTensor()])

train_data, test_data, dataset = load_data(base_dir = '../GT4HistOCR/corpus', dataset_name = 'RefCorpus-ENHG-Incunabula/1476-Historij-Wierstaat',
                                              transformation=transformation,
                                              train_test_split=TRAIN_TEST_SPLIT)

# Fixed values ( i.e.: not configurable)
ALPHABET_SIZE = len(dataset.alphabet)
INPUT_DIMENSION = int( (IMAGE_HEIGHT / SEQUENCES_NUM) * IMAGE_WIDTH )


# Define and train the model
model = OCRModel(INPUT_DIMENSION, HIDDEN_LAYER_SIZE, HIDDEN_LAYERS_NUM, ALPHABET_SIZE)

optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

losses = model.train(train_data, optimizer, SEQUENCES_NUM, INPUT_DIMENSION, EPOCHS,CLIPPING_VALUE )

print("Final Mean loss {}".format(np.mean(losses)))

print( "============== Infer test data")
for batch in enumerate(test_data):
    
    true_texts = batch[1]['text']
    predicted_texts = model.infer(batch[1]['image'], SEQUENCES_NUM, INPUT_DIMENSION, dataset.alphabet)
    
    for i in range(0, len(true_texts)):
        print("Original Text:   {} \nPredicted Text:  {} \n".format(true_texts[i], predicted_texts[i]))

  output = torch.tensor( batch[1]['text_vector'] )
  output_lengths = torch.tensor( batch[1]['text_length'] )


Epoch 0, Mean loss 0.18370480835437775
Epoch 1, Mean loss 0.18401667475700378
Epoch 2, Mean loss 0.18396171927452087
Epoch 3, Mean loss 0.20243465900421143
Epoch 4, Mean loss 0.20291832089424133
Epoch 5, Mean loss 0.18236330151557922
Epoch 6, Mean loss 0.18220485746860504
Epoch 7, Mean loss 0.18132363259792328
Epoch 8, Mean loss 0.20160581171512604
Epoch 9, Mean loss 0.18078778684139252
Epoch 10, Mean loss 0.18130645155906677
Epoch 11, Mean loss 0.18029743432998657
Epoch 12, Mean loss 0.18127988278865814
Epoch 13, Mean loss 0.18109656870365143
Epoch 14, Mean loss 0.18081340193748474
Epoch 15, Mean loss 0.18109065294265747
Epoch 16, Mean loss 0.20007583498954773
Epoch 17, Mean loss 0.18025022745132446
Epoch 18, Mean loss 0.18003442883491516
Epoch 19, Mean loss 0.2002863883972168
Epoch 20, Mean loss 0.17937038838863373
Epoch 21, Mean loss 0.18015418946743011
Epoch 22, Mean loss 0.1801353394985199
Epoch 23, Mean loss 0.1792815774679184
Epoch 24, Mean loss 0.17913874983787537
Epoch 25, Mea

AttributeError: 'list' object has no attribute 'shape'

Original Text:   Byrnhoultz vyll zo backen ind bruwen dat gemall 
Predicted Text:  AAtmTrMGxTxnE·Ezni·ẽ 

Original Text:   Man hoerd dayr vyll iamers claegen 
Predicted Text:  AAAAtvpyku S i nO ẽz 

Original Text:   Myt gudẽ hertzen dayr zu ſayſſen 
Predicted Text:  AAAATmmewi· n   gbuG 

Original Text:   Got wyll dye gud frund geleyden 
Predicted Text:  AAAvTAGeoo nsO ···e  

Original Text:   Vyll roſſmoelen in der ſtat dye fuegen dayr wall 
Predicted Text:  AAAtiMw WOſ·nJJneDBE 

Original Text:   Alſo geſament zo ſtrijden 
Predicted Text:  AAAArSMſy·s· ·cuJVG  

Original Text:   Vp maenendaygh hoyrt mych vortan 
Predicted Text:  AAAmmiBErTnuuS··n··· 

Original Text:   Gudt zo doyn ind gelucks zo wynſchen 
Predicted Text:  AAAAſjkkVoeDOSffllS· 

Original Text:   Den homechtichſten roemſch keyſer 
Predicted Text:  AAMNy·Nflauuſ iEuuDy 

Original Text:   Vyll wullengewantz man ouch dayr hauen moyt 
Predicted Text:  AAAqmAgueT sBE t  ·S 

