In [3]:
import cv2
import torch

device = "cuda" if torch.cuda.is_available() else "cpu" # I use AMD... 😭
device

'cpu'

## Initialize dataset directory

In [4]:
train_dir = ''
test_dir = ''   
val_dir = ''

## Initialize dataset

In [5]:
import json
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class ImageDataset(Dataset):
    def __init__(self, image_dir, word_dir):
        self.image_dir = image_dir
        with open(word_dir, 'r') as f:
            self.words = json.load(f)
        self.images = list(self.image_dir.glob('*.jpg'))
        self.transform = transforms.Compose([
            transforms.Lambda(lambda img: cv2.cvtColor(img, cv2.COLOR_BGR2RGB)),
            transforms.ToTensor(),
            transforms.Resize((256, 256))
        ])
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        assert idx <= len(self), 'Index out of range'
        try:
            rgb_img = self.transform(Image.open(self.images[idx]).convert('RGB'))
            bw_img = self.transform(Image.open(self.images[idx]).convert('L'))
            return rgb_img, bw_img
        except Exception as e:
            return torch.tensor(-1), torch.tensor(-1)
    
        # item = {'image': img, 'idx': idx, 'label': self.words[self.images[idx].name]}
        # return item

# Loss function

### 1. VGG19 Typeface classifier, C - Text Perceptual Loss
- Perceptual loss computed from the feature maps at layer i denoted as φi and Mi is the number of elements in the particular feature map which is used as normalization. Only computed for output image corresponding to original content.
- Texture loss / style loss computed from Gram matrix of the feature maps.
- Embedding-based loss computed from feature maps of the penultimate layer of this network. (???)

https://gist.github.com/alper111/8233cdb0414b4cb5853f2f730ab95a49#gistcomment-3347450

In [6]:
# Paper trains this model with Synth-Font dataset, however I could not find it
class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks= [torchvision.models.vgg19(pretrained=True).features[:4].eval(),
                 torchvision.models.vgg19(pretrained=True).features[4:9].eval(),
                 torchvision.models.vgg19(pretrained=True).features[9:16].eval(),
                 torchvision.models.vgg19(pretrained=True).features[16:23].eval()]
        for bl in blocks:
            for p in bl.parameters():   
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate
        self.resize = resize
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    # default: feature_layers=[0, 1, 2, 3], style_layers=[]    
    def forward(self, prediction, target, feature_layers=None, style_layers=None):    
        if style_layers is None:
            style_layers = [0, 1, 2, 3]
        if feature_layers is None:
            feature_layers = [2]
        if prediction.shape[1] != 3:
            prediction = prediction.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        prediction = (prediction - self.mean) / self.std
        target = (target-self.mean)/self.std
        if self.resize:
            prediction = self.transform(prediction, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = prediction
        y = target
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
            if i in feature_layers:
                loss += torch.nn.functional.l1_loss(x, y)
            if i in style_layers:
                act_x = x.reshape(x.shape[0], x.shape[1], -1)
                act_y = y.reshape(y.shape[0], y.shape[1], -1)
                gram_x = act_x @ act_x.permute(0, 2, 1)
                gram_y = act_y @ act_y.permute(0, 2, 1)
                loss += torch.nn.functional.l1_loss(gram_x, gram_y)
        return loss

### 2. OCR, R - Text Content Loss
- The relevant modules has already been added along with the saved configuration mentioned in the paper, and the PyTorch model can be downloaded [here](https://drive.google.com/file/d/1b59rXuGGmKne1AuHnkgDzoYgKeETNMv9/view?usp=drive_link). Have not implemented it to calculate loss yet.
- The content loss is computed by measuring the cross entropy between the sequence of characters in the input string, c1, c2, the predicted string, c'1, c'2 respectively and are represented as one-hot vectors.

https://github.com/clovaai/deep-text-recognition-benchmark

In [None]:
from torch import nn
class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
    
    def forward(self, predictions, target):
        

In [34]:
import ocr
import argparse
from ocr.utils import AttnLabelConverter
from ocr.model import Model

class opt:
    # Hardcoding the arguments instead of using argument parsers 
    # as per https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/demo.py#L96
    image_folder = 'images'
    workers = 4
    batch_size = 192
    saved_model = 'ocr/TPS-ResNet-BiLSTM-Attn.pth'
    batch_max_length = 25
    imgH = 32
    imgW = 100
    rgb = False # See input_channel comment below
    character = '0123456789abcdefghijklmnopqrstuvwxyz'
    sensitive = True
    PAD = True
    Transformation = 'TPS'
    FeatureExtraction = 'ResNet'
    SequenceModeling = 'BiLSTM'
    Prediction = 'Attn'
    num_fiducial = 20
    input_channel = 1
    output_channel = 512
    hidden_size = 256
    
    
converter = AttnLabelConverter(opt.character)
opt.num_class = len(converter.character)

if opt.rgb:
    opt.input_channel = 3 # This breaks loading, input_channel has to be 1, or rgb false

model = Model(opt)
model = torch.nn.DataParallel(model).to(device)

mappings = torch.load(opt.saved_model, map_location=device)
model.load_state_dict(mappings)
# Need to load data, presumably through DataLoader(). Idk hows the input data like yet, so i hold this.
#demo_loader = torch.utils.data.DataLoader(
#       demo_data, batch_size=opt.batch_size,
#       shuffle=False,
#       num_workers=int(opt.workers),
#       collate_fn=AlignCollate_demo, pin_memory=True)
model.eval()

# It's getting late, idk anymore, over to you


DataParallel(
  (module): Model(
    (Transformation): TPS_SpatialTransformerNetwork(
      (LocalizationNetwork): LocalizationNetwork(
        (conv): Sequential(
          (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (6): ReLU(inplace=True)
          (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (10): ReLU

# 1. Style Encoder
This is a homebrew version of ResNet34 architecture,I have no idea if this works, but I am prepared to alter it in the events it fails. 

In [18]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels)
            )
        
    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

In [None]:
from torch import nn
import torchvision
class StyleEncoder(nn.Module):
    def __init__(self):
        super(StyleEncoder, self).__init__()
        self.layer_stack = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3), #256x256
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2), #128x128
            self._make_layer(self, ResidualBlock, 64, 128, stride=2),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2), #64x64
            self._make_layer(self, ResidualBlock, 128, 256, stride=2),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(kernel_size=2, stride=2), #32x32
            self._make_layer(self, ResidualBlock, 256, 512, stride=2),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(512),
            nn.MaxPool2d(kernel_size=2, stride=2), #16x16
            self._make_layer(self, ResidualBlock, 512, 512, stride=2),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1),
        )
        
    def _make_layer(self, block, in_channels, out_channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return torchvision.ops.roi_align(input=self.layer_stack(x)) #1x1