In [None]:
import numpy as np
from matplotlib import pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from collections import defaultdict
import random
import pickle

from model_benchmark import metrics, utils
from model_benchmark.metric_provider import MetricProvider, METRIC_NAMES
from model_benchmark import metric_provider

## Loading data

In [None]:
cocoGt_path = "cocoGt_remap.json"
cocoDt_paths = [
    "data/model-benchmark/COCO 2017 val (YOLOv8-L, conf-0.01)/cocoDt.json",
    "data/model-benchmark/COCO 2017 val (RT-DETR r34, conf-0.01)/cocoDt.json",
    "data/model-benchmark/COCO 2017 val (DINO, conf-0.05)/cocoDt.json"
]
eval_paths = [
    "eval_data_conf-0.01.pkl",
    "eval_data_rtdetr_conf-0.01.pkl",
    "eval_data_dino.pkl"
]
names = [
    "YOLOv8-L",
    "RT-DETR r34",
    "DINO"
]

In [None]:
cocoGt = COCO(cocoGt_path)
cocoDts = [cocoGt.loadRes(cocoDt_path) for cocoDt_path in cocoDt_paths]

In [None]:
from time import time
from importlib import reload
reload(metric_provider)

providers_full = []
providers = []

for cocoDt, eval_path in zip(cocoDts, eval_paths):
    with open(eval_path, 'rb') as f:
        eval_data = pickle.load(f)
    
    t0 = time()
    m_full = metric_provider.MetricProvider(eval_data['matches'], eval_data['coco_metrics'], eval_data['params'], cocoGt, cocoDt)

    score_profile = m_full.confidence_score_profile()
    f1_optimal_conf1, best_f1 = m_full.get_f1_optimal_conf()
    matches_thresholded = metric_provider.filter_by_conf(eval_data['matches'], f1_optimal_conf1)
    m = metric_provider.MetricProvider(matches_thresholded, eval_data['coco_metrics'], eval_data['params'], cocoGt, cocoDt)
    print(f"Time: {time() - t0:.2f} s")

    providers_full.append(m_full)
    providers.append(m)

## Overview

In [None]:
# Overall Metrics comparison

fig = go.Figure()
for m, name in zip(providers, names):
    base_metrics = m.base_metrics()
    r = list(base_metrics.values())
    theta = [metric_provider.METRIC_NAMES[k] for k in base_metrics.keys()]
    fig.add_trace(go.Scatterpolar(
        r=r+[r[0]], theta=theta+[theta[0]], name=name,
        hovertemplate='%{theta}: %{r:.2f}<extra></extra>',
    ))
fig.update_layout(polar=dict(radialaxis=dict(range=[0., 1.]),
                             angularaxis=dict(rotation=90, direction='clockwise')),
                  title="Overall Metrics", width=600, height=500,)
fig.show()

In [None]:
# Grouped bar chart for base metrics

df = pd.DataFrame({
    'Metric': [metric_provider.METRIC_NAMES[k] for k in base_metrics.keys()],
    **{name: list(m.base_metrics().values()) for m, name in zip(providers, names)}
})
fig = px.bar(df, x='Metric', y=names, barmode='group')
fig.update_layout(title="Overall Metrics")
fig.show()

## Outcome Counts

In [None]:
fig = go.Figure()

for m, name in zip(providers, names):
    fig.add_traces([
        go.Bar(x=[m.TP_count], y=[name], name='TP', orientation='h', marker=dict(color='#1fb466'), legendgroup='TP', showlegend=False),
        go.Bar(x=[m.FN_count], y=[name], name='FN', orientation='h', marker=dict(color='#dd3f3f'), legendgroup='FN', showlegend=False),
        go.Bar(x=[m.FP_count], y=[name], name='FP', orientation='h', marker=dict(color='#d5a5a5'), legendgroup='FP', showlegend=False)
    ])

fig.data[0].showlegend = True
fig.data[1].showlegend = True
fig.data[2].showlegend = True

