In [None]:
import numpy as np
from matplotlib import pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
from collections import defaultdict
import random

from model_benchmark import metrics, utils

## Loading data

In [None]:
cocoGt = COCO("cocoGt.json")
cocoDt = cocoGt.loadRes("cocoDt.json")
# cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')

import pickle
with open('eval_data.pkl', 'rb') as f:
    eval_data = pickle.load(f)

In [None]:
import warnings

from sklearn.metrics import log_loss, brier_score_loss
from sklearn.calibration import calibration_curve
import model_benchmark.metrics as metrics


def get_outcomes_per_image(matches, cocoGt: COCO):
    img_ids = sorted(cocoGt.getImgIds())
    imgId2idx = {img_id: idx for idx, img_id in enumerate(img_ids)}
    outcomes_per_image = np.zeros((len(img_ids), 3), dtype=float)
    for m in matches:
        img_id = m["image_id"]
        idx = imgId2idx[img_id]
        if m["type"] == "TP":
            outcomes_per_image[idx, 0] += 1
        elif m["type"] == "FP":
            outcomes_per_image[idx, 1] += 1
        elif m["type"] == "FN":
            outcomes_per_image[idx, 2] += 1
    return img_ids, outcomes_per_image


class Metrics:
    def __init__(self, eval_data: dict, cocoGt: COCO, cocoDt: COCO):

        # eval_data
        self.true_positives = eval_data["true_positives"]
        self.false_positives = eval_data["false_positives"]
        self.false_negatives = eval_data["false_negatives"]
        self.matches = eval_data["matches"]
        self.coco_stats = eval_data["coco_stats"]
        self.coco_precision = eval_data["coco_precision"]
        self.coco_params : Params = eval_data["coco_params"]

        # Counts
        self.TP_count = int(self.true_positives[...,0].sum())
        self.FP_count = int(self.false_positives[...,0].sum())
        self.FN_count = int(self.false_negatives[...,0].sum())

        # Matches
        self.tp_matches = [m for m in self.matches if m['type'] == "TP"]
        self.fp_matches = [m for m in self.matches if m['type'] == "FP"]
        self.fn_matches = [m for m in self.matches if m['type'] == "FN"]
        self.confused_matches = [m for m in self.fp_matches if m['miss_cls']]
        self.fp_not_confused_matches = [m for m in self.fp_matches if not m['miss_cls']]
        self.ious = np.array([m['iou'] for m in self.matches if m['iou']])

        # Calibration
        self.calibration_metrics = CalibrationMetrics(self.tp_matches, self.fp_matches, self.fn_matches, self.coco_params.iouThrs)

        # info
        self.cat_ids = cocoGt.getCatIds()
        self.cat_names = [cocoGt.cats[cat_id]['name'] for cat_id in self.cat_ids]

    def base_metrics(self):
        tp = self.true_positives.sum(1)
        fp = self.false_positives.sum(1)
        fn = self.false_negatives.sum(1)
        confuse_count = len(self.confused_matches)

        mAP = self.coco_stats[0]
        precision = np.mean(tp / (tp + fp))
        recall = np.mean(tp / (tp + fn))
        iou = np.mean(self.ious)
        classification_accuracy = self.TP_count / (self.TP_count + confuse_count)
        calibration_score = 1 - self.calibration_metrics.maximum_calibration_error()

        return {
            "mAP": mAP,
            "precision": precision,
            "recall": recall,
            "iou": iou,
            "classification_accuracy": classification_accuracy,
            "calibration_score": calibration_score
        }
    
    def per_class_metrics(self):
        tp = self.true_positives.sum(1).mean(1)
        fp = self.false_positives.sum(1).mean(1)
        fn = self.false_negatives.sum(1).mean(1)
        pr = tp / (tp + fp)
        rc = tp / (tp + fn)
        f1 = 2 * pr * rc / (pr + rc)
        return pd.DataFrame({
            "category": self.cat_names,
            "precision": pr,
            "recall": rc,
            "f1": f1
        })
    
    def pr_curve(self):
        pr_curve = self.coco_precision[:,:,:,0,2].mean(0)
        return pr_curve
    
    def prediction_table(self):
        img_ids, outcomes_per_image = get_outcomes_per_image(self.matches, cocoGt)
        image_names = [cocoGt.imgs[img_id]["file_name"] for img_id in img_ids]
        # inference_time = ...
        n_gt = outcomes_per_image[:,0] + outcomes_per_image[:,2]
        n_dt = outcomes_per_image[:,0] + outcomes_per_image[:,1]
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            precision_per_image = outcomes_per_image[:,0] / n_dt
            recall_per_image = outcomes_per_image[:,0] / n_gt
            f1_per_image = 2 * precision_per_image * recall_per_image / (precision_per_image + recall_per_image)
        prediction_table = pd.DataFrame({
            "image_name": image_names,
            "N gt": n_gt,
            "N dt": n_dt,
            "TP": outcomes_per_image[:,0],
            "FP": outcomes_per_image[:,1],
            "FN": outcomes_per_image[:,2],
            "Precision": precision_per_image,
            "Recall": recall_per_image,
            "F1": f1_per_image
            })
        return prediction_table
    
    def confusion_matrix(self):
        K = len(self.cat_ids)
        catId2idx = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
        idx2catId = {i: cat_id for cat_id, i in catId2idx.items()}

        confusion_matrix = np.zeros((K+1, K+1), dtype=int)

        for m in self.confused_matches:
            cat_idx_pred = catId2idx[m['category_id']]
            cat_idx_gt = catId2idx[cocoGt.anns[m['gt_id']]['category_id']]
            confusion_matrix[cat_idx_pred, cat_idx_gt] += 1

        for m in self.tp_matches:
            cat_idx = catId2idx[m['category_id']]
            confusion_matrix[cat_idx, cat_idx] += 1

        for m in self.fp_not_confused_matches:
            cat_idx_pred = catId2idx[m['category_id']]
            confusion_matrix[cat_idx_pred, -1] += 1

        for m in self.fn_matches:
            cat_idx_gt = catId2idx[m['category_id']]
            confusion_matrix[-1, cat_idx_gt] += 1
        
        return confusion_matrix
    
    def frequently_confused(self, confusion_matrix, topk_pairs=20):
        # Frequently confused class pairs
        idx2catId = {i: cat_id for i, cat_id in enumerate(self.cat_ids)}
        cm = confusion_matrix[:-1,:-1]
        cm_l = np.tril(cm, -1)
        cm_u = np.triu(cm, 1)
        cm = cm_l + cm_u.T
        cm_flat = cm.flatten()
        inds_sort = np.argsort(-cm_flat)[:topk_pairs]
        inds_sort = inds_sort[cm_flat[inds_sort] > 0]  # remove zeros
        inds_sort = np.unravel_index(inds_sort, cm.shape)

        # probability of confusion: (predicted A, actually B + predicted B, actually A) / (predicted A + predicted B)
        confused_counts = cm[inds_sort]
        dt_total = confusion_matrix.sum(1)
        dt_pair_sum = np.array([dt_total[i] + dt_total[j] for i, j in zip(*inds_sort)])
        confused_prob = confused_counts / dt_pair_sum
        inds_sort2 = np.argsort(-confused_prob)

        confused_idxs = np.array(inds_sort).T[inds_sort2]
        confused_name_pairs = [(self.cat_names[i], self.cat_names[j]) for i, j in confused_idxs]
        confused_counts = confused_counts[inds_sort2]
        confused_prob = confused_prob[inds_sort2]
        confused_catIds = [(idx2catId[i], idx2catId[j]) for i, j in confused_idxs]

        return pd.DataFrame({
            "category_pair": confused_name_pairs,
            "category_id_pair": confused_catIds,
            "count": confused_counts,
            "probability": confused_prob
        })

    def iou_histogram(self):
        iou_hist = np.histogram(self.ious, range=(0.5, 1))
        return iou_hist
    
    
