In [None]:
!pip install textdistance
!pip install Augmentor

In [None]:
import os
import time
import math
import random
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, BatchNorm2d, ReLU, LeakyReLU
from torch.utils.data import Dataset, sampler
from torch.nn.utils.clip_grad import clip_grad_norm_
from torchvision import models
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from textdistance import levenshtein as lev
import cv2

## CONFIG

In [None]:
!mkdir ./checkpoints/

ALPHABET = " %(),-./0123456789:;?[]«»АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё-"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
REPORT_ACCURACY = True
PATH_TO_TRAIN_IMGDIR = "../input/cyrillic-handwriting-dataset/train/"
PATH_TO_TRAIN_LABELS = "../input/cyrillic-handwriting-dataset/train.tsv"
PATH_TO_TEST_IMGDIR = "../input/cyrillic-handwriting-dataset/test/"
PATH_TO_TEST_LABELS = "../input/cyrillic-handwriting-dataset/test.tsv"
PATH_TO_CHECKPOINT = "./checkpoints/"
BATCH_SIZE = 2
APPLY_AUGS = True # is augmentation applied?
SEED = 41

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

## MODELS

### MODEL 1

"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" https://arxiv.org/abs/1507.05717

In [None]:
class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)
    def forward(self, input):
        self.rnn.flatten_parameters()
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output

class Model1(nn.Module):

    def __init__(self, nHidden, num_classes):
        super(Model1, self).__init__()

        self.conv0 = Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.conv1 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv5 = Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv7 = Conv2d(512, 512, kernel_size=7, stride=1, padding=1)

        self.pool1 = MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = MaxPool2d(kernel_size=2, stride=(2,1))
        self.pool3 = MaxPool2d(kernel_size=2, stride=(2,1))
        self.pool4 = MaxPool2d(kernel_size=2, stride=(3,1))

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)

        self.rnn = nn.Sequential(
            BidirectionalLSTM(nHidden*2, nHidden, nHidden),
            BidirectionalLSTM(nHidden, nHidden, len(ALPHABET)))


    def forward(self, src):
        '''
        src : [b, c, h, w]
        '''
        x = self.conv0(src)
        x = self.bn1(self.pool1(self.conv1(x)))
        x = self.conv2(x)
        x = self.bn2(self.pool2(self.conv3(x)))
        x = self.conv4(x)
        x = self.bn3(self.pool3(self.conv5(x)))
        x = self.conv6(x)
        x = self.bn4(self.pool4(self.conv7(x)))
        b, c, h, w = x.size() # [4, 512, 1, 121])
        assert h == 1, "the height of conv must be 1"
        x = x.squeeze(2) # [b, c, h*w]
        x = x.permute(2, 0, 1)  # [h*w, b, c]
        logits = self.rnn(x) # [h*w, b, num_classes]
        output = torch.nn.functional.log_softmax(logits, 2)
        return output

### MODEL 2: ResNet50 + LSTM

In [None]:
class Model2(nn.Module):

    def __init__(self, nHidden, num_classes):
        super(Model2, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)
        self.resnet50.fc1 = nn.Conv2d(2048, 512, kernel_size=(2, 2))
        self.resnet50.fc2 = nn.Linear(8, 16)

        self.rnn = nn.Sequential(
            BidirectionalLSTM(nHidden*2, nHidden, nHidden),
            BidirectionalLSTM(nHidden, nHidden, num_classes))


    def forward(self, src):
        # ResNet requires 3 channels
        if src.shape[1] == 1:
          src = src.repeat(1, 3, 1, 1)
        x = self.resnet50.conv1(src)
        x = self.resnet50.bn1(x)
        x = self.resnet50.relu(x)
        x = self.resnet50.maxpool(x)
        x = self.resnet50.layer1(x)
        x = self.resnet50.layer2(x)
        x = self.resnet50.layer3(x)
        x = self.resnet50.layer4(x)
        x = self.resnet50.fc1(x)
        b, c, h, w = x.size()
        assert h == 1, "the height of conv must be 1"
        x = x.squeeze(2) # [b, c, h*w]
        x = x.permute(2, 0, 1)  # [h*w, b, c]
        logits = self.rnn(x) # [h*w, b, num_classes]
        output = torch.nn.functional.log_softmax(logits, 2)
        return output

### MODEL 3 


In [None]:
class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)
    def forward(self, input):
        self.rnn.flatten_parameters()
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output