fig.update_layout(barmode='stack', title="Outcome Counts",
                  width=600, height=None)
fig.update_xaxes(title_text="Count")
fig.update_yaxes(tickangle=0)
fig.show()

## Recall

In [None]:
blue_color = '#1f77b4'
orange_color = '#ff7f0e'

In [None]:
import plotly.colors
colors = px.colors.qualitative.Plotly
colors_a = [f"rgba{tuple(list(plotly.colors.hex_to_rgb(c))+[0.3])}" for c in colors]

In [None]:
# 1. define the order of cats
# 2. sort all metrics by the order of cats

f1_mean = np.zeros(len(providers[0].cat_ids))
per_class_metrics = []
for m, name in zip(providers, names):
    per_class_metrics_df = m.per_class_metrics()
    f1_mean += per_class_metrics_df["f1"]
    per_class_metrics.append(per_class_metrics_df)
f1_mean /= len(providers)

idx_sort = np.argsort(f1_mean)

In [None]:
fig = go.Figure()

for i, (df, name) in enumerate(zip(per_class_metrics, names)):
    df = df.iloc[idx_sort]
    fig.add_trace(
        go.Scatter(x=df["category"], y=df["f1"], mode='lines', name=name, line=dict(width=3, color=colors[i]), legendgroup=name, fill='tozeroy', fillcolor=colors_a[i]),
        )
    
# for i, (df, name) in enumerate(zip(per_class_metrics, names)):
#     df = df.iloc[idx_sort]
#     fig.add_trace(
#         go.Scatter(x=df["category"], y=df["f1"], mode='lines', name=name, line=dict(width=3, color=colors[i]), legendgroup=name, showlegend=False),
#         )

fig.update_layout(title="F1-score (all classes)", xaxis_title="Class", yaxis_title="F1", width=None, height=500)
fig.update_xaxes(tickangle=45)
fig.update_yaxes(range=[0., 1.])
fig.show()

In [None]:
## F1 comparison bar chart
fig = go.Figure()

for i, (df, name) in enumerate(zip(per_class_metrics, names)):
    df = df.iloc[idx_sort]
    fig.add_trace(
        go.Bar(x=df["category"], y=df["f1"], name=name, marker=dict(color=colors[i]), opacity=0.5)
    )
    
# for i, (df, name) in enumerate(zip(per_class_metrics, names)):
#     df = df.iloc[idx_sort]
#     fig.add_trace(
#         go.Scatter(x=df["category"], y=df["f1"], mode='lines', name=name, line=dict(width=3, color=colors[i]), legendgroup=name, showlegend=False),
#         )

fig.update_layout(title="F1-score (all classes)", xaxis_title="Class", yaxis_title="F1", width=None, height=500,
                   barmode='group')
fig.update_xaxes(tickangle=45)
fig.update_yaxes(range=[0., 1.])
fig.show()

In [None]:
fig = go.Figure()

for i, (df, name) in enumerate(zip(per_class_metrics, names)):
    df = df.iloc[idx_sort]
    fig.add_trace(
        go.Scatter(x=df["category"], y=df["precision"], mode='lines', name=name, line=dict(width=3, color=colors[i]), legendgroup=name, fill='tozeroy', fillcolor=colors_a[i]),
        )
    
fig.update_layout(title="Precision (all classes)", xaxis_title="Class", yaxis_title="Precision", width=None, height=500)
fig.update_xaxes(tickangle=45)
fig.update_yaxes(range=[0., 1.])
fig.show()

In [None]:
fig = go.Figure()

for i, (df, name) in enumerate(zip(per_class_metrics, names)):
    df = df.iloc[idx_sort]
    fig.add_trace(
        go.Scatter(x=df["category"], y=df["recall"], mode='lines', name=name, line=dict(width=3, color=colors[i]), legendgroup=name, fill='tozeroy', fillcolor=colors_a[i]),
        )
    
fig.update_layout(title="Recall (all classes)", xaxis_title="Class", yaxis_title="Recall", width=None, height=500)
fig.update_xaxes(tickangle=45)
fig.update_yaxes(range=[0., 1.])
fig.show()

