## 1. Dataloader

In [1]:
import cv2
import json
import numpy as np
import imgaug.augmenters as iaa

from pathlib import Path
from natsort import natsorted

from data.augmenter import Augmenter
from data.scene_text_dataset import SceneTextDataset

In [2]:
label_paths = natsorted(Path('../dataset/focused_scene_text_2013/train/').glob('*.json'), key=lambda x: x.stem)
image_paths = natsorted(Path('../dataset/focused_scene_text_2013/train/').glob('*.jpg'), key=lambda x: x.stem)

data_pairs = [(image_path, label_path) for image_path, label_path in zip(label_paths, image_paths) if image_path.stem == label_path.stem]

print(len(data_pairs))

229


### 1.1 Test Augmenter

In [4]:
def to_valid_poly(polygon, image_height, image_width):
    polygon = np.array(polygon)
    polygon[:, 0] = np.clip(polygon[:, 0], a_min=0, a_max=image_width - 1)  # x coord not max w-1, and not min 0
    polygon[:, 1] = np.clip(polygon[:, 1], a_min=0, a_max=image_height - 1)  # y coord not max h-1, and not min 0
    return polygon.tolist()

def to_4points(points):
    x1, y1 = points[0][0], points[0][1]
    x2, y2 = points[1][0], points[1][1]
    return [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]

augmenter = Augmenter()

image_path = str(image_paths[0])
label_path = str(label_paths[0])

image = cv2.imread(image_path)
thickness = max(image.shape) // 400

with open(file=label_path, mode='r', encoding='utf-8') as f:
    label = json.load(f)

image, label = augmenter.apply(image=image, label=label, augmenter=iaa.Rotate(rotate=(-90, 90), fit_output=True))

for shape in label['shapes']:
    if shape['shape_type'] == 'rectangle':
        points = to_4points(shape['points'])
        points = to_valid_poly(points, image_height=image.shape[0], image_width=image.shape[1])
        cv2.rectangle(
            img=image,
            pt1=(int(points[0][0]), int(points[0][1])),
            pt2=(int(points[1][0]), int(points[1][1])),
            color=(0, 255, 0),
            thickness=thickness
        )
    elif shape['shape_type'] == 'polygon':
        points = to_valid_poly(shape['points'], image_height=image.shape[0], image_width=image.shape[1])
        cv2.polylines(img=image, pts=[np.int32(points)], isClosed=True, color=(0, 255, 0), thickness=thickness)
    else:
        raise ValueError(f"visual function for {shape['shape_type']} is not implemented.")

cv2.imwrite('image.png', image)

True

### 1.2 Test Dataset

In [54]:
import torch

mean = [0, 0, 0]
std = [1, 1, 1]

dataset = SceneTextDataset(
    dirnames=['../dataset/totaltext/train/'],
    imsize=640,
    mean=mean, std=std,
    shrink_ratio=0.5,
    image_extents=['.jpg'],
    label_extent='.json',
    transforms=None,
    require_transforms=None,
)

train - 50


In [55]:
image, mask, image_info = dataset[1]

In [56]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

mean = torch.tensor(mean, dtype=torch.float, device=device).view(3, 1, 1)
std = torch.tensor(std, dtype=torch.float, device=device).view(3, 1, 1)

image = (image * std + mean) * 255
image = image.permute(1, 2, 0).contiguous()
image = image.to(torch.uint8).numpy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [57]:
text_map, kernel_map = mask[0], mask[1]

text_mask = torch.stack([text_map] * 3, dim=2)
kernel_mask = torch.stack([kernel_map] * 3, dim=2)

text_id = torch.unique(text_map)

print(text_id.numpy().tolist())

[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]


In [58]:
for i in text_id:
    if i == 0: continue
    color = (
        np.random.randint(0, 255),
        np.random.randint(0, 255),
        np.random.randint(0, 255)
    )

    text_mask[text_map == i] = torch.tensor(color, dtype=text_map.dtype, device=text_map.device)
    kernel_mask[kernel_map == i] = torch.tensor(color, dtype=kernel_map.dtype, device=kernel_map.device)

text_mask = text_mask.to(torch.uint8).numpy()
kernel_mask = kernel_mask.to(torch.uint8).numpy()

text_mask = (0.4 * text_mask + 0.6 * image).astype(np.uint8)
kernel_mask = (0.4 * kernel_mask + 0.6 * image).astype(np.uint8)

In [60]:
cv2.imwrite('text.png', text_mask)
cv2.imwrite('kernel.png', kernel_mask)

True

## 2. PANNet

In [1]:
import torch
from model.pan_net import PANNet

In [2]:
model = PANNet(backbone_name='resnet18', backbone_pretrained=False, num_FPEMs=2)

In [3]:
dummy_x = torch.FloatTensor(2, 3, 640, 640)
x = model(dummy_x)



In [4]:
x.shape

torch.Size([2, 6, 640, 640])