In [None]:
import numpy as np
from matplotlib import pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from collections import defaultdict
import random

from model_benchmark import metrics, utils
from model_benchmark.metric_provider import MetricProvider, METRIC_NAMES
from model_benchmark import metric_provider

## Loading data

In [None]:
cocoGt_path = "cocoGt.json"
cocoDt_path = "cocoDt.json"
eval_data_path = "eval_data.pkl"

In [None]:
cocoGt = COCO(cocoGt_path)
cocoDt = cocoGt.loadRes(cocoDt_path)
# cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')

import pickle
with open(eval_data_path, 'rb') as f:
    eval_data = pickle.load(f)

In [None]:
from importlib import reload
reload(metric_provider)
m = metric_provider.MetricProvider(eval_data, cocoGt, cocoDt)
m.base_metrics()

## Overview

In [None]:
# Overall Metrics
base_metrics = m.base_metrics()
r = list(base_metrics.values())
theta = [METRIC_NAMES[k] for k in base_metrics.keys()]
fig = go.Figure()
fig.add_trace(go.Scatterpolar(
    r=r+[r[0]], theta=theta+[theta[0]], fill='toself', name='Overall Metrics',
    hovertemplate='%{theta}: %{r:.2f}<extra></extra>',
))
fig.update_layout(polar=dict(radialaxis=dict(range=[0., 1.])),
                  title="Overall Metrics", width=600, height=500)
fig.show()

## Model Predictions

In [None]:
df = m.prediction_table()
df

## Outcome Counts

In [None]:
# Outcome counts
fig = go.Figure()
fig.add_trace(go.Bar(x=[m.TP_count], y=["Outcomes"], name='TP', orientation='h', marker=dict(color='#1fb466')))
fig.add_trace(go.Bar(x=[m.FN_count], y=["Outcomes"], name='FN', orientation='h', marker=dict(color='#dd3f3f')))
fig.add_trace(go.Bar(x=[m.FP_count], y=["Outcomes"], name='FP', orientation='h', marker=dict(color='#d5a5a5')))
fig.update_layout(barmode='stack', title="Outcome Counts")
fig.update_xaxes(title_text="Count")
# width=600, height=500
fig.update_layout(width=600, height=300)
fig.show()

## Recall

In [None]:
per_class_metrics_df = m.per_class_metrics()

In [None]:
# Per-class Precision and Recall bar chart
per_class_metrics_df_sorted = per_class_metrics_df.sort_values(by="f1")

blue_color = '#1f77b4'
orange_color = '#ff7f0e'
fig = go.Figure()
fig.add_trace(go.Bar(y=per_class_metrics_df_sorted["precision"], x=per_class_metrics_df_sorted["category"], name='Precision', marker=dict(color=blue_color)))
fig.add_trace(go.Bar(y=per_class_metrics_df_sorted["recall"], x=per_class_metrics_df_sorted["category"], name='Recall', marker=dict(color=orange_color)))
fig.update_layout(barmode='group', title="Per-class Precision and Recall (Sorted by F1)")
fig.update_xaxes(title_text="Category")
fig.update_yaxes(title_text="Value", range=[0, 1])
fig.show()

In [None]:
# Per-class Precision bar chart
# per_class_metrics_df_sorted = per_class_metrics_df.sort_values(by="precision")
fig = px.bar(per_class_metrics_df_sorted, x='category', y='precision', title="Per-class Precision (Sorted by F1)",
             color='precision', color_continuous_scale='Plasma')
if len(per_class_metrics_df_sorted) <= 20:
    fig.update_traces(text=per_class_metrics_df_sorted["precision"].round(2), textposition='outside')
fig.update_xaxes(title_text="Category")
fig.update_yaxes(title_text="Precision", range=[0, 1])
fig.show()

In [None]:
# Per-class Precision bar chart
# per_class_metrics_df_sorted = per_class_metrics_df.sort_values(by="recall")
fig = px.bar(per_class_metrics_df_sorted, x='category', y='recall', title="Per-class Recall (Sorted by F1)",
             color='recall', color_continuous_scale='Plasma')
if len(per_class_metrics_df_sorted) <= 20:
    fig.update_traces(text=per_class_metrics_df_sorted["recall"].round(2), textposition='outside')
fig.update_xaxes(title_text="Category")
fig.update_yaxes(title_text="Recall", range=[0, 1])
fig.show()

## PR-curve

In [None]:
pr_curve = m.pr_curve()

fig = go.Figure()
fig.add_trace(go.Scatter(x=m.coco_params.recThrs, y=pr_curve.mean(-1), mode='lines', name='PR Curve', marker=dict(color='#1f77b4')))
fig.update_layout(title="Precision-Recall Curve", xaxis_title="Recall", yaxis_title="Precision")
fig.update_traces(fill='tozeroy')
fig.update_layout(width=600, height=500)
fig.show()