## PR-curve

In [None]:
# Precision-Recall curve
fig = go.Figure()
for i, (m, name) in enumerate(zip(providers, names)):
    pr_curve = m.pr_curve()
    fig.add_trace(
        go.Scatter(x=m.recThrs, y=pr_curve.mean(-1), mode='lines', name=name, line=dict(width=2, color=colors[i]), fill='tozeroy', fillcolor=colors_a[i]),
    )
    
fig.update_layout(title="Precision-Recall Curve", xaxis_title="Recall", yaxis_title="Precision",
                    width=700, height=600)
fig.update_yaxes(range=[0., 1.01])
fig.update_xaxes(range=[0., 1.])
fig.add_trace(
    go.Scatter(
        x=m.recThrs,
        y=[1]*len(m.recThrs),
        name="Perfect",
        line=dict(color='orange', dash='dash'),
        showlegend=True
    )
)
fig.show()

## Calibration

In [None]:
# Calibration curve (only positive predictions)

fig = go.Figure()

for i, (m_full, name) in enumerate(zip(providers_full, names)):
    true_probs, pred_probs = m_full.calibration_metrics.calibration_curve()
    fig.add_trace(go.Scatter(x=pred_probs, y=true_probs, mode='lines+markers', name=name))

# Perfectly calibrated line
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Perfectly calibrated', 
                         line=dict(color='orange', dash='dash')))

fig.update_layout(
    title='Calibration Curve (only positive predictions)',
    xaxis_title='Confidence Score',
    yaxis_title='Fraction of True Positives',
    legend=dict(x=0.65, y=0.1),
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1]),
    width=700, height=500
)

fig.show()

In [None]:
# Confidence score profile

fig = go.Figure()
for i, (m_full, name) in enumerate(zip(providers_full, names)):
    score_profile = m_full.confidence_score_profile()
    f1_optimal_conf, best_f1 = m_full.get_f1_optimal_conf()
    f1 = score_profile['f1']
    scores = score_profile['scores']

    # downsample
    if len(f1) > 5000:
        f1 = f1[::len(f1)//1000]
        scores = scores[::len(scores)//1000]

    fig.add_trace(go.Scatter(x=scores, y=f1, mode='lines', name=name))

    # Add vertical line for the best threshold
    fig.add_shape(type="line", x0=f1_optimal_conf, x1=f1_optimal_conf, y0=0, y1=best_f1, line=dict(color=colors[i], width=1, dash="dash"))
    fig.add_annotation(x=f1_optimal_conf, y=best_f1+0.04, text=f"F1-optimal threshold: {f1_optimal_conf:.2f}", showarrow=False, font=dict(color=colors[i]))
    
fig.update_layout(
    yaxis=dict(range=[0, 1]),
    xaxis=dict(range=[0, 1], tick0=0, dtick=0.1),
    title='Confidence Score Profile',
    xaxis_title='Confidence Score',
    yaxis_title='F1',
    width=None, height=500,
)

fig.show()

## All Classes

In [None]:
# Per-class Average Precision (AP)

fig = go.Figure()
for i, (m, name) in enumerate(zip(providers, names)):
    ap_per_class = m.coco_precision[:, :, :, 0, 2].mean(axis=(0, 1))
    r = np.concatenate([ap_per_class, [ap_per_class[0]]])
    theta = m.cat_names + [m.cat_names[0]]
    fig.add_trace(go.Scatterpolar(r=r, theta=theta, name=name,
                                  mode='lines+markers',
                                  fill='toself',
                                  hovertemplate='%{theta}: %{r:.2f} (AP)',
                                  fillcolor=colors_a[i],))

fig.update_layout(
    polar=dict(
        radialaxis=dict(range=[0, 1]),
        angularaxis=dict(direction="clockwise", rotation=90),
        ),
    title="Per-class Average Precision (AP)",
    width=800, height=800,
)

fig.show()

## Per-class

## Corr-plot