class CalibrationMetrics:
    def __init__(self, tp_matches, fp_matches, fn_matches, iouThrs):
        scores = []
        classes = []
        iou_idxs = []
        p_matches = tp_matches + fp_matches
        per_class_count = defaultdict(int)
        for m in p_matches:
            if m['type'] == "TP" and m['iou'] is not None:
                iou_idx = np.searchsorted(iouThrs, m['iou'])
                iou_idxs.append(iou_idx)
                assert iou_idx > 0
            else:
                iou_idxs.append(0)
            scores.append(m['score'])
            classes.append(m["category_id"])
            if m['type'] == "TP":
                per_class_count[m["category_id"]] += 1
        for m in fn_matches:
            per_class_count[m["category_id"]] += 1
        per_class_count = dict(per_class_count)
        scores = np.array(scores)
        inds_sort = np.argsort(-scores)
        scores = scores[inds_sort]
        classes = np.array(classes)[inds_sort]
        iou_idxs = np.array(iou_idxs)[inds_sort]

        self.scores = scores
        self.classes = classes
        self.iou_idxs = iou_idxs
        self.per_class_count = per_class_count

        # TODO What does it mean: self.iou_idxs > iou_idx
        self.y_true = self.iou_idxs > iou_idx

    def scores_vs_metrics(self, iou_idx=0, cat_id=None):
        tps = self.iou_idxs > iou_idx
        if cat_id is not None:
            cls_mask = self.classes == cat_id
            tps = tps[cls_mask]
            scores = self.scores[cls_mask]
            n_positives = self.per_class_count[cat_id]
        else:
            scores = self.scores
            n_positives = sum(self.per_class_count.values())
        fps = ~tps

        tps_sum = tps.cumsum()
        fps_sum = fps.cumsum()

        # Precision, recall, f1
        precision = tps_sum / (tps_sum + fps_sum)
        recall = tps_sum / n_positives
        f1 = 2 * precision * recall / (precision + recall)
        return {
            "scores": scores,
            "precision": precision,
            "recall": recall,
            "f1": f1
        }
    
    def calibration_curve(self):
        true_probs, pred_probs = calibration_curve(self.y_true, self.scores, n_bins=10)
        return true_probs, pred_probs
    
    def maximum_calibration_error(self):
        return metrics.maximum_calibration_error(self.y_true, self.scores, n_bins=10)