In [None]:
# Precision-Recall curve per-class
pr_curve_per_class = m.pr_curve()
# shape (n_recall_thresholds, n_classes)
df = pd.DataFrame(pr_curve_per_class, columns=m.cat_names)

fig = px.line(df, x=df.index, y=df.columns, title="Precision-Recall Curve per Class",
              labels={"index": "Recall", "value": "Precision", "variable": "Category"},
              color_discrete_sequence=px.colors.qualitative.Prism, width=800, height=600)
fig.show()

## Confusion Matrix

In [None]:
confusion_matrix = m.confusion_matrix()

In [None]:
# Confusion Matrix
cat_names = m.cat_names
none_name = "(None)"

confusion_matrix_df = pd.DataFrame(np.log(confusion_matrix), index=cat_names + [none_name], columns=cat_names + [none_name])
fig = px.imshow(confusion_matrix_df, labels=dict(x="Predicted", y="Ground Truth", color="Count"), title="Confusion Matrix (log-scale)",
                width=1000, height=1000)

# Hover text
fig.update_traces(customdata=confusion_matrix,
                  hovertemplate='Count: %{customdata}<br>Predicted: %{x}<br>Ground Truth: %{y}')

# Text on cells
if len(cat_names) <= 20:
    fig.update_traces(text=confusion_matrix,
                      texttemplate="%{text}")

fig.show()

## Frequently Confused Classes

In [None]:
# Frequency of confusion as bar chart
confused_df = m.frequently_confused(confusion_matrix, topk_pairs=20)
confused_name_pairs = confused_df["category_pair"]
confused_prob = confused_df["probability"]
x_labels = [f"{pair[0]} - {pair[1]}" for pair in confused_name_pairs]
fig = go.Figure()
fig.add_trace(go.Bar(x=x_labels, y=confused_prob, marker=dict(color=confused_prob, colorscale="Reds")))
fig.update_layout(title="Frequently confused class pairs", xaxis_title="Class pair", yaxis_title="Probability")
fig.update_traces(text=confused_prob.round(2))
fig.show()

## IoU Distribution

In [None]:
fig = go.Figure()
nbins = 40
fig.add_trace(go.Histogram(x=m.ious, nbinsx=nbins))
fig.update_layout(title="IoU Distribution", xaxis_title="IoU", yaxis_title="Count",
                  width=600, height=500)

# Add annotation for mean IoU as vertical line
mean_iou = m.ious.mean()
y1 = len(m.ious) // nbins
fig.add_shape(type="line", x0=mean_iou, x1=mean_iou, y0=0, y1=y1, line=dict(color="orange", width=2, dash="dash"))
fig.add_annotation(x=mean_iou, y=y1, text=f"Mean IoU: {mean_iou:.2f}", showarrow=False)
fig.show()

## Calibration Score

In [None]:
# Calibration curve (only positive predictions)
true_probs, pred_probs = m.calibration_metrics.calibration_curve()

fig = go.Figure()
fig.add_trace(go.Scatter(x=pred_probs, y=true_probs, mode='lines+markers', name='Calibration plot (Model)', 
                         line=dict(color='blue'), marker=dict(color='blue')))
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Perfectly calibrated', 
                         line=dict(color='orange', dash='dash')))

fig.update_layout(
    title='Calibration Curve (only positive predictions)',
    xaxis_title='Confidence Score',
    yaxis_title='Fraction of True Positives',
    legend=dict(x=0.6, y=0.1),
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1]),
    width=700, height=500
)

fig.show()

In [None]:
# F1 score, Precision, Recall vs Confidence Score
scores_vs_metrics = pd.DataFrame(m.calibration_metrics.scores_vs_metrics_avg())
idxmax = scores_vs_metrics['f1'].idxmax()
best_threshold = scores_vs_metrics['scores'][idxmax]
best_f1 = scores_vs_metrics['f1'][idxmax]

fig = px.line(df, x="scores", y=["precision", "recall", "f1"], title="Performance at Different Confidence Thresholds",
                labels={"scores": "Confidence Threshold", "value": "Value", "variable": "Metric"},
                width=800, height=500)
fig.update_layout(yaxis=dict(range=[0, 1]))

# Add vertical line for best threshold
fig.add_shape(type="line", x0=best_threshold, x1=best_threshold, y0=0, y1=best_f1, line=dict(color="orange", width=2, dash="dash"))
fig.add_annotation(x=best_threshold, y=best_f1+0.04, text=f"F1-optimal threshold: {best_threshold:.2f}", showarrow=False)
fig.show()

