## 1. Dataloader

In [1]:
import cv2
import json
import numpy as np
import imgaug.augmenters as iaa

from pathlib import Path
from natsort import natsorted

from torch.utils.data import DataLoader

from dataloader.augmenter import Augmenter
from dataloader.PAN_dataset import PANDataset

### 1.1 Test Augmenter

In [3]:
image_paths = natsorted(Path('../dataset/focused_scene_text_2013/train/').glob('*.jpg'), key=lambda x: x.stem)
label_paths = natsorted(Path('../dataset/focused_scene_text_2013/train/').glob('*.json'), key=lambda x: x.stem)
data_pairs = [(image_path, label_path) for image_path, label_path in zip(image_paths, label_paths) if image_path.stem == label_path.stem]
print(f'Number of Data: {len(data_pairs)}')

Number of Data: 229


In [4]:
def to_valid_poly(polygon, image_height, image_width):
    polygon = np.array(polygon)
    polygon[:, 0] = np.clip(polygon[:, 0], a_min=0, a_max=image_width - 1)  # x coord not max w-1, and not min 0
    polygon[:, 1] = np.clip(polygon[:, 1], a_min=0, a_max=image_height - 1)  # y coord not max h-1, and not min 0
    return polygon.tolist()

def to_4points(points):
    x1, y1 = points[0][0], points[0][1]
    x2, y2 = points[1][0], points[1][1]
    return [(x1, y1), (x2, y1), (x2, y2), (x1, y2)]

augmenter = Augmenter()

image_path, label_path = data_pairs[0]
image_path = str(image_path)
label_path = str(label_path)

image = cv2.imread(image_path)
thickness = max(image.shape) // 400

with open(file=label_path, mode='r', encoding='utf-8') as f:
    label = json.load(f)

image, label = augmenter.apply(image=image, label=label, augmenter=iaa.Rot90(k=3, keep_size=False))
image = np.ascontiguousarray(image)

for shape in label['shapes']:
    if shape['shape_type'] == 'rectangle':
        points = to_4points(shape['points'])
        points = to_valid_poly(points, image_height=image.shape[0], image_width=image.shape[1])
        cv2.rectangle(
            img=image,
            pt1=(int(points[0][0]), int(points[0][1])),
            pt2=(int(points[1][0]), int(points[1][1])),
            color=(0, 255, 0),
            thickness=thickness
        )
    elif shape['shape_type'] == 'polygon':
        points = to_valid_poly(shape['points'], image_height=image.shape[0], image_width=image.shape[1])
        cv2.polylines(img=image, pts=[np.int32(points)], isClosed=True, color=(0, 255, 0), thickness=thickness)
    else:
        raise ValueError(f"visual function for {shape['shape_type']} is not implemented.")

cv2.imwrite('image.png', image)

True

### 1.2 Test Dataset

In [5]:
import torch

mean = [0, 0, 0]
std = [1, 1, 1]

dataset = PANDataset(
    dirnames=['../dataset/totaltext/train/'],
    imsize=640,
    mean=mean, std=std,
    shrink_ratio=0.5,
    image_extents=['.jpg'],
    label_extent='.json',
    transforms=[
        iaa.Rot90(k=[0, 1, 2, 3], keep_size=False),
        iaa.Add(value=(-50, 50), per_channel=True),
        iaa.GaussianBlur(sigma=(0, 1)),
        iaa.MotionBlur(),
        iaa.Affine(rotate=(0, 10), shear=(-5, 5), fit_output=True),
        iaa.PerspectiveTransform(scale=(0, 0.1)),
        # iaa.ChangeColorTemperature(),
        # iaa.Clouds(),
    ],
    require_transforms=None,
)

train - 50


In [6]:
from typing import List, Tuple

