In [1]:
import numpy as np
from matplotlib import pyplot as plt
import cv2
import torch
from torch import nn
from datasets.data import PAN_CTW
from post_processing import pa
from models import model
import os
from configer import config



In [2]:
data = PAN_CTW('test')
sample = data[0]
print(sample.keys())

dict_keys(['imgs', 'img_metas'])


In [3]:
ckp = torch.load('./pan_r18_ctw.pth.tar', map_location='cpu')['state_dict']
d = dict()
for key, value in ckp.items():
    tmp = key[7:]
    d[tmp] = value

In [4]:
Net = model.PAN(config.backbone, config.neck_param, config.head_param)
Net.load_state_dict(d)

<All keys matched successfully>

In [5]:
Net.eval()
with torch.no_grad():
    out = Net(sample['imgs'].unsqueeze(0))
    print(out.shape)

torch.Size([1, 6, 640, 992])


In [7]:
def write_result_ctw(image_name, outputs, result_path=result_text_path):
    bboxes = outputs['bboxes']

    lines = []
    for i, bbox in enumerate(bboxes):
        bbox = bbox.reshape(-1, 2)[:, ::-1].reshape(-1)
        values = [int(v) for v in bbox]
        line = '%d' % values[0]
        for v_id in range(1, len(values)):
            line += ',%d' % values[v_id]
        line += '\n'
        lines.append(line)

    file_name = '%s.txt' % image_name
    file_path = os.path.join(result_path, file_name)
    with open(file_path, 'w') as f:
        for line in lines:
            f.write(line)

In [8]:
outputs = dict()
img_meta = sample['img_metas']

score = torch.sigmoid(out[:, 0, :, :])
kernels = out[:, :2, :, :] > 0
text_mask = kernels[:, :1, :, :]
kernels[:, 1:, :, :] = kernels[:, 1:, :, :] * text_mask
emb = out[:, 2:, :, :]
emb = emb * text_mask.float()

score = score.data.cpu().numpy()[0].astype(np.float32)
kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)
emb = emb.cpu().numpy()[0].astype(np.float32)

# pa
label = pa(kernels, emb)

# image size
org_img_size = img_meta['org_img_size']
img_size = img_meta['img_size']
img_path = img_meta['img_path']
img_name = img_meta['img_name']

label_num = np.max(label) + 1
label = cv2.resize(label, (int(img_size[1]), int(img_size[0])),
                    interpolation=cv2.INTER_NEAREST)
score = cv2.resize(score, (int(img_size[1]), int(img_size[0])),
                    interpolation=cv2.INTER_NEAREST)


scale = (float(org_img_size[1]) / float(img_size[1]),
            float(org_img_size[0]) / float(img_size[0]))

bboxes = []
scores = []
for i in range(1, label_num):
    ind = label == i
    points = np.array(np.where(ind)).transpose((1, 0))

    if points.shape[0] < min_area:
        label[ind] = 0
        continue

    score_i = np.mean(score[ind])
    if score_i < min_score:
        label[ind] = 0
        continue


    if bbox_type == 'rect':
        rect = cv2.minAreaRect(points[:, ::-1])
        bbox = cv2.boxPoints(rect) * scale
    elif bbox_type == 'poly':
        binary = np.zeros(label.shape, dtype='uint8')
        binary[ind] = 1
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL,
                                        cv2.CHAIN_APPROX_SIMPLE)
        bbox = contours[0] * scale

    bbox = bbox.astype('int32')
    bboxes.append(bbox.reshape(-1))
    scores.append(score_i)

outputs.update(dict(bboxes=bboxes, scores=scores))

In [10]:
write_result_ctw(img_name, outputs)
ori_img = cv2.imread(img_path)
boxes = [b.reshape(-1, 2) for b in outputs['bboxes']]
vis_img = cv2.polylines(ori_img, boxes, True, (0, 255, 255), 2)
cv2.imwrite(os.path.join(result_img_path, img_name), vis_img)

True