# Input

In [62]:
class CardInfo():
    def __init__(self, 
                 idx=None, 
                 original_image=None, 
                 image=None, 
                 warped_size=None, 
                 card_type: str=None, 
                 angle: int=None,
                 textlines=None):
        self.idx = idx
        self.original_image = original_image
        self.image = image
        self.warped_size = warped_size
        self.card_type = card_type
        self.angle = angle
        self.textlines = textlines

In [63]:
import cv2
import numpy as np

image = cv2.imread('test_images/input/extracted_card.jpg')
card_info = CardInfo()

card_infos = []
card_info.image = image
card_info.orginal_image = image
card_infos.append(card_info)

# Processor

In [64]:
import torch
from fcn import FCN
import time

num_class = 12
weight_path = '/home/vinhloiit/Documents/VTCC/id_info_extraction/models/weights/field_extraction/cmnd/2011091606/best_model_40_loss=-0.07206277665637788.pth'
image_size = (256, 256)

In [66]:
# config model while training
model = FCN(replace_stride_with_dilation=[True, True, True],
            backbone="resnet50",
            pretrained_backbone=False,
            num_classes=num_class)

t1 = time.time()
model.load_state_dict(torch.load(weight_path, map_location='cpu')) # Load weight
t2 = time.time()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Chọn device nếu là GPU thì sẽ chuyển sang GPU
model.to(device)
model.eval()

print(f'Load weight: {t2 - t1}s')

Load weight: 0.0682060718536377s


## 1. preprocess

In [67]:
def preprocess(card_infos):
    images = [card_info.image for card_info in card_infos]
    samples = [cv2.resize(image, image_size) for image in images]
    samples = np.array(samples)
    samples = torch.from_numpy(samples).to(device).to(torch.float)
    samples = samples.permute(0, 3, 1, 2)
    samples = (samples - samples.mean(dim=(1, 2, 3), keepdim=True)) / samples.std(dim=(1, 2, 3), keepdim=True)
    return card_infos, samples

## 2. process

In [68]:
def process(card_infos, samples):
    with torch.no_grad():
        return card_infos, model(samples)    

## 3. postprocess

In [69]:
def postprocess(card_infos, preds):
    preds = preds.permute(0, 2, 3, 1).detach().cpu().numpy() # (B, C, H, W) -> (B, H, W, C)
    return card_infos, preds

# Stage

## 1. preprocess

In [70]:
def spreprocess(card_infos):
    if __debug__:
        for i, card_info in enumerate(card_infos):
            assert type(card_info.image).__name__ == 'ndarray', f'Image #{i} must be an ndarray.'
            assert card_info.image.ndim == 3, f'Image #{i} must be a 3D ndarray.'
            assert card_info.image.shape[-1] == 3, f'Image #{i} must have 3 channels.'

    return card_infos,

## 2. process

In [71]:
def sprocess(card_infos):
    card_infos, samples = preprocess(card_infos)
    card_infos, preds = process(card_infos, samples)
    card_infos, preds = postprocess(card_infos, preds)
    return card_infos, preds

In [72]:
classes = {
    'HEADING': [1, [175, 153, 144], True],
    'V_ID': [2, [75, 25, 230], True],
    'V_NAME1': [3, [128, 0, 0], True],
    'V_NAME2': [4, [48, 130, 245], True],
    'V_BD': [5, [128, 128, 0], True],
    'V_BP1': [6, [25, 225, 225], True],
    'V_BP2': [7, [75, 180, 60], True],
    'V_A1': [8, [180, 215, 255], True],
    'V_A2': [9, [240, 240, 70], True],
    'LOGO': [10, [255, 190, 230], False],
    'FIGURE': [11, [255, 255, 255], False],
    'BG': [0, [0, 255, 0], False],
}

In [11]:
card_infos, preds = sprocess(card_infos)
pred = preds[0] # 1 image [H, W, C]
for i in range(len(classes)):
    num_labels, labels = cv2.connectedComponents(pred[..., i].round().astype(np.uint8))
    mask = np.zeros_like(labels).astype(np.uint8) #Array of zeros with the same shape and type as labels.
    for j in range(1, num_labels):
        mask = ((mask + (labels == j).astype(np.uint8)) != 0).astype(np.uint8) # mask of each class
        
    cv2.imshow(list(classes.keys())[i], mask * 255)
    cv2.waitKey()
    cv2.destroyAllWindows()