In [None]:
m = Metrics(eval_data, cocoGt, cocoDt)
m.base_metrics()

In [None]:
metric_names = {
    "mAP": "mAP",
    "precision": "Precision",
    "recall": "Recall",
    "iou": "IoU",
    "classification_accuracy": "Classification Accuracy",
    "calibration_score": "Calibration Score"
    }

## Overview

In [None]:
# Overall Radar Chart
base_metrics = m.base_metrics()
r = list(base_metrics.values())
theta = [metric_names[k] for k in base_metrics.keys()]
fig = px.line_polar(
    r=r, theta=theta, line_close=True,
    title="Overall Metrics", width=600, height=500
    )
fig.update_traces(fill='toself')
fig.update_layout(polar=dict(radialaxis=dict(range=[0., 1.])))
fig.show()

## Model Predictions

In [None]:
df = m.prediction_table()
df

## Outcome Counts

In [None]:
# Outcome counts
TP_count, FN_count, FP_count = m.TP_count, m.FN_count, m.FP_count

fig = go.Figure()
fig.add_trace(go.Bar(x=[TP_count], y=["Outcomes"], name='TP', orientation='h', marker=dict(color='#1fb466')))
fig.add_trace(go.Bar(x=[FN_count], y=["Outcomes"], name='FN', orientation='h', marker=dict(color='#dd3f3f')))
fig.add_trace(go.Bar(x=[FP_count], y=["Outcomes"], name='FP', orientation='h', marker=dict(color='#d5a5a5')))
fig.update_layout(barmode='stack', title="Outcome Counts")
fig.update_xaxes(title_text="Count")
# width=600, height=500
fig.update_layout(width=600, height=300)
fig.show()

## Recall

In [None]:
per_class_metrics_df = m.per_class_metrics()

In [None]:
# Per-class Precision and Recall bar chart
per_class_metrics_df_sorted = per_class_metrics_df.sort_values(by="f1")

blue_color = '#1f77b4'
orange_color = '#ff7f0e'
fig = go.Figure()
fig.add_trace(go.Bar(y=per_class_metrics_df_sorted["precision"], x=per_class_metrics_df_sorted["category"], name='Precision', marker=dict(color=blue_color)))
fig.add_trace(go.Bar(y=per_class_metrics_df_sorted["recall"], x=per_class_metrics_df_sorted["category"], name='Recall', marker=dict(color=orange_color)))
fig.update_layout(barmode='group', title="Per-class Precision and Recall (Sorted by F1)")
fig.update_xaxes(title_text="Category")
fig.update_yaxes(title_text="Value", range=[0, 1])
fig.show()

In [None]:
# Per-class Precision bar chart
# per_class_metrics_df_sorted = per_class_metrics_df.sort_values(by="precision")
fig = px.bar(per_class_metrics_df_sorted, x='category', y='precision', title="Per-class Precision (Sorted by F1)",
             color='precision', color_continuous_scale='Plasma')
if len(per_class_metrics_df_sorted) <= 20:
    fig.update_traces(text=per_class_metrics_df_sorted["precision"].round(2), textposition='outside')
fig.update_xaxes(title_text="Category")
fig.update_yaxes(title_text="Precision", range=[0, 1])
fig.show()

