In [6]:
import os
from PIL import Image

import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import torch.nn as nn


In [2]:
class SynthDataset(Dataset):
    def __init__(self, opt):
        super(SynthDataset, self).__init__()
        self.path = os.path.join(opt['path'], opt['imgdir'])
        self.images = os.listdir(self.path)
        self.nSamples = len(self.images)
        f = lambda x: os.path.join(self.path, x)
        self.imagepaths = list(map(f, self.images))
        transform_list = [
            transforms.Grayscale(1),
            transforms.toTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ]
        self.transform = transforms.Compose(transform_list)
        self.collate_fn = SynthCollator()
        
    def __len__(self):
        return self.nSamples
    
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imagepath = self.imagepaths[index]
        imagefile = os.path.basename(imagepath)
        img = Image.open(imagepath)
        if self.transform is not None:
            img = self.transform(img)
        item = {'img': img, 'idx': index}
        item['label'] = imagefile.split('_')[0]
        return item

In [5]:
class SynthCollator(object):
    
    def __call__(self, batch):
        width = [item['img'].shape[2] for item in batch]
        indexes = [item['idx'] for item in batch]
        imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1],
                           max(width)], dtype=torch.float32)
        
        for idx, item in enumerate(batch):
            try:
                imgs[idx, :, :, 0:item['img'].shape[2]] = item['img']
            except:
                print(imgs.shape)
                
        item = {'img': imgs, 'idx': indexes}
        if 'label' in batch[0].keys():
            labels = [item['label'] for item in batch]
            item['label']  = labels
        
        return item

In [None]:
class BidirectionalLSTM(nn.Module):
    
    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHIdden * 2, nOut)
        
    def forward(self, input):
        self.rnn.flatten_parameters()
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)
        output = self.embedding(t_rec)
        output = output.view(T, b, -1)
        return output
    
class CRNN(nn.Module):
    
    def __init__(self, opt, leakyRelu=False):
        super(CRNN, self).__init__()
        
        assert opt['imgH'] % 16 == 0, 'imgH has to be a multiple of 16'
        
        ks = [3, 3, 3, 3, 3, 3, 2] # kernel size
        ps = [1, 1, 1, 1, 1, 1, 0] # padding size
        ss = [1, 1, 1, 1, 1, 1, 1] # stride size
        nm = [64, 128, 256, 256, 512, 512, 512] #
        
        cnn = nn.Sequential()

        def convRelu(i, batchNormalization=False):
            nIn = opt['nChannels'] if i == 0 else nm[i -1]
            nOut = nm[i]
            cnn.add_module(f'conv{i}',
                           nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
        
            if batchNormalization:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(nOut))
            
            if leakyRelu:
                cnn.add_module(f'relu{i}', nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module(f'relu{i}', nn.ReLU(True))
                
        convRelu(0)
        cnn.add_module(f'pooling0', nn.MaxPool2d(2, 2)) # 64x16x64
        convRelu(1)
        cnn.add_module(f'pooling1', nn.MaxPool2d(2, 2)) # 128x8x32
        convRelu(2, True)
        convRelu(3)
        cnn.add_module('pooling2',
                       nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
        convRelu(6, True) # 512x1x16
        self.cnn = cnn
        self.rnn = nn.Sequential()
        self.rnn = nn.Sequential(
            BidirectionalLSTM(opt['nHidden'] * 2, opt['nHidden'], opt['nHidden']),
            BidirectionalLSTM(opt['nHidden'] * 2, opt['nHidden'], opt['nClasses'])
        )
        
    def forward(self, input):
        # conv features
        conv = self.cnn(input)
        b, c, h, w = conv.size()
        assert h == 1, "the height of conv must be 1"
        conv = conv.squeeze(2)
        conv = conv.permute(2, 0, 1) # [w, b, c]
        # rnn features
        output = self.rnn(conv)
        output = output.transpose(1, 0) # Tbh to bth
        return output