In [None]:
# Histogram of confidence scores (TP vs FP)
scores_tp, scores_fp = m.calibration_metrics.scores_tp_and_fp(iou_idx=0)

fig = go.Figure()
fig.add_trace(go.Histogram(x=scores_tp, name='TP', marker=dict(color='#1fb466'), opacity=0.7))
fig.add_trace(go.Histogram(x=scores_fp, name='FP', marker=dict(color='#dd3f3f'), opacity=0.7))
fig.update_layout(barmode='overlay', title="Histogram of Confidence Scores (TP vs FP)",
                  width=800, height=500)
fig.update_xaxes(title_text="Confidence Score", range=[0, 1])
fig.update_yaxes(title_text="Count")
fig.show()

## Per-class

In [None]:
# AP per-class
ap_per_class = m.coco_precision[:, :, :, 0, 2].mean(axis=(0, 1))
# Per-class Average Precision (AP)
fig = px.scatter_polar(r=ap_per_class, theta=m.cat_names, title="Per-class Average Precision (AP)",
                       labels=dict(r="Average Precision", theta="Category"),
                       width=800, height=800,
                       range_r=[0, 1])
# fill points
fig.update_traces(fill='toself')
fig.show()

In [None]:
# Per-class Counts
iou_thres = 0

tp = m.true_positives.sum(1)[:,iou_thres]
fp = m.false_positives.sum(1)[:,iou_thres]
fn = m.false_negatives.sum(1)[:,iou_thres]

# normalize
support = tp + fn
tp_rel = tp / support
fp_rel = fp / support
fn_rel = fn / support

# sort by f1
sort_scores = 2 * tp / (2 * tp + fp + fn)

K = len(m.cat_names)
sort_indices = np.argsort(sort_scores)
cat_names_sorted = [m.cat_names[i] for i in sort_indices]
tp_rel, fn_rel, fp_rel = tp_rel[sort_indices], fn_rel[sort_indices], fp_rel[sort_indices]

In [None]:
# Stacked per-class counts
data = {
    "count": np.concatenate([tp_rel, fn_rel, fp_rel]),
    "type": ["TP"]*K + ["FN"]*K + ["FP"]*K,
    "category": cat_names_sorted*3
}

df = pd.DataFrame(data)

color_map = {
    'TP': '#1fb466',
    'FN': '#dd3f3f',
    'FP': '#d5a5a5'
}
fig = px.bar(df, x="category", y="count", color="type", title="Per-class Outcome Counts",
             labels={'count': 'Total Count', "category": "Category"},
             color_discrete_map=color_map)

fig.show()

In [None]:
from matplotlib import cm

t = 0
tp = m.true_positives.sum(0)[:,t]
fp = m.false_positives.sum(0)[:,t]
fn = m.false_negatives.sum(0)[:,t]

y_edges = np.arange(min(tp) - 0.5, max(tp) + 1.5, 1)
x_edges = np.arange(min(fp+fn) - 0.5, max(fp+fn) + 1.5, 1)
heatmap, y_edges, x_edges = np.histogram2d(tp, fp+fn, bins=(y_edges, x_edges))

z_max = np.max(heatmap)
gamma = 0.95

colors = np.zeros((heatmap.shape[0], heatmap.shape[1], 3))  # for RGB channels
colormap_name = 'RdYlGn_r'
cmap = cm.get_cmap(colormap_name)

for i in range(heatmap.shape[0]):
    for j in range(heatmap.shape[1]):
        tp_val = x_edges[j] + 0.5
        fp_fn_val = y_edges[i] + 0.5
        
        intensity = heatmap[i, j]
        if tp_val + fp_fn_val > 0:
            value = tp_val / (tp_val + fp_fn_val)
        else:
            value = 0
        
        # green to red colormap
        colormap_name = 'RdYlGn_r'
        color = cmap(value)  # Get a color from a colormap
        # Adjust the color intensity based on the heatmap value
        if intensity > 0:
            c = np.array(color[:3]) * max(0.2, np.log(intensity) / np.log(z_max))
            colors[i, j, :] = c**gamma
        else:
            colors[i, j, :] = np.array(color[:3]) * 0.12

# Plot the colored heatmap
fig = px.imshow(colors, labels=dict(x="Count of Errors", y="Count of True Predictions"), title="TP vs FP+FN", text_auto=True, origin='lower',
                width=1000, height=1000)

# Adding text to each pixel
for i in range(heatmap.shape[0]):
    for j in range(heatmap.shape[1]):
        fig.add_annotation(
            x=j, 
            y=i, 
            text=str(int(heatmap[i, j])),
            showarrow=False,
            font=dict(color="#ddd", size=10)
        )

fig.show()