In [85]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
from torchvision import transforms as T
from torch.utils.data import Dataset, DataLoader

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm.autonotebook import tqdm
from glob import glob

from sklearn.model_selection import train_test_split

!pip install -qq editdistance torchsummary
import editdistance
from torchsummary import summary

seed = 42

In [86]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

PATH = r'../input/handwritten/synthetic-data/'
LR = 0.001
BATCH_SIZE = 32
HIDDEN = 512
ENC_LAYERS = 2
DEC_LAYERS = 2
N_HEADS = 4
DROPOUT = 0.1
IMG_WIDTH = 256
IMG_HEIGHT = 64
EPOCHS = 100

VOCAB = ['PAD', 'SOS', ' ',] + [char for char in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'] + ['EOS']

char2idx = {char: idx for idx, char in enumerate(VOCAB)}
idx2char = {idx: char for idx, char in enumerate(VOCAB)}

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currenly using {device.upper()} device.')

In [87]:
def text_to_labels(text, char2idx=char2idx):
    return [char2idx['SOS']] + [char2idx[i.upper()] for i in text if i.upper() in char2idx.keys()] + [char2idx['EOS']]

def labels_to_text(text, idx2char=idx2char):
    S = "".join([idx2char[i] for i in text])
    if S.find('EOS') == -1:
        return S
    else:
        return S[:S.find('EOS')]

def char_error_rate(p_seq1, p_seq2):
    p_vocab = set(p_seq1 + p_seq2)
    p2c = dict(zip(p_vocab, range(len(p_vocab))))
    c_seq1 = [chr(p2c[p]) for p in p_seq1]
    c_seq2 = [chr(p2c[p]) for p in p_seq2]
    return editdistance.eval(''.join(c_seq1),
                             ''.join(c_seq2)) / max(len(c_seq1), len(c_seq2))

In [88]:
transforms = T.Compose([
                        T.ToPILImage(),
                        T.Resize((IMG_HEIGHT, IMG_WIDTH)),
                        T.ColorJitter(contrast=(0.5,1),saturation=(0.5,1)),
                        T.RandomRotation(degrees=(-9, 9), fill=255),
                        T.RandomAffine(10 ,None ,[0.6 ,1] ,3 ,fillcolor=255),
                        T.ToTensor()
                        ])
valid_transforms = T.Compose([
                              T.ToPILImage(),
                              T.Resize((IMG_HEIGHT, IMG_WIDTH)),
                              T.ToTensor()
                              ])

In [89]:
images_paths = glob(PATH+'*.png')
images_paths = sorted([str(path) for path in images_paths])
images_labels = [path.split('/')[-1].split('_')[0] for path in images_paths]
df = pd.DataFrame(data={'path': images_paths, 'label': images_labels})
df.sample(3)

In [90]:
df_train, df_valid = train_test_split(df, test_size=0.25, shuffle=True, random_state=seed)
print(f'Train size: {df_train.shape[0]}, validation size: {df_valid.shape[0]}')

In [91]:
class CharDataset(Dataset):
    def __init__(self, dataframe, transforms=transforms):
        self.dataframe = dataframe
        self.transforms = transforms
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx].squeeze()
        image = cv2.imread(row['path'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = transforms(image)
        label = text_to_labels(row['label'])
        return torch.FloatTensor(image / 255.), torch.LongTensor(label) 
    
    def collate_fn(self, batch):
        x_padded = []
        max_y_len = max([i[1].size(0) for i in batch])
        y_padded = torch.LongTensor(max_y_len, len(batch))
        y_padded.zero_()

        for i in range(len(batch)):
            x_padded.append(batch[i][0].unsqueeze(0))
            y = batch[i][1]
            y_padded[:y.size(0), i] = y

        x_padded = torch.cat(x_padded)
        return x_padded.to(device), y_padded.to(device)

In [92]:
class TransformerModel(nn.Module):
    def __init__(self, bb_name, outtoken, hidden, enc_layers=1, dec_layers=1, nhead=1, dropout=0.1, pretrained=False):
        super(TransformerModel, self).__init__()
        self.backbone = torchvision.models.__getattribute__(bb_name)(pretrained=pretrained)
        self.backbone.fc = nn.Conv2d(2048, int(hidden/2), 1)

        self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.decoder = nn.Embedding(outtoken, hidden)
        self.pos_decoder = PositionalEncoding(hidden, dropout)
        self.transformer = nn.Transformer(d_model=hidden, nhead=nhead, num_encoder_layers=enc_layers,
                                          num_decoder_layers=dec_layers, dim_feedforward=hidden * 4, dropout=dropout,
                                          activation='relu')

        self.fc_out = nn.Linear(hidden, outtoken)
        self.src_mask = None
        self.trg_mask = None
        self.memory_mask = None

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask

    def make_len_mask(self, inp):
        return (inp == 0).transpose(0, 1)

    def forward(self, src, trg):
        if self.trg_mask is None or self.trg_mask.size(0) != len(trg):
            self.trg_mask = self.generate_square_subsequent_mask(len(trg)).to(device) 
        x = self.backbone.conv1(src)

        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x) # [64, 2048, 2, 8] : [B,C,H,W]
            
        x = self.backbone.fc(x) # [64, 256, 2, 8] : [B,C,H,W]
        x = x.permute(0, 3, 1, 2) # [64, 8, 256, 2] : [B,W,C,H]
        x = x.flatten(2) # [64, 8, 512] : [B,W,CH]
        x = x.permute(1, 0, 2) # [8, 64, 512] : [W,B,CH]
        src_pad_mask = self.make_len_mask(x[:, :, 0])
        src = self.pos_encoder(x) # [8, 64, 512]
        trg_pad_mask = self.make_len_mask(trg)
        trg = self.decoder(trg)
        trg = self.pos_decoder(trg)

        output = self.transformer(src, trg, src_mask=self.src_mask, tgt_mask=self.trg_mask,
                                  memory_mask=self.memory_mask,
                                  src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=trg_pad_mask,
                                  memory_key_padding_mask=src_pad_mask) # [13, 64, 512] : [L,B,CH]
        output = self.fc_out(output) # [13, 64, 92] : [L,B,H]

        return output
    
class PositionalEncoding(nn.Module):  # when having a sentences
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.scale = nn.Parameter(torch.ones(1))

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.scale * self.pe[:x.size(0), :]
        return self.dropout(x) 

In [93]:
def train_one_batch(model, data, optimizer, criterion):
    model.train()
    image, target = data
    optimizer.zero_grad()
    output = model(image, target[:-1, :])
    loss = criterion(output.view(-1, output.shape[-1]), torch.reshape(target[1:, :], (-1,)))
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def validate_one_batch(model, data, criterion):
    model.eval()
    image, target = data
    output = model(image, target[:-1, :])
    loss = criterion(output.view(-1, output.shape[-1]), torch.reshape(target[1:, :], (-1,)))
    return loss.item()

def evaluate(model, dataloader, max_len=30):  # assuming dataloader has batch_size=1
    model.eval()
    wer_overall = 0
    cer_overall = 0
    with torch.no_grad():
        for src, trg in tqdm(dataloader, leave=False):
            out_indexes = [char2idx['SOS'], ]

            for i in range(max_len):
                trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                output = model(src, trg_tensor)
                out_token = output.argmax(2)[-1].item()
                out_indexes.append(out_token)
                if out_token == char2idx['EOS']:
                    break
                    
            out_char = labels_to_text(out_indexes[1:])
            real_char = labels_to_text(trg[1:, 0].detach().cpu().numpy()).lower()
            wer_overall += int(real_char != out_char)
            if out_char:
                cer = char_error_rate(real_char, out_char)
            else:
                cer = 1
            
            cer_overall += cer
    
    return cer_overall / len(dataloader) * 100, wer_overall / len(dataloader) * 100

@torch.no_grad()
def prediction(model, filepath='random', max_len=30):
    label = None
    if filepath == 'random':
        idx = np.random.randint(len(df_valid))
        filepath = df_valid.iloc[idx, 0]
        label = df_valid.iloc[idx, 1]

    model.eval()
    img = cv2.imread(filepath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    image = valid_transforms(img)
    src = torch.FloatTensor(image / 255.).unsqueeze(0).to(device)

    out_indexes = [char2idx['SOS'], ]

    for i in range(max_len):
                
        trg_tensor = torch.LongTensor(out_indexes).unsqueeze(1).to(device)
                
        output = model(src, trg_tensor)
        out_token = output.argmax(2)[-1].item()
        out_indexes.append(out_token)
        if out_token == char2idx['EOS']:
            break
    preds = labels_to_text(out_indexes[1:], idx2char)
    plt.figure(figsize=(6,4))
    plt.title(f'Prediction: {preds}, Truth: {label if label is not None else "NO label"}')
    plt.imshow(img)
    plt.tight_layout()
    plt.show()
    plt.pause(0.001)
    
    return preds

In [94]:
train_dataset = CharDataset(df_train, transforms)
valid_dataset = CharDataset(df_valid, valid_transforms)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate_fn, drop_last=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, collate_fn=valid_dataset.collate_fn, drop_last=True)

In [95]:
model = TransformerModel('resnet50', len(VOCAB), hidden=HIDDEN, enc_layers=ENC_LAYERS, dec_layers=DEC_LAYERS,   
                         nhead=N_HEADS, dropout=DROPOUT).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-6)