def tensor2image(
    sample: torch.Tensor,
    mean: List[float] = [0, 0, 0],
    std: List[float] = [1, 1, 1],
    image_size: Tuple[int, int] = None,
):
    mean = torch.tensor(mean, dtype=torch.float, device=sample.device).view(3, 1, 1)
    std = torch.tensor(std, dtype=torch.float, device=sample.device).view(3, 1, 1)

    sample = (sample * std + mean) * 255  # denormalize
    sample = sample.permute(1, 2, 0).contiguous()  # C x H x W -> H x W x C
    image = sample.to(torch.uint8).numpy()  # tensor, float32 -> numpy, uint8
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # BGR -> RGB
    image = cv2.resize(image, dsize=image_size)

    return image


def segmap2segmask(masks: torch.Tensor, image_size: Tuple[int, int] = None):
    text_map, kernel_map = masks[0], masks[1]
    text_mask = torch.stack([text_map] * 3, dim=2)
    kernel_mask = torch.stack([kernel_map] * 3, dim=2)

    text_id = torch.unique(text_map)

    for i in text_id:
        if i == 0:  # background
            continue
        color = (
            np.random.randint(0, 255),
            np.random.randint(0, 255),
            np.random.randint(0, 255)
        )

        text_mask[text_map == i] = torch.tensor(color, dtype=text_map.dtype, device=text_map.device)
        kernel_mask[kernel_map == i] = torch.tensor(color, dtype=kernel_map.dtype, device=kernel_map.device)

    text_mask = text_mask.to(torch.uint8).numpy()
    kernel_mask = kernel_mask.to(torch.uint8).numpy()

    text_mask = cv2.resize(text_mask, dsize=image_size)
    kernel_mask = cv2.resize(kernel_mask, dsize=image_size)
    
    return text_mask, kernel_mask

In [7]:
sample, masks, image_info = dataset[1]
image_path, image_size = image_info

image = tensor2image(sample=sample, mean=mean, std=std, image_size=image_size)
text_mask, kernel_mask = segmap2segmask(masks=masks, image_size=image_size)

text_mask = (0.4 * text_mask + 0.6 * image).astype(np.uint8)
kernel_mask = (0.4 * kernel_mask + 0.6 * image).astype(np.uint8)

In [8]:
cv2.imwrite(f'{Path(image_path).stem}_text.png', text_mask)
cv2.imwrite(f'{Path(image_path).stem}_kernel.png', kernel_mask)

True

### 1.3 Test Dataloader

In [9]:
data_loader = DataLoader(dataset=dataset, batch_size=2, shuffle=False)

In [10]:
data_iter = iter(data_loader)

In [11]:
samples, masks, image_infos = data_iter.next()
image_paths, image_sizes = image_infos

In [12]:
for sample, mask, image_path, image_size in zip(samples, masks, image_paths, image_sizes):
    image = tensor2image(sample=sample, mean=mean, std=std, image_size=tuple(image_size.numpy()))
    text_mask, kernel_mask = segmap2segmask(masks=mask, image_size=tuple(image_size.numpy()))

    text_mask = (0.4 * text_mask + 0.6 * image).astype(np.uint8)
    kernel_mask = (0.4 * kernel_mask + 0.6 * image).astype(np.uint8)

    cv2.imwrite(f'{Path(image_path).stem}_text.png', text_mask)
    cv2.imwrite(f'{Path(image_path).stem}_kernel.png', kernel_mask)

## 2. PANNet

In [13]:
import torch
from model.PANNet.PAN_net import PANNet

In [14]:
model = PANNet(backbone_name='resnet18', backbone_pretrained=False, num_FPEMs=2)

In [15]:
params = sum([param.numel() for param in model.parameters() if param.requires_grad])
print(f'The number of parameters: {params}')

The number of parameters: 11520582


In [13]:
preds = model(samples)
print(preds.shape)



torch.Size([2, 6, 640, 640])


## 3. PAN Loss

In [14]:
import torch
from loss.PAN_loss import PANLoss

In [15]:
loss = PANLoss()

In [16]:
loss(preds=preds, targets=masks)

(tensor(1.7438, grad_fn=<AddBackward0>),
 tensor(0.1975, grad_fn=<MeanBackward0>),
 tensor(0.7776, grad_fn=<MeanBackward0>),
 tensor(1., grad_fn=<MeanBackward0>),
 tensor(1., grad_fn=<MeanBackward0>))