In [None]:
# Per-class Precision bar chart
# per_class_metrics_df_sorted = per_class_metrics_df.sort_values(by="recall")
fig = px.bar(per_class_metrics_df_sorted, x='category', y='recall', title="Per-class Recall (Sorted by F1)",
             color='recall', color_continuous_scale='Plasma')
if len(per_class_metrics_df_sorted) <= 20:
    fig.update_traces(text=per_class_metrics_df_sorted["recall"].round(2), textposition='outside')
fig.update_xaxes(title_text="Category")
fig.update_yaxes(title_text="Recall", range=[0, 1])
fig.show()

## PR-curve

In [None]:
pr_curve = m.pr_curve()

fig = go.Figure()
fig.add_trace(go.Scatter(x=m.coco_params.recThrs, y=pr_curve.mean(-1), mode='lines', name='PR Curve', marker=dict(color='#1f77b4')))
fig.update_layout(title="Precision-Recall Curve", xaxis_title="Recall", yaxis_title="Precision")
fig.update_traces(fill='tozeroy')
fig.update_layout(width=600, height=500)
fig.show()

In [None]:
# Precision-Recall curve per-class
pr_curve_per_class = m.pr_curve()
# shape (n_recall_thresholds, n_classes)
df = pd.DataFrame(pr_curve_per_class, columns=m.cat_names)

fig = px.line(df, x=df.index, y=df.columns, title="Precision-Recall Curve per Class",
              labels={"index": "Recall", "value": "Precision", "variable": "Category"},
              color_discrete_sequence=px.colors.qualitative.Prism, width=800, height=600)
fig.show()

## Confusion Matrix

In [None]:
confusion_matrix = m.confusion_matrix()

In [None]:
# Confusion Matrix
cat_names = m.cat_names
none_name = "(None)"

confusion_matrix_df = pd.DataFrame(np.log(confusion_matrix), index=cat_names + [none_name], columns=cat_names + [none_name])
fig = px.imshow(confusion_matrix_df, labels=dict(x="Predicted", y="Ground Truth", color="Count"), title="Confusion Matrix (log-scale)",
                width=1000, height=1000)

# Hover text
fig.update_traces(customdata=confusion_matrix,
                  hovertemplate='Count: %{customdata}<br>Predicted: %{x}<br>Ground Truth: %{y}')

# Text on cells
if len(cat_names) <= 20:
    fig.update_traces(text=confusion_matrix,
                      texttemplate="%{text}")

fig.show()

## Frequently Confused Classes

In [None]:
# Frequency of confusion as bar chart
confused_df = m.frequently_confused(confusion_matrix, topk_pairs=20)
confused_name_pairs = confused_df["category_pair"]
confused_prob = confused_df["probability"]
x_labels = [f"{pair[0]} - {pair[1]}" for pair in confused_name_pairs]
fig = go.Figure()
fig.add_trace(go.Bar(x=x_labels, y=confused_prob, marker=dict(color=confused_prob, colorscale="Reds")))
fig.update_layout(title="Frequently confused class pairs", xaxis_title="Class pair", yaxis_title="Probability")
fig.update_traces(text=confused_prob.round(2))
fig.show()

## IoU Distribution

In [None]:
iou_hist = m.iou_histogram()

In [None]:
fig = go.Figure()
fig.add_trace(go.Bar(x=iou_hist[1], y=iou_hist[0]))
fig.update_layout(title="IoU Distribution", xaxis_title="IoU", yaxis_title="Count")
fig.show()

## Calibration Score

In [None]:
true_probs, pred_probs = m.calibration_metrics.calibration_curve()
# Сalibration curve
plt.figure(figsize=(8, 6))
plt.plot(pred_probs, true_probs, marker='o', linewidth=1, label='Calibration plot (Model)')
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly calibrated')
plt.xlabel('Confidence Score')
plt.ylabel('Fraction of True Positives')
plt.title('Calibration Curve')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
scores_vs_metrics = pd.DataFrame(m.calibration_metrics.scores_vs_metrics())
# F1 score, Precision, Recall vs Confidence Score

fig = px.line(scores_vs_metrics, x="scores", y=["precision", "recall", "f1"], title="F1 score, Precision, Recall vs Confidence Score",
                labels={"scores": "Confidence Score", "value": "Value", "variable": "Metric"})
fig.show()


In [None]:
tps = m.calibration_metrics.iou_idxs > 0
scores = m.calibration_metrics.scores
scores_tp = scores[tps]
scores_fp = scores[~tps]

