In [1]:
%cd ..

/home/tuannm/sonhoang/Vietnamese-OCR-from-scratch-pytorch


In [2]:
import os
from PIL import Image
import math
import numpy as np
from typing import OrderedDict
import logging

import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import torch.nn as nn

In [3]:
class SynthDataset(Dataset):
    def __init__(self, opt):
        super(SynthDataset, self).__init__()
        self.path = os.path.join(opt['path'], opt['imgdir'])
        self.images = os.listdir(self.path)
        self.nSamples = len(self.images)
        f = lambda x: os.path.join(self.path, x)
        self.imagepaths = list(map(f, self.images))
        transform_list = [
            transforms.Grayscale(1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]
        self.transform = transforms.Compose(transform_list)
        self.collate_fn = SynthCollator()
        
    def __len__(self):
        return self.nSamples
    
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imagepath = self.imagepaths[index]
        imagefile = os.path.basename(imagepath)
        img = Image.open(imagepath)
        if self.transform is not None:
            img = self.transform(img)
        item = {'img': img, 'idx': index}
        item['label'] = imagefile.split('_')[0]
        return item

In [4]:
class SynthCollator(object):
    
    def __call__(self, batch):
        width = [item['img'].shape[2] for item in batch]
        indexes = [item['idx'] for item in batch]
        imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1],
                           max(width)], dtype=torch.float32)
        
        for idx, item in enumerate(batch):
            try:
                imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
            except:
                print(imgs.shape)
                
        item = {'img': imgs, 'idx': indexes}
        if 'label' in batch[0].keys():
            labels = [item['label'] for item in batch]
            item['label']  = labels
        
        return item

In [62]:
alphabet = "AÀÁẢÃẠaàáảãạĂẰẮẲẴẶăằắẳẵặÂẦẤẨẪẬâầấẩẫậBbCcDdĐđEÈÉẺẼẸeèéẻẽẹÊỀẾỂỄỆêềếểễệGgHhIÌÍỈĨỊiìíỉĩịKkLlMmNnOÒÓỎÕỌoòóỏõọÔỒỐỔỖỘôồốổỗộƠỜỚỞỠỢơờớởỡợPpQqRrSsTtUÙÚỦŨỤuùúủũụƯỪỨỬỮỰưừứửữựVvXxYỲÝỶỸỴyỳýỷỹỵZz0123456789jJwWfF"

args = {
    'name':'exp1',
    'path':'data',
    'imgdir': 'train',
    'imgH':32,
    'nChannels':1,
    'nHidden':256,
    'nClasses':len(alphabet),
    'lr':0.001,
    'epochs':4,
    'batch_size':32,
    'save_dir':'checkpoints',
    'log_dir':'logs',
    'resume':False,
    'cuda':False,
    'schedule':False
    
}

In [44]:
class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)
    def forward(self, input):
        self.rnn.flatten_parameters()
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)
        return output

class CRNN(nn.Module):

    def __init__(self, opt, leakyRelu=False):
        super(CRNN, self).__init__()

        assert opt['imgH'] % 16 == 0, 'imgH has to be a multiple of 16'

        ks = [3, 3, 3, 3, 3, 3, 2]
        ps = [1, 1, 1, 1, 1, 1, 0]
        ss = [1, 1, 1, 1, 1, 1, 1]
        nm = [64, 128, 256, 256, 512, 512, 512]

        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = opt['nChannels'] if i == 0 else nm[i - 1]
            nOut = nm[i]
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

        convRelu(0)
        cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64
        convRelu(1)
        cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling{0}'.format(2),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16
        convRelu(4, True)
        convRelu(5)
        cnn.add_module('pooling{0}'.format(3),
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16
        convRelu(6, True)  # 512x1x16
        self.cnn = cnn
        self.rnn = nn.Sequential()
        self.rnn = nn.Sequential(
            BidirectionalLSTM(opt['nHidden']*2, opt['nHidden'], opt['nHidden']),
            BidirectionalLSTM(opt['nHidden'], opt['nHidden'], opt['nClasses']))


    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1)  # [w, b, c]
        # rnn features
        output = self.rnn(conv)
        output = output.transpose(1,0) #Tbh to bth
        return output

In [80]:
cnn_out = model.cnn(data[10]['img'].unsqueeze(0))

In [81]:
cnn_out.size(), cnn_out

(torch.Size([1, 512, 1, 17]),
 tensor([[[[0.3934, 2.0215, 1.1573,  ..., 0.0000, 0.0000, 0.0351]],
 
          [[0.0000, 0.0000, 0.4576,  ..., 0.0000, 0.0000, 1.5227]],
 
          [[0.0000, 1.1980, 2.6751,  ..., 0.0000, 0.0000, 0.0000]],
 
          ...,
 
          [[1.5306, 0.2778, 0.0000,  ..., 2.1162, 1.7593, 0.0000]],
 
          [[0.0000, 0.0000, 0.0000,  ..., 2.4536, 1.3614, 0.5925]],
 
          [[0.0135, 0.0000, 0.0000,  ..., 0.2320, 0.8973, 0.0000]]]],
        grad_fn=<ReluBackward0>))

