In [2]:
from db.datasets import datasets
from config import system_configs
import json, os
import numpy as np
from tqdm import tqdm
from sklearn.metrics import average_precision_score

In [3]:
with open('config/KPDetection.json', "r") as f:
    configs = json.load(f)
    
split = 'valchart'

configs["system"]["data_dir"] = "/root/autodl-tmp/extraction_data"
configs["system"]["cache_dir"] = "data/cache/"

configs["system"]["dataset"] = "Chart"
configs["system"]["snapshot_name"] = "PretrainKP"
system_configs.update_config(configs["system"])
db = datasets["Chart"](configs["db"], split)

Label file: /root/autodl-tmp/extraction_data/annotations/val.json
Loading from cache file: data/cache/val_cache.pkl
Loading annotations into memory...
Done (t=0.19s)


In [4]:
def get_pie_center(a, b, c):
    a,b,c = np.array(a), np.array(b), np.array(c)
    ca = c - a
    cb = c - b
    cosine_angle = np.dot(ca, cb) / (np.linalg.norm(ca) * np.linalg.norm(cb))
    angle = np.arccos(cosine_angle)
    r_square = (ca**2).sum()
    
    if ca[0]*cb[1]-ca[1]*cb[0] >= 0:
        return (a[0]+b[0]+c[0])/3., (a[1]+b[1]+c[1])/3., 0.5 * angle * r_square
    else:
        return 2*c[0]-(a[0]+b[0]+c[0])/3., 2*c[1]-(a[1]+b[1]+c[1])/3., np.pi * r_square - 0.5 * angle * r_square

def get_points(gts, preds, chartType):
    gt_keys, gt_cens = [], []
    area = 0
    
    if chartType == 'vbar_categorical':
        for bbox in gts.tolist():
            area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) 
            gt_keys.append((bbox[0],bbox[1], area))
            gt_keys.append((bbox[2],bbox[3], area))
            gt_cens.append( ( (bbox[0] + bbox[2])/2, (bbox[1] + bbox[3])/2, area ) )
    elif chartType == 'pie':
        for bbox in gts.tolist():
            a, b, c = (bbox[0], bbox[1]), (bbox[2], bbox[3]), (bbox[4], bbox[5])
            xce, yce, area = get_pie_center(a,b,c)                
            gt_keys.append((bbox[0],bbox[1], area))
            gt_keys.append((bbox[2],bbox[3], area))
            gt_keys.append((bbox[4],bbox[5], area))
            gt_cens.append((xce, yce, area))   
    elif chartType == 'line':
        for bbox in gts[0]:
            detection = np.array(bbox)
            if len(detection) <= 1: continue
            elif len(detection)//2 % 2 == 0:
                mid = len(detection) // 2
                xce, yce = (detection[mid-2] + detection[mid]) / 2, (detection[mid-1] + detection[mid+1]) / 2
            else:
                mid = len(detection) // 2
                xce, yce = detection[mid-1].copy(), detection[mid].copy()
            assert len(detection) % 2 == 0
            xs = detection[0:len(detection):2]
            ys = detection[1:len(detection):2]
            area = (max(max(xs) - min(xs), max(ys) - min(ys)) / len(detection) * 2) ** 2
                
            for x, y in zip(xs, ys):
                gt_keys.append((x,y, area))
            gt_cens.append((xce, yce, area))   

    pred_keys, pred_cens = [], []
    if '1' not in preds[0]: # baseline predictions
        if chartType == 'pie':
            pred_keys.append((preds[0][0][0], preds[0][0][1], preds[0][-1]))
            for pred in preds:
                pred_keys.append((pred[1][0], pred[1][1], pred[-1]))
        elif chartType == 'line':
            for pred in preds:
                pred_groups.append(np.array(pred))
        else:
            for pred in preds:
                pred_keys.append((pred[0],pred[1], 1.))
                pred_keys.append((pred[2],pred[3], 1.))
    else:  
        for point in preds[0]['1']:
            pred_keys.append((point[2],point[3], point[0]))
        for point in preds[1]['1']:
            pred_cens.append((point[2],point[3], point[0]))
    return gt_keys, gt_cens, pred_keys, pred_cens

In [5]:
def OKS(gt_p, pred_p):
    d2 = (gt_p[0] - pred_p[0]) ** 2 + (gt_p[1] - pred_p[1]) ** 2
    k2 = 0.1
    s2 = gt_p[2]
    return np.exp(d2/(s2 * k2) * (-1))

def computeTargetLabel(gt_ps, pred_ps, thres=0.75):
    y_true = []
    for pred_p in pred_ps:
        found = False
        for gt_p in gt_ps:
            if OKS(gt_p, pred_p) > thres:
                y_true.append(1)
                found = True
                break
        if not found:
            y_true.append(0)
    return y_true

#用于计算给定阈值（默认为 0.75）下，哪些 ground truth 点（gt_ps）被预测点（pred_ps）成功检测到。
def computeDetectedGT(gt_ps, pred_ps, thres=0.75):
    # 初始化一个空列表 y_true，用于存储每个 ground truth 点是否被成功检测到（1 表示检测成功，0 表示未检测到）。
    y_true = []
    for gt_p in gt_ps:
        found = False
        for pred_p in pred_ps:
            if OKS(gt_p, pred_p) > thres:
                y_true.append(1)
                found = True
                break
        # 如果遍历所有预测点后，found 仍为 False，说明当前 ground truth 点未被检测到。
        if not found:
            y_true.append(0)
    return y_true

