## Visualization of Retail Detections

In [None]:
import cv2
import json
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

In [None]:
def load_json(json_det_file):
    with open(str(json_det_file), 'r') as f:
        dets_dict = json.load(f)
    
    return dets_dict

In [None]:
data_dir = Path('../data/test/')
subset = 'a'
model_names = ['yolov4_6a_retail_one_best', 'yolov4_9a_retail_one_140', 'yolov4_9a_retail_one_200']

dets_models = []
for model_name in model_names:
    det_file = data_dir / '{}_det_{}.json'.format(subset, model_name)
    dets = load_json(det_file)
    dets['model_name'] = model_name
    dets_models.append(dets)

In [None]:
def get_detections(index, subset, dets):
    img_info = dets['images'][index]
    img_name = img_info['file_name']
    img_id = img_info['id']
    img_path = data_dir / '{}_images'.format(subset) / img_name
    
    bbox_list = []
    score_list = []
    for det in dets['annotations']:
        if det['image_id'] == img_id:
            bbox = det['bbox']
            score = det.get('score', 0.0)
            bbox_list.append(bbox)
            score_list.append(score)
    return {'img_path': img_path, 'bbox_list': bbox_list, 'score_list': score_list}

In [None]:
plt.rcParams['figure.figsize'] = [30, 10]
# colors = np.random.randint(0, 255, size=(200, 3), dtype=np.uint8)

def show_detection(img_index, subset, dets_models, show_score_thr = 0.0, color = (85, 138, 29), show_score=True):
    fig, ax = plt.subplots(1, len(dets_models))
    
    for model_id, dets in enumerate(dets_models):
        det_dict = get_detections(img_index, subset, dets)
        img_path = det_dict['img_path']
        bbox_list = det_dict['bbox_list']
        score_list = det_dict['score_list']
        
        img = cv2.imread(str(img_path))
        for i, bbox in enumerate(bbox_list):
            score = score_list[i]
            if score < show_score_thr:
                continue
                
            x1, y1, w, h = bbox
            cv2.rectangle(img, (x1, y1), (x1+w, y1+h), color, 3)
            if show_score:
                cv2.putText(img, '{:.2f}'.format(score), (int(x1 + w/2) - 10, int(y1 + h / 2)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    
        cv2.putText(img, '{}'.format(img_index), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, color, 2)
        ax[model_id].set_title(dets['model_name'])
        ax[model_id].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        ax[model_id].set_xticks([])
        ax[model_id].set_yticks([])
    
    plt.tight_layout()
    plt.show()

In [None]:
# 漏检
for idx in [574, 193, 745, 374, 1022, 1026, 1271, 982]:
    show_detection(idx, subset, dets_models, show_score_thr = 0.1)

In [None]:
# 误检
for idx in [489, 1308, 903, 1038, 1021, 521, 958]:
    show_detection(idx, subset, dets_models, show_score_thr = 0.1)

In [None]:
# 定位
for idx in [450, 253, 1044]:
    show_detection(idx, subset, dets_models, show_score_thr = 0.1)

In [None]:
num_images = len(dets_models[0]['images'])
idx = np.random.randint(num_images)
print(idx)
show_detection(idx, subset, dets_models, show_score_thr = 0.1)