In [1]:
import json
import numpy as np
import torch
from tqdm import tqdm

from src.models.low_level_model import LowLevelModel
from src.data.multi_output_dataset import MultiOutputDataModule
from src.calibration.nonconformity_functions import NONCONFORMITY_FN_DIC
from src.calibration.calibration import CALIBRATION_FN_LOW_DIC
from src.models.conformal_prediction import standard_prediction, clustered_prediction
from src.metrics import compute_efficiency, compute_informativeness, compute_covgap

In [2]:
UTKFACE_CLASSES = [2, 5]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model = LowLevelModel.load_from_checkpoint(
    "models/utkface-low-level-model.ckpt",
    task_num_classes=UTKFACE_CLASSES,
    map_location=device,
)
model = model.eval()

data = MultiOutputDataModule(
    root_dir="data/UTKFace", batch_size=64, num_workers=0, task_num_classes=UTKFACE_CLASSES
)
data.setup()

In [4]:
with open("models/utkface-low-level-calibration.json", "r") as file:
    calibration_data = json.load(file)

In [5]:
def generate_predictions(model: LowLevelModel, dataloader: MultiOutputDataModule):
    y_preds = []
    y_trues = []

    for batch in tqdm(dataloader):
        x, y = batch[0].to(device), batch[1].to(device)
        with torch.no_grad():
            pred = model.predict_step((x, y), 0)  # (B, C)

        y_preds.append(pred.cpu().numpy())
        y_trues.append(model.encode_targets(y).cpu().numpy())

    y_preds = np.concatenate(y_preds, axis=0)
    y_trues = np.concatenate(y_trues, axis=0)
    return y_preds, y_trues

In [6]:
def test_calibration_from_preds(
    y_preds: np.ndarray,
    y_trues: np.ndarray,
    nonconformity_fn: str,
    calibration_type: str,
    calibration_data: dict,
    alpha: float = 0.05,
):
    """
    Evaluate conformal calibration using precomputed predictions for multi-task classification.

    This function computes nonconformity scores, applies the specified conformal calibration
    method, and evaluates prediction sets based on coverage, efficiency, and informativeness.

    Args:
        y_preds (list): List of B samples, each is a list of T tensors/logits of shape (C_t,).
        y_trues (list): List of B samples, each is a list of T true labels (ints).
        nonconformity_fn (str): Name of the nonconformity function to use (e.g., 'hinge', 'margin').
        calibration_type (str): Type of calibration to apply ('scp_task_thresholds', 'ccp_cluster_thresholds', etc.).
        calibration_data (dict): Dictionary containing thresholds or cluster mappings for each method.

    Returns:
        None: Prints taskwise and overall coverage, efficiency, and informativeness statistics.
    """

    B = len(y_preds)
    C = len(y_preds[0]) 
    nonconformity_scores = NONCONFORMITY_FN_DIC[nonconformity_fn](y_preds)

    clustered = "cluster" in calibration_type
    prediction = (
        clustered_prediction(
            nonconformity_scores,
            calibration_data[nonconformity_fn][calibration_type],
        )
        if clustered
        else standard_prediction(
            nonconformity_scores,
            calibration_data[nonconformity_fn][calibration_type],
        )
    )

    in_it = 0
    for i in range(B):
        if y_trues[i] in prediction[i]:
            in_it += 1

    efficiency = compute_efficiency(prediction)
    informativeness = compute_informativeness(prediction)
    covgap = compute_covgap(prediction, y_trues, C, alpha)

    # Reporting
    print(f"Accuracy: {in_it} / {B} = {in_it / B:.2%}")
    print(f"Efficiency: {efficiency:.4f}")
    print(f"Informativeness: {informativeness:.4f}")
    print(f"Coverage Gap: {covgap:.4f}")

In [7]:
y_preds, y_trues = generate_predictions(model, data.test_dataloader())

100%|██████████| 57/57 [05:33<00:00,  5.85s/it]


In [8]:
for nonconformity_fn in NONCONFORMITY_FN_DIC.keys():
    print("-------------------------------------------------")
    print(f"Nonconformity function: {nonconformity_fn}")
    for calibration_type in CALIBRATION_FN_LOW_DIC.keys():
        print("-------------------------------------------------")
        print(f"Calibration type: {calibration_type}")
        test_calibration_from_preds(
            y_preds,
            y_trues,
            nonconformity_fn,
            calibration_type,
            calibration_data
        )

-------------------------------------------------
Nonconformity function: hinge
-------------------------------------------------
Calibration type: scp_global_threshold
Accuracy: 3457 / 3616 = 95.60%
Efficiency: 3.6300
Informativeness: 0.1394
Coverage Gap: 4.0195
-------------------------------------------------
Calibration type: ccp_class_thresholds
Accuracy: 3467 / 3616 = 95.88%
Efficiency: 4.0968
Informativeness: 0.1117
Coverage Gap: 1.3222
-------------------------------------------------
Calibration type: ccp_global_clusters
Accuracy: 3467 / 3616 = 95.88%
Efficiency: 4.0968
Informativeness: 0.1117
Coverage Gap: 1.3222
-------------------------------------------------
Nonconformity function: margin
-------------------------------------------------
Calibration type: scp_global_threshold
Accuracy: 3478 / 3616 = 96.18%
Efficiency: 6.1527
Informativeness: 0.3413
Coverage Gap: 2.3664
-------------------------------------------------
Calibration type: ccp_class_thresholds
Accuracy: 3466 