In [None]:
from __future__ import annotations

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch_imp import HistogramBinning

In [None]:
y_pred = np.load("y_pred.npy")
binary_labels = np.load("binary_labels.npy")

y_pred = torch.tensor(y_pred)
binary_labels = torch.tensor(binary_labels)

print("Loaded shapes:", y_pred.shape, binary_labels.shape)

In [None]:
cal_preds = y_pred[1500:1750]
cal_labels = binary_labels[1500:1750]

test_preds = y_pred[1750:]
test_labels = binary_labels[1750:]

print("Calibration size:", len(cal_preds))
print("Test size:", len(test_preds))

In [None]:
device = torch.device("cpu")
calibrator = HistogramBinning(base_model=None, device=device)

calibrator.fit(cal_preds, cal_labels)

test_preds_calibrated = calibrator.predict(test_preds)

In [None]:
def compute_ece(preds, labels, n_bins=15):  # noqa: ANN001, ANN201
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0

    for i in range(n_bins):
        start, end = bins[i], bins[i + 1]

        in_bin = (preds >= start) & (preds < end)
        bin_size = in_bin.sum().item()

        if bin_size > 0:
            avg_pred = preds[in_bin].float().mean()
            avg_label = labels[in_bin].float().mean()
            ece += (bin_size / len(preds)) * abs(avg_pred - avg_label)

    return ece


ece_before = compute_ece(test_preds, test_labels)
ece_after = compute_ece(test_preds_calibrated, test_labels)

print(f"ECE before calibration: {ece_before:.4f}")
print(f"ECE after calibration:  {ece_after:.4f}")

In [None]:
def reliability_plot(raw, calibrated, labels, n_bins=15) -> None:  # noqa: ANN001
    bins = np.linspace(0.0, 1.0, n_bins + 1)

    def bin_stats(preds):  # noqa: ANN001, ANN202
        conf, acc = [], []
        for i in range(n_bins):
            in_bin = (preds >= bins[i]) & (preds < bins[i + 1])
            if in_bin.sum() > 0:
                conf.append(preds[in_bin].float().mean())
                acc.append(labels[in_bin].float().mean())
            else:
                conf.append(0)
                acc.append(0)
        return np.array(conf), np.array(acc)

    raw_conf, raw_acc = bin_stats(raw)
    cal_conf, cal_acc = bin_stats(calibrated)

    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], "--", label="Perfect Calibration")
    plt.plot(raw_conf, raw_acc, label="Before Calibration")
    plt.plot(cal_conf, cal_acc, label="After Calibration")
    plt.xlabel("Avg predicted probability")
    plt.ylabel("Accuracy in bin")
    plt.title("Reliability Diagram")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
reliability_plot(test_preds, test_preds_calibrated, test_labels)