In [None]:
import os
import os.path as osp
import sys
import numpy as np
import torch, torchvision
import torch.nn.functional as F
from torch import nn
import matplotlib.pyplot as plt
import mmdet
import mmcv
%load_ext autoreload
%autoreload 2
from mmdet.models import build_detector
from mmdet.datasets import get_dataset
from mmdet.datasets import transforms

In [None]:
import tempfile
from mmdet.core.evaluation import coco_utils
from mmdet.core.post_processing.bbox_nms import multiclass_nms
from mmdet.ops.nms import nms_wrapper

In [None]:
cfg = './visdrone/configs/ssd300.py'
cfg = mmcv.Config.fromfile(cfg)
# gt_json = cfg.data.test.ann_file
dataset = get_dataset(cfg.data.test)

In [None]:
txt_dir = '/tmp/fuckyoudir/'

In [None]:
def txt2det(fid, num_classes=10):
    """ Returns detecton result from one txt(image)
    Args:
        fid: opened file handler.
    """
    lines = fid.readlines()
    lines = [v.strip('\n') for v in lines]
    lines = [v.split(',') for v in lines]
    dets = [[] for _ in range(num_classes)]
    for line in lines:
        x1, y1, w, h, sc, label, trun, occ = line
        label = int(label)
        if label == 0 or label == 11:
            continue
        assert label > 0 and label < 11, 'Bad label'
        x1, y1 = int(x1), int(y1)
        w, h = int(w), int(h)
        score = float(sc)
        x2 = x1 + w
        y2 = y1 + h
        bbox = np.asarray([x1, y1, x2, y2, score], dtype=np.float32)
        dets[label - 1].append(bbox)
    for i in range(len(dets)):
        if len(dets[i]) == 0:
            dets[i] = np.empty([0, 5])
        else:
            dets[i] = np.stack(dets[i], 0)
    return dets

In [None]:
def get_dets(savedir, dataset):
    """ This function return dets format on original image,
    ignoring crops, it calls txt2det(). 
    
    Returns:
        list(image) of list(class) of [N, 5]
    """
    results = []
    for info in dataset.img_infos:
        img_name = info['filename']
        stem = img_name.split('.')[0].split('/')[1]

        # do merge
        for i in range(10):
            if i == 9:
                exp_name = '{}_{}.txt'.format(stem, i)
                exp_name = os.path.join(txt_dir, exp_name)
                assert osp.exists(exp_name)
                with open(exp_name) as f:
                    result = txt2det(f)

        results.append(result)    

In [None]:
def get_dets_merge(savedir, dataset):
    """ This function return dets format on original image,
    MERGE crops, it calls txt2det(). 
    
    Returns:
        list(image) of list(class) of [N, 5]
    """
    results = []
    for info in dataset.img_infos:
        img_name = info['filename']
        stem = img_name.split('.')[0].split('/')[1]

        # do merge
        single_results =  []
        for i in range(10):
            exp_name = '{}_{}.txt'.format(stem, i)
            exp_name = os.path.join(txt_dir, exp_name)
            assert osp.exists(exp_name)
            with open(exp_name) as f:
                sin_res = txt2det(f)
            single_results.append(sin_res)
        
        # 
        num_classes = len(single_results[0])
        per_cls_bboxes = [[] for _ in range(num_classes)]  # [ []*10]
        for result in single_results:
            for i, res in enumerate(result):
                per_cls_bboxes[i].append(res)
        for i in range(num_classes):
            per_cls_bboxes[i] = np.concatenate(per_cls_bboxes[i], 0)
            
        results.append(per_cls_bboxes)    
    return results

In [None]:
# results = get_dets(savedir=txt_dir, dataset=dataset)
results = get_dets_merge(savedir=txt_dir, dataset=dataset)

In [None]:
tf = tempfile.mkstemp(suffix='.json')
js = coco_utils.results2json(dataset, results, tf[1])
coco_utils.coco_eval(tf[1], ['bbox'], dataset.coco)

# NMS

In [None]:
from visdrone.utils import box_ops

In [None]:
from time import time

In [None]:
results2 = results.copy()

In [None]:
a = time()
for idx, res in enumerate(results2):
    results2[idx] = box_ops.refine_boxes_multi_class(res, 10, 0.5, 500, 0.5, 'cuda')
print(time() - a)

In [None]:
tf = tempfile.mkstemp(suffix='.json')
js = coco_utils.results2json(dataset, results2, tf[1])
coco_utils.coco_eval(tf[1], ['bbox'], dataset.coco)

In [None]:
def transfrom_by_refine(results):
    for idx, res in enumerate(results):
        box_ops.refine_boxes_multi_class(res, 10, 0.5, 500, 0.5)

In [None]:
def transform_results_by_nms(results, nms_func, iou_thr = 0.5):
    for idx, res in enumerate(results):
        for i, c_res in enumerate(res):
            bb, ind = nms_func(c_res.astype(np.float32), iou_thr)
            res[i] = bb
    results[idx] = res
    return results

In [None]:
import functools
nms_func = functools.partial(nms_wrapper.soft_nms, method='linear', sigma=float(0.5), min_score=float(1e-3))
ress = transform_results_by_nms(results, nms_func)

tf = tempfile.mkstemp(suffix='.json')
js = coco_utils.results2json(dataset, ress, tf[1])
coco_utils.coco_eval(tf[1], ['bbox'], dataset.coco)

In [None]:
cocores = []
for ithr in np.linspace(0.2, 0.95, 10):
    ress = transform_results_by_nms(results, iou_thr=ithr)

    tf = tempfile.mkstemp(suffix='.json')
    js = coco_utils.results2json(dataset, ress, tf[1])
    cocores_ = coco_utils.coco_eval(tf[1], ['bbox'], dataset.coco)
    cocores.append(cocores_)

In [None]:
a = torch.zeros([3,4]).numpy()

In [None]:
a[(a==0).nonzero()]