In [None]:
import numpy as np
from matplotlib import pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from PIL import Image
import supervisely as sly
import json
from IPython.display import display

from model_benchmark import metrics, utils
from model_benchmark.metric_provider import MetricProvider, METRIC_NAMES
from model_benchmark import metric_provider
from model_benchmark.prediction_gallery import prediction_gallery

## Loading data

In [None]:
cocoGt_path = "cocoGt_remap.json"
cocoDt_path = "data/model-benchmark/COCO 2017 val (YOLOv8-L, conf-0.01)/cocoDt.json"
eval_data_path = "eval_data_conf-0.01.pkl"

In [None]:
cocoGt = COCO(cocoGt_path)
cocoDt = cocoGt.loadRes(cocoDt_path)
# cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')

import pickle
with open(eval_data_path, 'rb') as f:
    eval_data = pickle.load(f)

In [None]:
from importlib import reload
reload(metric_provider)
m = metric_provider.MetricProvider(eval_data['matches'], eval_data['coco_metrics'], eval_data['params'], cocoGt, cocoDt)
m.base_metrics()

In [None]:
cat_ids_rare, cat_names_rare = utils.get_rare_classes(cocoGt)
gallery = prediction_gallery(m.matches, cocoGt, cat_ids_rare)

In [None]:
image_dir = "data/COCO2017/img/val2017/"
ann_dir_gt = "data/model-benchmark/COCO 2017 val (YOLOv8-L, conf-25)/gt_dataset/val2017/ann/"
ann_dir_dt = "data/model-benchmark/COCO 2017 val (YOLOv8-L, conf-25)/dt_dataset/val2017/ann/"
meta_gt = "data/model-benchmark/COCO 2017 val (YOLOv8-L, conf-25)/gt_dataset/meta.json"
meta_dt = "data/model-benchmark/COCO 2017 val (YOLOv8-L, conf-25)/dt_dataset/meta.json"
with open(meta_gt, 'r') as f:
    meta_gt = json.load(f)
with open(meta_dt, 'r') as f:
    meta_dt = json.load(f)
meta_gt = sly.ProjectMeta.from_json(meta_gt)
meta_dt = sly.ProjectMeta.from_json(meta_dt)

def get_diffs(img_id, matches):
    img_name = cocoGt.loadImgs(img_id)[0]['file_name']
    img = Image.open(image_dir + img_name)
    ann_gt = sly.Annotation.load_json_file(ann_dir_gt + img_name+'.json', meta_gt)
    ann_gt = ann_gt.clone(labels=[l for l in ann_gt.labels if l.geometry.geometry_name() == "rectangle"])
    ann_dt = sly.Annotation.load_json_file(ann_dir_dt + img_name+'.json', meta_dt)
    img_gt = np.array(img)
    img_dt = np.array(img)
    ann_gt.draw(img_gt, thickness=2, draw_class_names=True, fill_rectangles=False)
    ann_dt.draw(img_dt, thickness=2, draw_class_names=True, fill_rectangles=False)
    # img_gt = Image.fromarray(img_gt)
    # img_dt = Image.fromarray(img_dt)

    that_matches = [match for match in matches if match['image_id'] == img_id and match['type'] != 'TP']

    labels = []
    for match in that_matches:
        if match['gt_id'] is not None:
            # FN
            label = cocoGt.anns[match['gt_id']]
            if label['iscrowd']:
                print("iscrowd")
                continue
            bbox = label['bbox']
            cat_id = label['category_id']
            cat_name = cocoGt.loadCats(cat_id)[0]['name']
            obj_class = meta_gt.get_obj_class(cat_name)
            obj_class = obj_class.clone(color=(245,140,40))
            left, top, width, height = bbox
            rect = sly.Rectangle(top=top, left=left, bottom=top+height-1, right=left+width-1)
            label = sly.Label(rect, obj_class)
        elif match['dt_id'] is not None:
            # FP
            label = cocoDt.anns[match['dt_id']]
            bbox = label['bbox']
            cat_id = label['category_id']
            cat_name = cocoDt.loadCats(cat_id)[0]['name']
            obj_class = meta_dt.get_obj_class(cat_name)
            obj_class = obj_class.clone(color=(245,40,40))
            left, top, width, height = bbox
            rect = sly.Rectangle(top=top, left=left, bottom=top+height-1, right=left+width-1)
            label = sly.Label(rect, obj_class)
        labels.append(label)
    ann_diff = ann_gt.clone(labels=labels)
    img_diff = np.array(img)
    ann_diff.draw_pretty(img_diff, thickness=2)
    # img_diff = Image.fromarray(img_diff)

    # join images at X axis
    img_join = np.concatenate([img_gt, img_dt, img_diff], axis=1)

    return img_join


In [None]:
N = 12
gallery_keys = list(gallery.keys())
gallery_keys = gallery_keys[:6]
for i in range(N):
    idx = i % len(gallery_keys)
    k = i // len(gallery_keys)
    key = gallery_keys[idx]
    row = gallery[key][k]

    img_id = row[0]
    img_join = get_diffs(img_id, m.matches)
    print(key, row[1])
    display(Image.fromarray(img_join))