In [82]:
cnn_out[0][0]

tensor([[0.3934, 2.0215, 1.1573, 0.0000, 0.0000, 0.0000, 0.1979, 0.0871, 0.2027,
         0.4868, 0.7580, 0.4469, 0.0000, 0.0000, 0.0000, 0.0000, 0.0351]],
       grad_fn=<SelectBackward0>)

In [83]:
cnn_permute = cnn_out.squeeze(2).permute(2, 0 ,1)

In [84]:
cnn_permute.size(), cnn_permute

(torch.Size([17, 1, 512]),
 tensor([[[0.3934, 0.0000, 0.0000,  ..., 1.5306, 0.0000, 0.0135]],
 
         [[2.0215, 0.0000, 1.1980,  ..., 0.2778, 0.0000, 0.0000]],
 
         [[1.1573, 0.4576, 2.6751,  ..., 0.0000, 0.0000, 0.0000]],
 
         ...,
 
         [[0.0000, 0.0000, 0.0000,  ..., 2.1162, 2.4536, 0.2320]],
 
         [[0.0000, 0.0000, 0.0000,  ..., 1.7593, 1.3614, 0.8973]],
 
         [[0.0351, 1.5227, 0.0000,  ..., 0.0000, 0.5925, 0.0000]]],
        grad_fn=<PermuteBackward0>))

In [63]:
args['alphabet'] = alphabet
model = CRNN(args)

In [46]:
data = SynthDataset(args)

In [47]:
data[10]['img'].size()

torch.Size([1, 32, 64])

In [48]:
model

CRNN(
  (cnn): Sequential(
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU(inplace=True)
    (pooling0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU(inplace=True)
    (pooling1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU(inplace=True)
    (pooling2): MaxPool2d(kernel_size=(2, 2), stride=(2, 1), padding=(0, 1), dilation=1, ceil_mode=False)
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.

In [88]:
out = model(data[10]['img'].unsqueeze(0))

In [57]:
data[10]['img'].size()

torch.Size([1, 32, 64])

In [89]:
out.transpose(1, 0).size()

torch.Size([17, 1, 196])

In [90]:
out = out.transpose(1, 0)

In [91]:
logits = torch.nn.functional.log_softmax(out, 2)

In [92]:
logits, logits.size()

(tensor([[[-5.3379, -5.2193, -5.2672,  ..., -5.2889, -5.2489, -5.2710]],
 
         [[-5.3359, -5.2211, -5.2689,  ..., -5.2991, -5.2500, -5.2812]],
 
         [[-5.3356, -5.2243, -5.2651,  ..., -5.3010, -5.2502, -5.2842]],
 
         ...,
 
         [[-5.3185, -5.2342, -5.2566,  ..., -5.3004, -5.2426, -5.2962]],
 
         [[-5.3188, -5.2425, -5.2545,  ..., -5.3087, -5.2352, -5.2978]],
 
         [[-5.3171, -5.2432, -5.2548,  ..., -5.3100, -5.2435, -5.2845]]],
        grad_fn=<LogSoftmaxBackward0>),
 torch.Size([17, 1, 196]))

In [116]:
probs, preds = logits.max(2)
probs, preds

(tensor([[-5.1980],
         [-5.1909],
         [-5.1935],
         [-5.1910],
         [-5.1870],
         [-5.1836],
         [-5.1813],
         [-5.1882],
         [-5.1907],
         [-5.1940],
         [-5.1962],
         [-5.1926],
         [-5.1883],
         [-5.1813],
         [-5.1814],
         [-5.1887],
         [-5.2074]], grad_fn=<MaxBackward0>),
 tensor([[ 79],
         [164],
         [164],
         [164],
         [164],
         [150],
         [150],
         [150],
         [164],
         [164],
         [164],
         [164],
         [164],
         [164],
         [164],
         [164],
         [ 64]]))

In [123]:
alphabet[79-1]

'i'

In [124]:
sim_preds = converter.decode(preds.data, pred_sizes.data, raw=True)

In [125]:
sim_preds

'ivvvvụụụvvvvvvvvề'

In [93]:
T, B, H = logits.size()
pred_sizes = torch.LongTensor([T for i in range(B)])

In [94]:
pred_sizes

tensor([17])

In [64]:
from src.utils.utils import OCRLabelConverter

converter = OCRLabelConverter(args["alphabet"])

In [106]:
targets, lenghts = converter.encode([data[10]['label']])

In [107]:
targets, lenghts

(tensor([ 40,  72, 147,   7], dtype=torch.int32),
 tensor([4], dtype=torch.int32))

In [108]:
targets.unsqueeze(0).view(-1).size()

torch.Size([4])

In [109]:
from src.criterions.ctc import CustomCTCLoss

criterion = CustomCTCLoss()

In [110]:
loss = criterion(logits, targets.unsqueeze(0).view(-1), pred_sizes, lenghts)