In [6]:
mAP_keys = []
mAP_cens = []
max_iter = db.db_inds.size
print(max_iter)
for i in tqdm(range(50)):
    db_ind = db.db_inds[i]
    image_file = db.image_file(db_ind)
    gts = db.detections(db_ind)
    print(image_file.split('/')[-1])

3695


100%|██████████| 50/50 [00:00<00:00, 22229.72it/s]

c49285ca77f6aff6214dad492b688113_d3d3LmRhbmUuZ292LmNvCTE3MC4yMzguNjQuNzg=.xls-0-0.png
c49285ca77f6aff6214dad492b688113_d3d3LmRhbmUuZ292LmNvCTE3MC4yMzguNjQuNzg=.xls-0-1.png
c49830551627d9220fc08c5e9fe007b6_d3d3Lmdvdi5zY290CTEzNC4xOS4xNjEuMjQ5.xls-25-0.png
c49bfcb997437ebb0afc13986b889817_d3d3LndpbmVwaS5jb20JMjA3LjE1MC4yMTIuOTk=.xls-0-0.png
c49bfcb997437ebb0afc13986b889817_d3d3LndpbmVwaS5jb20JMjA3LjE1MC4yMTIuOTk=.xls-1-0.png
c49d11bc39a7a613709ce556f619c0e2_d3d3Lmpwby5nby5qcAkyMy40MS4yNTEuOTc=.xls-0-0.png
c49e369ac4e7b5cf62a8849afe9ecfb8_d3d3LnJpbnlhLm1hZmYuZ28uanAJMTA0LjgwLjg5LjUx.xls-0-0.png
c49e369ac4e7b5cf62a8849afe9ecfb8_d3d3LnJpbnlhLm1hZmYuZ28uanAJMTA0LjgwLjg5LjUx.xls-0-1.png
c49e369ac4e7b5cf62a8849afe9ecfb8_d3d3LnJpbnlhLm1hZmYuZ28uanAJMTA0LjgwLjg5LjUx.xls-0-2.png
c4a4e9e26218c55f620920c6e1f4269e_dzMuZGdlZWMubWVjLnB0CTIxMi41NS4xNDMuMzY=-0-0.png
c4a4e9e26218c55f620920c6e1f4269e_dzMuZGdlZWMubWVjLnB0CTIxMi41NS4xNDMuMzY=-0-1.png
c4aafc3df2bf520812a72afa1bfce6d3_c2Ftc2V0cHJvamVjdC5uZXQJ




In [7]:
with open('evaluation/KPDetection50000.json') as f:
    prediction = json.load(f)
    
mAP_keys = []
mAP_cens = []
max_iter = db.db_inds.size
print(list(prediction.keys())[:3])

['c4c0627d16eb05af88880a92faec6aaa_amFyY2hpdmVzLmNvbQkxOTIuMTg1Ljk4LjE5OA==-5-0.png', 'c49285ca77f6aff6214dad492b688113_d3d3LmRhbmUuZ292LmNvCTE3MC4yMzguNjQuNzg=.xls-0-0.png', 'c4c0627d16eb05af88880a92faec6aaa_amFyY2hpdmVzLmNvbQkxOTIuMTg1Ljk4LjE5OA==-6-0.png']


In [8]:
for i in tqdm(range(max_iter)):
    db_ind = db.db_inds[i]
    #print(db_ind)
    image_file = db.image_file(db_ind)
    #print(image_file)
    gts = db.detections(db_ind)
    #print(gts)
    # 如果没有 ground truth 数据，则跳过当前迭代。
    if gts is None or len(gts) == 0: continue
    #print(image_file.split('/')[-1])
    preds = prediction[image_file.split('/')[-1]]
    if preds is None or len(preds) == 0: continue
    if len(preds) == 3 and len(preds[2]) == 0: continue
    gt_keys, gt_cens, pred_keys, pred_cens = get_points(gts, preds, chartType)
    
    # 计算关于关键点（keys）的评估指标
    y_true_keys = computeTargetLabel(gt_keys, pred_keys)
    y_score_keys= [key[2] for key in pred_keys]
    
    detected_gt_keys = computeDetectedGT(gt_keys, pred_keys)
    miss_count = len(detected_gt_keys) - sum(detected_gt_keys)
    # 对漏检的 ground truth，其真实标签应为 1。
    y_true_keys = y_true_keys + [1] * miss_count
    # 漏检的 ground truth 的预测得分应为 0。
    y_score_keys = y_score_keys + [0] * miss_count
    
    score = average_precision_score(y_true_keys, y_score_keys)
#     if score < 0.3:
#         print(image_file)
    mAP_keys = np.append(mAP_keys,score)
    
    # cens
    y_true_cens = computeTargetLabel(gt_cens, pred_cens)
    y_score_cens= [key[2] for key in pred_cens]
    
    detected_gt_cens = computeDetectedGT(gt_cens, pred_cens)
    miss_count = len(detected_gt_cens) - sum(detected_gt_cens)
    y_true_cens = y_true_cens + [1] * miss_count
    y_score_cens = y_score_cens + [0] * miss_count
    
    mAP_cens = np.append(mAP_cens, average_precision_score(y_true_cens, y_score_cens))

  0%|          | 0/3695 [00:00<?, ?it/s]


NameError: name 'chartType' is not defined

In [None]:
mAP_keys = np.array(mAP_keys)
mAP_cens = np.array(mAP_cens)
print('mAP for keypoints:', mAP_keys[~np.isnan(mAP_keys)].mean(), " mAP for center points:",mAP_cens[~np.isnan(mAP_cens)].mean())