In [1]:
# from importlib import import_module

def create_instance(config, *args, **kwargs):
#     module = config['module']
    name = config['name']
    config_kwargs = config.get(name, {})
    for key, value in config_kwargs.items():
        if isinstance(value, str):
            config_kwargs[key] = eval(value)
#     return getattr(import_module(module), name)(*args, **config_kwargs, **kwargs)
    return eval(name)(*args, **config_kwargs, **kwargs)

In [2]:
import torch

from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.utils import load_state_dict_from_url
from torchvision.models import resnet
from torchvision.models.segmentation import fcn


model_urls = {
    'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth',
    'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth',
}


def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True, replace_stride_with_dilation=[False, True, True]):
    backbone = resnet.__dict__[backbone_name](
        pretrained=pretrained_backbone,
        replace_stride_with_dilation=replace_stride_with_dilation)

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        inplanes = 1024
        aux_classifier = fcn.FCNHead(inplanes, num_classes)

    model_map = {
        'fcn': (fcn.FCNHead, fcn.FCN),
    }
    inplanes = 2048
    classifier = model_map[name][0](inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model


def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
    if pretrained:
        aux_loss = True
    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
    if pretrained:
        arch = arch_type + '_' + backbone + '_coco'
        model_url = model_urls[arch]
        if model_url is None:
            raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
        else:
            state_dict = load_state_dict_from_url(model_url, progress=progress)
            model.load_state_dict(state_dict)
    return model


class FCN(nn.Module):
    def __init__(self, backbone, pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs):
        super(FCN, self).__init__()
        supported_backbone = [
            'resnet50',
            'resnet101',
        ]

        if backbone not in supported_backbone:
            raise ValueError('{} is not supported.'.format(backbone))

        self.model = _load_model('fcn', backbone, pretrained, progress, num_classes, aux_loss, **kwargs)

    def forward(self, x):
        output = self.model(x)['out']
        output = torch.nn.functional.softmax(output, dim=1)
        return output


In [3]:
arch_config = {
    'name': 'FCN',
    'FCN': {
        'replace_stride_with_dilation': [True, True, True],
        'backbone': '"resnet50"',
        'pretrained_backbone': False
    }
}

In [4]:
model = create_instance(arch_config)

In [5]:
import cv2
import torch

class Extractor():
    def __init__(self, arch_config, image_size, weight_path, device, **kwargs):
        super(Extractor, self).__init__()
        self.device = device
        self.image_size = image_size
        self.model = create_instance(arch_config, **kwargs)
        self.model.load_state_dict(torch.load(weight_path, map_location='cpu'))
        self.model.eval()
        self.model.to(self.device)

    def preprocess(self, image):
        sample = torch.from_numpy(cv2.resize(image, dsize=self.image_size))
        samples = sample.unsqueeze(dim=0).to(self.device).to(torch.float)
        samples = samples.permute(0, 3, 1, 2)
        samples = (samples - samples.mean(dim=(1, 2, 3), keepdim=True)) / samples.std(dim=(1, 2, 3), keepdim=True)
        return image, samples

    def process(self, image, samples):
        with torch.no_grad():
            return image, self.model(samples)

    def postprocess(self, image, preds):
        preds = preds.permute(0, 2, 3, 1).detach().cpu().numpy()
        pred = preds[0]  # input one image
        return image, pred

    def __call__(self, image):
        image, samples = self.preprocess(image)
        image, preds = self.process(image, samples)
        image, pred = self.postprocess(image, preds)
        return image, pred

In [6]:
config = {
    'name': 'Extractor',
    'Extractor': {
        'arch_config': {
            'name': 'FCN',
            'FCN': {
                'replace_stride_with_dilation': [True, True, True],
                'backbone': '"resnet50"',
                'pretrained_backbone': False
            }
        },
        'image_size': (256, 256),
        'weight_path': "'best_model_40_loss=-0.07206277665637788.pth'",
        'device': "'cpu'",
        'num_classes': 12
    }
}

classes = {
    'BG': [0, 0, False],
    'HEADING': [1, 0.0214, True],
    'V_ID': [2, 0.0096, True],
    'V_NAME1': [3, 0.0163, True],
    'V_NAME2': [4, 0.0180, True],
    'V_BD': [5, 0.0051, True],
    'V_BP1': [6, 0.0075, True],
    'V_BP2': [7, 0.0158, True],
    'V_A1': [8, 0.0081, True],
    'V_A2': [9, 0.0231, True],
    'LOGO': [10, 0.0420, False],
    'FIGURE': [11, 0.1090, False]
}

In [7]:
textline_extractor = create_instance(config)

In [8]:
import cv2
import numpy as np

In [9]:
image = cv2.imread('./test/extracted_card.jpg')
image, pred = textline_extractor(image)

In [10]:
cv2.imshow('image', image)
cv2.waitKey()
cv2.destroyAllWindows()

In [11]:
image.shape

(408, 664, 3)

In [12]:
pred.shape

(256, 256, 12)

In [13]:
for i in range(len(classes)):
    num_labels, labels = cv2.connectedComponents(pred[..., i].round().astype(np.uint8))
    mask = np.zeros_like(labels).astype(np.uint8)
    for j in range(1, num_labels):
        mask = ((mask + (labels == j).astype(np.uint8)) != 0).astype(np.uint8)
    print(list(classes.keys())[i])
    cv2.imshow(f'image_{i}', mask * 255)
    cv2.waitKey()
    cv2.destroyAllWindows()

BG
HEADING
V_ID
V_NAME1
V_NAME2
V_BD
V_BP1
V_BP2
V_A1
V_A2
LOGO
FIGURE


In [14]:
def order_points(points):
    assert len(points) == 4, 'Length of points must be 4'
    left = sorted(points, key=lambda p: p[0])[:2]
    right = sorted(points, key=lambda p: p[0])[2:]
    tl, bl = sorted(left, key=lambda p: p[1])
    tr, br = sorted(right, key=lambda p: p[1])
    return [tl, tr, br, bl]

In [15]:
def get_line(mask):
    num_labels, label = cv2.connectedComponents(mask.round().astype(np.uint8))
    if num_labels == 1:
        return None
    for i in range(1, num_labels):
        contours, _ = cv2.findContours(np.uint8(label == i), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contour = contours[0]
        textline = cv2.boxPoints(cv2.minAreaRect(contour))       
        textline = order_points(textline)
    
    return textline

In [16]:
for i in range(1, len(classes)):
    num_labels, labels = cv2.connectedComponents(pred[..., i].round().astype(np.uint8))
    mask = np.zeros_like(labels).astype(np.uint8)
    for j in range(1, num_labels):
        mask = ((mask + (labels == j).astype(np.uint8)) != 0).astype(np.uint8)
    line = get_line(mask)
    if line is not None:
        line = np.int0(line)
        line = np.array([[x * image.shape[1] // 256, y * image.shape[0] // 256] for x, y in line])
        cv2.drawContours(image, [line], -1,  (0, 255, 0), 2)
    cv2.imshow(f'image_{i}', mask * 255)
    cv2.waitKey()
    cv2.destroyAllWindows()

In [19]:
cv2.imshow('image', image)
cv2.waitKey()
cv2.destroyAllWindows()