In [1]:
import sys
sys.path.append('../src')

from ocr_data_loader import load_data
from ocr_utils import *
import os
from torchvision import transforms
from torch import nn
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
import numpy as np
import torch.optim as optim
import glob

In [2]:
BASE_DIR = '../../GT4HistOCR/corpus'
DATA_SET_NAME = 'RefCorpus-ENHG-Incunabula/1476-Historij-Wierstaat'
#to_pil = torchvision.transforms.ToPILImage()

In [3]:
class OCRNet(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(OCRNet, self).__init__()
        
        self.hidden_dim = hidden_dim

        # Number of hidden layers
        self.layer_dim = layer_dim

        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)

        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.randn(self.layer_dim, x.size(0), self.hidden_dim)

        # Initialize cell state
        c0 = torch.randn(self.layer_dim, x.size(0), self.hidden_dim)

        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        
        out = self.fc(out)

        return F.log_softmax(out, dim=1)

In [4]:
#========= Hyper parameters 

# Image parameters
IMAGE_WIDTH = 70
IMAGE_HEIGHT = 700
SEQUENCES_NUM = 70 # Number of input sequences ( i.e.: How many frames we will split the input image into)

# NN parameters
HIDDEN_LAYER_SIZE = 500
HIDDEN_LAYERS_NUM = 3 # Number of LSTM cells to stack

# Training parameters
LEARNING_RATE = 0.001
MOMENTUM = 0.05
EPOCHS_NUM = 100
TRAIN_TEST_SPLIT = .8
CLIPPING_VALUE = 3

transformation = transforms.Compose(
    [transforms.RandomRotation(degrees=(-90,-90), expand=True), 
     transforms.Resize((IMAGE_HEIGHT,IMAGE_WIDTH)), 
     transforms.ToTensor()])

train_data, test_data, dataset = load_data(base_dir = '../../GT4HistOCR/corpus', dataset_name = 'RefCorpus-ENHG-Incunabula/1476-Historij-Wierstaat',
                                              transformation=transformation,
                                              train_test_split=TRAIN_TEST_SPLIT)

# Fixed values ( i.e.: not configurable)
ALPHABET_SIZE = len(dataset.alphabet)
INPUT_DIMENSION = int( (IMAGE_HEIGHT / SEQUENCES_NUM) * IMAGE_WIDTH )

In [5]:

model = OCRNet(INPUT_DIMENSION, HIDDEN_LAYER_SIZE, HIDDEN_LAYERS_NUM, ALPHABET_SIZE)
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(0,EPOCHS_NUM):
    losses = list()
    for batch in enumerate(train_data):
        # Move data to GPU
        #data, target = data.to(device), target.to(device)
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Compute output
        x = batch[1]['image'].view(-1, SEQUENCES_NUM, INPUT_DIMENSION)
        probabilities = model(x).permute(1,0,2)
        
        # Compute crossentropy loss
        ctc_loss = nn.CTCLoss(zero_infinity=True)
        
        probabilities_lengths = torch.full((probabilities.shape[1],), probabilities.shape[0], dtype=torch.long)

        output = torch.tensor( batch[1]['text_vector'] )
        output_lengths = torch.tensor( batch[1]['text_length'] )
        
        loss = ctc_loss(probabilities, output, probabilities_lengths, output_lengths)
         
        # Compute gradient
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIPPING_VALUE)
        
        print(loss)
        
        # Perform gradient descent
        optimizer.step()
        
        # Track losses
        losses.append(loss.item())



tensor(11.5105, grad_fn=<MeanBackward0>)
tensor(11.8900, grad_fn=<MeanBackward0>)
tensor(11.3990, grad_fn=<MeanBackward0>)
tensor(13.6217, grad_fn=<MeanBackward0>)
tensor(11.0550, grad_fn=<MeanBackward0>)
tensor(12.2490, grad_fn=<MeanBackward0>)
tensor(12.0718, grad_fn=<MeanBackward0>)
tensor(11.2731, grad_fn=<MeanBackward0>)
tensor(12.4586, grad_fn=<MeanBackward0>)
tensor(11.4490, grad_fn=<MeanBackward0>)
tensor(12.2378, grad_fn=<MeanBackward0>)
tensor(11.7708, grad_fn=<MeanBackward0>)
tensor(11.7738, grad_fn=<MeanBackward0>)
tensor(12.3497, grad_fn=<MeanBackward0>)
tensor(12.7273, grad_fn=<MeanBackward0>)
tensor(11.1744, grad_fn=<MeanBackward0>)


KeyboardInterrupt: 