class Model3(nn.Module):

    def __init__(self, nHidden, num_classes):
        super(Model3, self).__init__()

        self.conv0 = Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv1 = Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv2 = Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv3 = Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv4 = Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv5 = Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv6 = Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1))

        self.pool0 = MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.pool1 = MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.pool3 = MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
        self.pool5 = MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)

        self.bn2 = BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.bn4 = BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.bn6 = BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        self.relu = ReLU()

        self.rnn = nn.Sequential(
            BidirectionalLSTM(nHidden*2, nHidden, nHidden),
            BidirectionalLSTM(nHidden, nHidden, num_classes))


    def forward(self, src):
        
        x = self.pool0(self.relu(self.conv0(src)))
        x = self.pool1(self.relu(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.pool3(self.relu(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.pool5(self.relu(self.conv5(x)))
        x = self.relu(self.bn6(self.conv6(x)))

        b, c, h, w = x.size()
        assert h == 1, "the height of conv must be 1"
        x = x.squeeze(2) # [b, c, h*w]
        x = x.permute(2, 0, 1)  # [h*w, b, c]
        logits = self.rnn(x) # [h*w, b, num_classes]
        output = torch.nn.functional.log_softmax(logits, 2)
        return output

### MODEL 4

"Fine-tuning Handwriting Recognition systems with Temporal Dropout" https://arxiv.org/pdf/2102.00511v1.pdf

In [None]:
def TemporalDropout(x, p = 0.2):
    B, C, WH = x.shape # BATCH_SIZE, CHANNELS, WIDTH*HEIGHT
    v = torch.ones(size=(WH,))
    for k in range(int(C*p)):
      i = random.randint(0,WH-1)
      v[i] = 0
    mask = torch.stack([v]*C).to(DEVICE)
    x = x*mask
    return x

class Model4(nn.Module):

    def __init__(self, nHidden, num_classes):
        super(Model4, self).__init__()

        self.act = LeakyReLU(negative_slope=0.01, inplace=False)
        self.conv0 = Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.conv1 = Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv5 = Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv7 = Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv8 = Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv9 = Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv10 = Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv11 = Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv12 = Conv2d(512, 512, kernel_size=4, stride=1, padding=1)

        self.pool1 = MaxPool2d(kernel_size=2, stride=2)
        self.pool2 = MaxPool2d(kernel_size=2, stride=2)
        self.pool3 = MaxPool2d(kernel_size=(2,1), stride=(2,1))
        self.pool4 = MaxPool2d(kernel_size=(2,1), stride=(2,1))
        self.pool5 = MaxPool2d(kernel_size=(2,1), stride=(2,1))

        self.bn1 = BatchNorm2d(64)
        self.bn2 = BatchNorm2d(128)
        self.bn3 = BatchNorm2d(256)
        self.bn4 = BatchNorm2d(512)

        self.rnn1 = BidirectionalLSTM(2*nHidden, nHidden, num_classes, num_layers=3)
        self.rnn2 = BidirectionalLSTM(2*nHidden, nHidden, num_classes, num_layers=3)


    def forward(self, src):
        
        x = self.act(self.bn1(self.conv0(src)))
        x = self.act(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = self.act(self.bn2(self.conv2(x)))
        x = self.act(self.bn2(self.conv3(x)))
        x = self.pool2(x)
        x = self.act(self.bn3(self.conv4(x)))
        x = self.act(self.bn3(self.conv5(x)))
        x = self.act(self.bn3(self.conv6(x)))
        x = self.pool3(x)
        x = self.act(self.bn4(self.conv7(x)))
        x = self.act(self.bn4(self.conv8(x)))
        x = self.act(self.bn4(self.conv9(x)))
        x = self.pool4(x)
        x = self.act(self.bn4(self.conv10(x)))
        x = self.act(self.bn4(self.conv11(x)))
        x = self.act(self.bn4(self.conv12(x)))
        x = self.pool5(x)
        b, c, h, w = x.size()
        assert h == 1, "the height of conv must be 1"
        x = x.squeeze(2) # [b, c, h*w]
        x = TemporalDropout(x, 0.2)
        x = x.permute(2, 0, 1)  # [h*w, b, c]
        output1 = self.rnn1(x)
        output2 = self.rnn2(x)
        output = torch.cat([output1, output2], 0)
        output = torch.nn.functional.log_softmax(output, 2)
        return output

### DATASET

In [None]:
# class for mapping symbols into indicies and vice versa
class LabelCoder(object):
    def __init__(self, alphabet, ignore_case=False):
        self.alphabet = alphabet
        self.char2idx = {}
        for i, char in enumerate(alphabet):
            self.char2idx[char] = i + 1
        self.char2idx[''] = 0

    def encode(self, text: str):
        length = []
        result = []
        for item in text:
            length.append(len(item))
            for char in item:
                if char in self.char2idx:
                    index = self.char2idx[char]
                else:
                    index = 0
                result.append(index)

        text = result
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
                                                                                                         length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
                t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts

        
class OCRdataset(Dataset):
    def __init__(self, path_to_imgdir: str, path_to_labels: str, transform_list = None):
        super(OCRdataset, self).__init__()
        self.imgdir = path_to_imgdir
        df = pd.read_csv(path_to_labels, sep = '\t', names = ['image_name', 'label'])
        self.image2label = [(self.imgdir + image, label) for image, label in zip(df['image_name'], df['label'])]
        if transform_list == None:
            transform_list =  [transforms.Grayscale(1),
                              transforms.Resize((64, 256)),
                              transforms.ToTensor(), 
                              transforms.Normalize((0.5,), (0.5,))]
        self.transform = transforms.Compose(transform_list)
        self.collate_fn = Collator()

    def __len__(self):
        return len(self.image2label)

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        image_path, label = self.image2label[index]
        img = Image.open(image_path)
        if self.transform is not None:
            img = self.transform(img)
        item = {'idx' : index, 'img': img, 'label': label}
        return item


class Collator(object):
    
    def __call__(self, batch):
        width = [item['img'].shape[2] for item in batch]
        indexes = [item['idx'] for item in batch]
        imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], 
                           max(width)], dtype=torch.float32)
        for idx, item in enumerate(batch):
            try:
                imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
            except:
                print(imgs.shape)
        item = {'img': imgs, 'idx':indexes}
        if 'label' in batch[0].keys():
            labels = [item['label'] for item in batch]
            item['label'] = labels
        return item

## AUGMENTATIONS

The following augmentations are used:

1. *Vignetting*: a reduction of an image's brightness or saturation toward the periphery compared to the image center.

2. Lens Distortion

In [None]:
import Augmentor


class Vignetting(object):
    def __init__(self,
                 p = 0.1,
                 ratio_min_dist=0.2,
                 range_vignette=(0.2, 0.8),
                 random_sign=False):
        self.ratio_min_dist = ratio_min_dist
        self.range_vignette = np.array(range_vignette)
        self.random_sign = random_sign
        self.p = p

    def __call__(self, X, Y=None):

        if np.random.binomial(1, self.p) == 0:
            return X 
        h, w = X.shape[1:]
        min_dist = np.array([h, w]) / 2 * np.random.random() * self.ratio_min_dist
        # create matrix of distance from the center on the two axis
        x, y = np.meshgrid(np.linspace(-w / 2, w / 2, w), np.linspace(-h / 2, h / 2, h))
        x, y = np.abs(x), np.abs(y)
        # create the vignette mask on the two axis
        x = (x - min_dist[0]) / (np.max(x) - min_dist[0])
        x = np.clip(x, 0, 1)
        y = (y - min_dist[1]) / (np.max(y) - min_dist[1])
        y = np.clip(y, 0, 1)
        # then get a random intensity of the vignette
        vignette = (x + y) / 2 * np.random.uniform(*self.range_vignette)
        vignette = np.tile(vignette[None, ...], [1, 1, 1])

        sign = 2 * (np.random.random() < 0.5) * (self.random_sign) - 1
        Z = X * (1 + sign * vignette)
        return Z


class LensDistortion(object):
    def __init__(self, p = 0.1 ,d_coef=(0.15, 0.05, 0.05, 0.05, 0.05)):
        self.d_coef = np.array(d_coef)
        self.p = p

    def __call__(self, X):
        if np.random.binomial(1, self.p) == 0:
            return X 

        # get the height and the width of the image
        h, w = X.shape[:2]

        # compute its diagonal
        f = (h ** 2 + w ** 2) ** 0.5

        # set the image projective to carrtesian dimension
        K = np.array([[f, 0, w / 2],
                      [0, f, h / 2],
                      [0, 0, 1]])

        d_coef = self.d_coef * np.random.random(5)  # value
        d_coef = d_coef * (2 * (np.random.random(5) < 0.5) - 1)  # sign
        # Generate new camera matrix from parameters
        M, _ = cv2.getOptimalNewCameraMatrix(K, d_coef, (w, h), 0)

        # Generate look-up tables for remapping the camera image
        remap = cv2.initUndistortRectifyMap(K, d_coef, None, M, (w, h), 5)

        # Remap the original image to a new image
        Z = cv2.remap(np.float32(X.numpy()), *remap, cv2.INTER_LINEAR)
        return torch.from_numpy(Z)
    


if APPLY_AUGS:
    transform_list = [
            transforms.Grayscale(1),
            transforms.Resize((64, 256)),
            #transforms.RandomRotation(degrees=(-9, 9), fill=255),
            #transforms.transforms.GaussianBlur(3, sigma=(0.1, 1.5)), 
            transforms.ToTensor(),
            #Vignetting(p = 0.5),
            #LensDistortion(p = 0.3),
            transforms.Normalize((0.5,), (0.5,))
        ]
else:
    transform_list = None

In [None]:
dataset = OCRdataset(PATH_TO_TRAIN_IMGDIR, PATH_TO_TRAIN_LABELS, transform_list = transform_list)
collator = Collator()
train_loader = torch.utils.data.DataLoader(dataset, batch_size = 8, collate_fn = collator, shuffle = True)

Explore some examples

In [None]:
examples = []
idx = 0

for batch in train_loader:
    img, true_label = batch['img'], batch['label']
    examples.append([img, true_label])
    idx += 1
    if idx == BATCH_SIZE:
        break
fig = plt.figure(figsize=(10, 10))
rows = int(BATCH_SIZE / 4) + 2
columns = int(BATCH_SIZE / 8) + 2
for j, exp in enumerate(examples):
    fig.add_subplot(rows, columns, j + 1)
    plt.imshow(exp[0][0].permute(2, 1, 0).permute(1, 0, 2))
    plt.title(exp[1][0])

# TRAIN

In [None]:
import math
import torch

import math
import torch

class CustomCTCLoss(torch.nn.Module):
    # T x B x H => Softmax on dimension 2
    def __init__(self, dim=2):
        super().__init__()
        self.dim = dim
        self.ctc_loss = torch.nn.CTCLoss(reduction='mean', zero_infinity=True)

    def forward(self, logits, labels,
            prediction_sizes, target_sizes):
        EPS = 1e-7
        loss = self.ctc_loss(logits, labels, prediction_sizes, target_sizes)
        loss = self.sanitize(loss)
        return self.debug(loss, logits, labels, prediction_sizes, target_sizes)
    
    def sanitize(self, loss):
        EPS = 1e-7
        if abs(loss.item()) > 99999:
            return torch.zeros_like(loss, requires_grad = True)
        if math.isnan(loss.item()):
            return torch.zeros_like(loss, requires_grad = True)
        return loss

    def debug(self, loss, logits, labels,
            prediction_sizes, target_sizes):
        if math.isnan(loss.item()):
            print("Loss:", loss)
            print("logits:", logits)
            print("labels:", labels)
            print("prediction_sizes:", prediction_sizes)
            print("target_sizes:", target_sizes)
            raise Exception("NaN loss obtained.")
        return loss

    
def print_epoch_data(epoch, mean_loss, char_error, word_error, time_elapsed, zero_out_losses):
    if epoch == 0:
        print('epoch | mean loss | mean cer | mean wer | time elapsed | warnings')
    epoch_str = str(epoch)
    zero_out_losses_str = str(zero_out_losses)
    if len(epoch_str) < 2:
        epoch_str = '0' + epoch_str
    if len(zero_out_losses_str) < 2:
        zero_out_losses_str = '0' + zero_out_losses_str
    report_line = epoch_str + ' '*7 + "%.3f" % mean_loss + ' '*7 + "%.3f" % char_error + ' '*7 + \
             "%.3f" % word_error + ' '*7 +  "%.1f" % float(time_elapsed)
    if zero_out_losses != 0:
        report_line += f'       {zero_out_losses} batch losses skipped due to nan value'
    print(report_line)
    
    
def fit(model, optimizer, loss_fn, loader, epochs = 64):
    report = []
    coder = LabelCoder(ALPHABET)
    for epoch in range(epochs):
        zero_out_losses = 0
        start_time = time.time()
        model.train()
        outputs = []
        for batch_nb, batch in enumerate(loader):
            optimizer.zero_grad()
            input_, targets = batch['img'], batch['label']
            targets, lengths = coder.encode(targets)
            logits = model(input_.to(DEVICE))
            logits = logits.contiguous().cpu()
            T, B, H = logits.size()
            pred_sizes = torch.LongTensor([T for i in range(B)])
            targets = targets.view(-1).contiguous()
            loss = loss_fn(logits, targets, pred_sizes, lengths)
            if (torch.zeros(loss.size()) == loss).all():
                zero_out_losses += 1
                continue
            probs, preds = logits.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = coder.decode(preds.data, pred_sizes.data, raw=False)

            char_error = sum([lev(batch['label'][i], sim_preds[i])/max(len(batch['label'][i]), len(sim_preds[i])) for i in range(len(batch['label']))])/len(batch['label'])
            word_error = 1 - sum([batch['label'][i] == sim_preds[i] for i in range(len(batch['label']))])/len(batch['label'])

            loss.backward()
            clip_grad_norm_(model.parameters(), 0.05)
            optimizer.step()
            output = {'loss': abs(loss.item()),'cer': char_error,'wer': word_error}
            outputs.append(output)
        
        if len(outputs) == 0:
            print('ERROR: bad loss, try to decrease learning rate and batch size')
            return None
        end_time = time.time()
        mean_loss = sum([outputs[i]['loss'] for i in range(len(outputs))])/len(outputs)
        char_error = sum([outputs[i]['cer'] for i in range(len(outputs))])/len(outputs)
        word_error = sum([outputs[i]['wer'] for i in range(len(outputs))])/len(outputs)
        report.append({'mean_loss' : mean_loss, 'mean_cer' : char_error, 'mean_wer' : word_error})
        print_epoch_data(epoch, mean_loss, char_error, word_error, end_time - start_time, zero_out_losses)
        if epoch%4 == 0:
            torch.save(model.state_dict(), PATH_TO_CHECKPOINT + 'checkpoint_epoch_' + str(epoch) + '.pt')
    return report 

Choose an architecture

In [None]:
model = Model3(256, len(ALPHABET))
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr = 0.00002)
loss_fn = CustomCTCLoss()

We advise to use small batch size initially but increase it by 4 each 12 epochs

In [None]:
report = fit(model, optimizer, loss_fn, train_loader, epochs = 84)

## TEST

In [None]:
def evaluate(model, loader):
    coder = LabelCoder(ALPHABET)
    labels, predictions = [], []
    for iteration, batch in enumerate(tqdm(loader)):
        input_, targets = batch['img'].to(DEVICE), batch['label']
        labels.extend(targets)
        targets, _ = coder.encode(targets)
        logits = model(input_)
        logits = logits.contiguous().cpu()
        T, B, H = logits.size()
        pred_sizes = torch.LongTensor([T for i in range(B)])
        probs, pos = logits.max(2)
        pos = pos.transpose(1, 0).contiguous().view(-1)
        sim_preds = coder.decode(pos.data, pred_sizes.data, raw=False)
        predictions.extend(sim_preds)
    char_error = sum([lev(labels[i], predictions[i])/max(len(labels[i]), len(predictions[i])) for i in range(len(labels))])/len(labels)
    word_error = 1 - sum([labels[i] == predictions[i] for i in range(len(labels))])/len(labels)
    return {'char_error' : char_error, 'word_error' : word_error}

In [None]:
test_dataset = OCRdataset(PATH_TO_TEST_IMGDIR, PATH_TO_TEST_LABELS)
collator = Collator()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = 2, collate_fn = collator)
evaluate(model, test_loader)

In [None]:
 def predict(model, img):
    logits = model(img.to(DEVICE))
    logits = logits.contiguous().cpu()
    T, B, H = logits.size()
    pred_sizes = torch.LongTensor([T for i in range(B)])
    probs, pos = logits.max(2)
    pos = pos.transpose(1, 0).contiguous().view(-1)
    sim_preds = coder.decode(pos.data, pred_sizes.data, raw=False)
    return sim_preds

In [None]:
examples = []
idx = 0
coder = LabelCoder(ALPHABET)
for batch in test_loader:
    img, true_label = batch['img'], batch['label']
    pred_label = predict(model, img)
    examples.append([img, true_label, pred_label])
    idx += 1
    if idx == 9:
        break
fig = plt.figure(figsize=(10, 10))
rows = int(9 / 4) + 2
columns = int(9 / 8) + 2
for j, exp in enumerate(examples):
    fig.add_subplot(rows, columns, j + 1)
    plt.imshow(exp[0][0].permute(2, 1, 0).permute(1, 0, 2))
    plt.title('true:' + exp[1][0] + '\npred:' + exp[2][0], loc = 'left')