## 3. postprocess

In [74]:
from scipy.spatial import distance

def expand_height(points):
    """Expand height after found textline"""
    if distance.euclidean(points[0], points[1]) > distance.euclidean(points[0], points[3]):
        points[0] = points[0] - 0.50 * (points[3] - points[0])
        points[1] = points[1] - 0.50 * (points[2] - points[1])
        points[3] = points[0] + 4 / 3 * (points[3] - points[0])
        points[2] = points[1] + 4 / 3 * (points[2] - points[1])
    else:
        points[0] = points[0] - 0.50 * (points[1] - points[0])
        points[3] = points[3] - 0.50 * (points[2] - points[3])
        points[1] = points[0] + 4 / 3 * (points[1] - points[0])
        points[2] = points[3] + 4 / 3 * (points[2] - points[3])
    return points

In [75]:
def order_points(points):
    assert len(points) == 4, 'Length of points must be 4'
    left = sorted(points, key=lambda p: p[0])[:2]
    right = sorted(points, key=lambda p: p[0])[2:]
    tl, bl = sorted(left, key=lambda p: p[1])
    tr, br = sorted(right, key=lambda p: p[1])
    return [tl, tr, br, bl]

In [76]:
def get_line(mask):
    """Get 4 corners of each line"""
    textlines = []
    num_labels, label = cv2.connectedComponents(mask.round().astype(np.uint8))
    for i in range(1, num_labels):
        contours, _ = cv2.findContours(np.uint8(label == i), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contour = contours[0]
        textline = cv2.boxPoints(cv2.minAreaRect(contour)) #bounding rectangle is drawn with minimum area  
        textlines.append(order_points(textline))
    
    return textlines

In [77]:
def find_textlines(image, pred):
    image_original_size = image.shape[1::-1]
#     color_mask = np.zeros(shape=(*pred.shape[1::-1], 3), dtype=np.uint8)
    output_image = image.copy()

    textlines = {}
    for class_name, [i, color, expand] in classes.items():
        if class_name == 'BG':
            continue
            
        mask = pred[..., i].round().astype(np.uint8) # round value at (x,y) of class i -> binary image
               
        lines = get_line(mask) # 4 corners of each textline
        for line in lines:
            if expand:
                line = expand_height(line)
                
            line = np.array([[x * image.shape[1] // 256, y * image.shape[0] // 256] for x, y in line]) # Convert to orginal image size
            textlines[class_name] = line
            cv2.drawContours(output_image, [np.int32(line)], -1,  color, 2)
    return textlines, output_image

In [78]:
def get_warped_images(image, pts):
    rect = order_points(pts)
    tl, tr, br, bl = rect
    widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
    widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
    maxWidth = max(widthA, widthB)

    heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
    heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
    maxHeight = max(heightA, heightB)

    dst = np.array([
        [0, 0],
        [maxWidth - 1, 0],
        [maxWidth - 1, maxHeight - 1],
    [0, maxHeight - 1]], dtype = "float32")
    
    rect = np.array(rect, dtype = "float32")
    M = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(image, M, int(maxWidth), int(maxHeight))

    return warped

In [79]:
def cut_textlines(image, pred):
    textlines, output_image = find_textlines(image, pred)
    for class_name, textline in textlines.items():
        warped_images = get_warped_images(image, textline)
        textlines[class_name] = warped_images
    return textlines, output_image

In [80]:
def spostprocess(card_infos, preds):
    for card_info, pred in zip(card_infos, preds):
        textlines, textline_image = cut_textlines(card_info.image, pred)
        card_info.textlines = textlines
        cv2.imshow('textline image', textline_image)
        cv2.waitKey()
        cv2.destroyAllWindows()
    return card_infos,

# TEST

In [81]:
card_infos, = spreprocess(card_infos)
card_infos, preds = sprocess(card_infos)
card_infos, = spostprocess(card_infos, preds)

In [82]:
for card_info in card_infos:
    for class_name, textline in card_info.textlines.items():
        cv2.imshow(class_name, textline)
        cv2.waitKey()
        cv2.destroyAllWindows()