criterion = nn.CrossEntropyLoss(ignore_index=0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, mode='min', patience=10, min_lr=1e-6,)

In [96]:
train_losses, valid_losses = [], []

for epoch in range(EPOCHS):
    print(f'{epoch+1}/{EPOCHS} epoch.')
    epoch_train_losses, epoch_valid_losses = [], []
    for _, batch in enumerate(tqdm(train_dataloader, leave=False)):
        loss = train_one_batch(model, batch, optimizer, criterion)
        epoch_train_losses.append(loss)
        
    train_epoch_loss = np.array(epoch_train_losses).mean()
    train_losses.append(train_epoch_loss)
    
    for _, batch in enumerate(tqdm(valid_dataloader, leave=False)):
        loss = validate_one_batch(model, batch, criterion)
        epoch_valid_losses.append(loss)
        
    valid_epoch_loss = np.array(epoch_valid_losses).mean()
    valid_losses.append(valid_epoch_loss)
    print(f'Train loss: {train_epoch_loss:.4f}, validation loss: {valid_epoch_loss:.4f}')
    
    if (epoch + 1) % 5 == 0:
        valid_cer, valid_wer = evaluate(model, valid_dataloader)
        print(f'Char_error_rate: {valid_cer:.4f}, Word_error_rate: {valid_wer:.4f}')
    
    scheduler.step(epoch_loss)
    
    if (epoch+1) % 10 == 0:
        pred = prediction(model)
        torch.save(model.state_dict(), 'model.pth')

In [79]:
prediction(model)