fig = go.Figure()
fig.add_trace(go.Histogram(x=scores_tp, name='TP', marker=dict(color='#1fb466'), opacity=0.7))
fig.add_trace(go.Histogram(x=scores_fp, name='FP', marker=dict(color='#dd3f3f'), opacity=0.7))
fig.update_layout(barmode='overlay', title="Histogram of Confidence Scores (TP vs FP)")
fig.update_xaxes(title_text="Confidence Score")
fig.update_yaxes(title_text="Density")
fig.show()

## Per-class

In [None]:
# AP per-class
ap_per_class = m.coco_precision[:, :, :, 0, 2].mean(axis=(0, 1))
# Per-class Average Precision (AP)
fig = px.scatter_polar(r=ap_per_class, theta=m.cat_names, title="Per-class Average Precision (AP)",
                       labels=dict(r="Average Precision", theta="Category"),
                       width=800, height=800,
                       range_r=[0, 1])
# fill points
fig.update_traces(fill='toself')

In [None]:
# Per-class Counts
iou_thres = 0

tp = m.true_positives.sum(1)[:,iou_thres]
fp = m.false_positives.sum(1)[:,iou_thres]
fn = m.false_negatives.sum(1)[:,iou_thres]

# normalize
support = tp + fn
tp_rel = tp / support
fp_rel = fp / support
fn_rel = fn / support

# sort by f1
sort_scores = 2 * tp / (2 * tp + fp + fn)

K = len(m.cat_names)
sort_indices = np.argsort(sort_scores)
cat_names_sorted = [m.cat_names[i] for i in sort_indices]
tp_rel, fn_rel, fp_rel = tp_rel[sort_indices], fn_rel[sort_indices], fp_rel[sort_indices]

In [None]:
# Stacked per-class counts
data = {
    "count": np.concatenate([tp_rel, fn_rel, fp_rel]),
    "type": ["TP"]*K + ["FN"]*K + ["FP"]*K,
    "category": cat_names_sorted*3
}

df = pd.DataFrame(data)

color_map = {
    'TP': '#1fb466',
    'FN': '#dd3f3f',
    'FP': '#d5a5a5'
}
fig = px.bar(df, x="category", y="count", color="type", title="Per-class Outcome Counts",
             labels={'count': 'Total Count', "category": "Category"},
             color_discrete_map=color_map)

fig.show()

In [None]:
from matplotlib import cm

t = 0
tp = m.true_positives.sum(0)[:,t]
fp = m.false_positives.sum(0)[:,t]
fn = m.false_negatives.sum(0)[:,t]

y_edges = np.arange(min(tp) - 0.5, max(tp) + 1.5, 1)
x_edges = np.arange(min(fp+fn) - 0.5, max(fp+fn) + 1.5, 1)
heatmap, y_edges, x_edges = np.histogram2d(tp, fp+fn, bins=(y_edges, x_edges))

z_max = np.max(heatmap)
gamma = 0.95

colors = np.zeros((heatmap.shape[0], heatmap.shape[1], 3))  # for RGB channels
colormap_name = 'RdYlGn_r'
cmap = cm.get_cmap(colormap_name)

for i in range(heatmap.shape[0]):
    for j in range(heatmap.shape[1]):
        tp_val = x_edges[j] + 0.5
        fp_fn_val = y_edges[i] + 0.5
        
        intensity = heatmap[i, j]
        if tp_val + fp_fn_val > 0:
            value = tp_val / (tp_val + fp_fn_val)
        else:
            value = 0
        
        # green to red colormap
        colormap_name = 'RdYlGn_r'
        color = cmap(value)  # Get a color from a colormap
        # Adjust the color intensity based on the heatmap value
        if intensity > 0:
            c = np.array(color[:3]) * max(0.2, np.log(intensity) / np.log(z_max))
            colors[i, j, :] = c**gamma
        else:
            colors[i, j, :] = np.array(color[:3]) * 0.12

# Plot the colored heatmap
fig = px.imshow(colors, labels=dict(x="Count of Errors", y="Count of True Predictions"), title="TP vs FP+FN", text_auto=True, origin='lower',
                width=1000, height=1000)

# Adding text to each pixel
for i in range(heatmap.shape[0]):
    for j in range(heatmap.shape[1]):
        fig.add_annotation(
            x=j, 
            y=i, 
            text=str(int(heatmap[i, j])),
            showarrow=False,
            font=dict(color="#ddd", size=10)
        )

fig.show()