In [None]:
import os, cv2
import numpy as np
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import torch
from matplotlib import pyplot as plt
from mmcv.parallel import collate, scatter

from mmdet.datasets.pipelines import Compose
from mmdet.core import bbox2result, bbox_mask2result
from mmdet.datasets.coco_polar import Coco_Seg_Dataset as DATASET
from mmdet.apis.inference import init_detector, inference_detector, show_result, LoadImage

In [None]:
COLORS = {}
for i in range(80):
    COLORS[i] = (0,255,0)

def get_data(img, cfg, device):
    # build the data pipeline
    test_pipeline = [LoadImage()] + cfg.test_pipeline[1:]
    test_pipeline = Compose(test_pipeline)
    # prepare data
    data = dict(img=img)
    data = test_pipeline(data)
    data = scatter(collate([data], samples_per_gpu=1), [device])[0]
    return data

def draw_semseg(image, mask):
    image_ = image.copy()
    _mask = np.zeros_like(image_)
    _mask[mask==1,...] = (0,0,255)
    image_ = cv2.add(image_, _mask)
    return image_

def draw_bboxes(image, bboxes, labels, scores=None, thick=2):
    image_ = image.copy()

    if scores is not None:
        for (left, top, right, bottom), label, score in zip(bboxes, labels, scores):
            left = int(left); top = int(top); right = int(right); bottom = int(bottom)
            cv2.rectangle(image_, (left, top), (right, bottom), COLORS[label], thick)
            text = "%s_%.2f" % (DATASET.CLASSES[label], score)
            cv2.putText(image_, text, (left-10, top-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[label], 2)

    else:
        for (left, top, right, bottom), label in zip(bboxes, labels):
            left = int(left); top = int(top); right = int(right); bottom = int(bottom)
            cv2.rectangle(image_, (left, top), (right, bottom), COLORS[label], thick)
            cv2.putText(image_, DATASET.CLASSES[label], (left-10, top-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[label], 2)

    return image_

In [None]:
config = "../ccdetection/configs/polarmask/polar_b1_semseg.py"
checkpoint = "/home/member/Workspace/thuync/checkpoints/polar_b1_semseg/epoch_10.pth"
# img_file = "/home/member/Workspace/dataset/coco/images/val2017/000000397133.jpg"
img_file = "/home/member/Workspace/thuync/ccdetpose/mmdetection/demo/demo.jpg"
out_file = "/home/member/Workspace/thuync/checkpoints/polar_b1_semseg/debug.png"

model = init_detector(config, checkpoint=checkpoint, device='cuda')
data = get_data(img_file, model.cfg, next(model.parameters()).device)

In [None]:
threshold = 0.3

with torch.no_grad():
    # Prepare data
    img = data['img'][0]
    img_meta = data['img_meta'][0]
    print(img_meta[0].keys())
    ori_shape = img_meta[0]['ori_shape']
    img_h, img_w, _ = ori_shape

    # Get instance
    x = model.extract_feat(img)
    bbox_outs = model.bbox_head(x)
    bbox_pred = bbox_outs + (img_meta, model.test_cfg, True)
    bboxes, labels, masks = model.bbox_head.get_bboxes(*bbox_pred)[0]

    # Filter
    bboxes = bboxes.cpu().numpy()
    labels = labels.cpu().numpy()
    indicator = (bboxes[:,-1] > threshold)
    bboxes = bboxes[indicator][:,:-1]
    labels = labels[indicator]

    # Transform instance mask
    _masks = []
    for i in range(masks.shape[0]):
        im_mask = np.zeros((img_h, img_w), dtype=np.uint8)
        mask = [masks[i].transpose(1,0).unsqueeze(1).int().data.cpu().numpy()]
        im_mask = cv2.drawContours(im_mask, mask, -1,1,-1)
        _masks.append(im_mask)
    masks = np.stack(_masks)
    masks = masks[indicator]

    # Get semantic
    mask_pred = model.semseg_head(x)
    mask_pred = model.semseg_head.get_seg_masks(mask_pred, ori_shape, scale_factor=1.0, rescale=False, threshold=0.5)
    foreground = mask_pred[0,0].astype('uint8')

print("bboxes:", bboxes.shape)
print("labels:", labels.shape)
print("masks:", masks.shape)
print("foregrounds:", foreground.shape)

In [None]:
image = cv2.imread(img_file)[...,::-1]
_masks = masks.sum(axis=0).clip(0,1).astype('uint8')
_masks = cv2.resize(_masks, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
foreground = cv2.resize(foreground, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)

image_bbox = draw_bboxes(image, bboxes, labels)
image_sem = draw_semseg(image_bbox, foreground)
image_ins = draw_semseg(image_bbox, _masks)

plt.figure(figsize=(20,20))
plt.subplot(2,1,1); plt.imshow(image_sem)
plt.subplot(2,1,2); plt.imshow(image_ins)
plt.show()

In [None]:
cls_score = bbox_outs[0]
centerness = bbox_outs[2]

sc_thres = 0.05
cls_scores = []
for score, center in zip(cls_score, centerness):
    score = score.sigmoid() * center
    score = (score>sc_thres).float().cpu().numpy()
    score = score[0].sum(axis=0).clip(0,1).astype('uint8')
    score = cv2.resize(score, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
    cls_scores.append(score)
cls_scores = np.stack(cls_scores).sum(axis=0).clip(0,1)
print(cls_scores.shape)

image_p3 = draw_semseg(image_bbox, cls_scores)
plt.figure(figsize=(20,20))
plt.imshow(image_p3)
plt.show()