In [1]:
import json
import numpy as np
import torch
from tqdm import tqdm

from src.models.high_level_model import HighLevelModel
from src.data.multi_output_dataset import MultiOutputDataModule
from src.calibration.nonconformity_functions import NONCONFORMITY_FN_DIC
from src.calibration.calibration import CALIBRATION_FN_HIGH_DIC
from src.models.conformal_prediction import standard_prediction, clustered_prediction
from src.metrics import (
    compute_overall_efficiency,
    compute_overall_informativeness,
    compute_taskwise_informativeness,
    compute_taskwise_efficiency,
    compute_overall_covgap,
)

In [2]:
UTKFACE_CLASSES = [2, 5]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model = HighLevelModel.load_from_checkpoint(
    "models/utkface-high-level-model.ckpt",
    task_num_classes=UTKFACE_CLASSES,
    map_location=device,
)
model = model.eval()

data = MultiOutputDataModule(
    root_dir="data/UTKFace", batch_size=64, num_workers=0, task_num_classes=UTKFACE_CLASSES
)
data.setup()

In [4]:
with open("models/utkface-high-level-calibration.json", "r") as file:
    calibration_data = json.load(file)

In [5]:
def generate_predictions(model, dataloader):
    y_preds = []
    y_trues = []

    for batch in tqdm(dataloader):
        x, y = batch[0].to(device), batch[1].to(device)
        with torch.no_grad():
            pred = model.predict_step((x,y), 0)  # List of T tensors, each (B, C_t)
        pred = [p.cpu() for p in pred]  # Move to CPU

        batch_size = pred[0].shape[0]
        for i in range(batch_size):
            sample_preds = [p[i] for p in pred]  # T predictions for sample i
            y_preds.append(sample_preds)
        y_trues.extend(y.cpu().numpy())  # Each is shape (T,)

    return y_preds, y_trues

In [6]:
def test_calibration_from_preds(
    y_preds: list,
    y_trues: list,
    nonconformity_fn: str,
    calibration_type: str,
    calibration_data: dict,
    task_num_classes: list,
    alpha: float = 0.05,
):
    """
    Evaluate conformal calibration using precomputed predictions for multi-task classification.

    This function computes nonconformity scores, applies the specified conformal calibration
    method, and evaluates prediction sets based on coverage, efficiency, and informativeness.

    Args:
        y_preds (list): List of B samples, each is a list of T tensors/logits of shape (C_t,).
        y_trues (list): List of B samples, each is a list of T true labels (ints).
        nonconformity_fn (str): Name of the nonconformity function to use (e.g., 'hinge', 'margin').
        calibration_type (str): Type of calibration to apply ('scp_task_thresholds', 'ccp_cluster_thresholds', etc.).
        calibration_data (dict): Dictionary containing thresholds or cluster mappings for each method.
        task_num_classes (list): List of ints representing number of classes for each task.
        alpha (float): Significance level for coverage gap computation.
        
    Returns:
        None: Prints taskwise and overall coverage, efficiency, and informativeness statistics.
    """
    T = len(task_num_classes)
    B = len(y_preds)

    # Transpose y_preds: list[B][T][C_t] --> list[T][B][C_t] efficiently
    y_preds_by_task = [np.stack(task_preds) for task_preds in zip(*y_preds)]

    # Compute nonconformity scores
    nonconformity_scores = NONCONFORMITY_FN_DIC[nonconformity_fn](y_preds_by_task)

    clustered = "cluster" in calibration_type
    prediction = (
        clustered_prediction(
            nonconformity_scores,
            calibration_data[nonconformity_fn][calibration_type],
        )
        if clustered
        else standard_prediction(
            nonconformity_scores,
            calibration_data[nonconformity_fn][calibration_type],
        )
    )  # list[T][B], each element is prediction set for one task and sample

    # Reshape predictions: list[T][B] --> list[B][T]
    predictions_by_sample = list(zip(*prediction))

    # Evaluate coverage
    in_it = np.zeros(T, dtype=int)
    for t in range(T):
        for i in range(B):
            if y_trues[i][t] in predictions_by_sample[i][t]:
                in_it[t] += 1

    overall_eff = compute_overall_efficiency(prediction)
    taskwise_eff = compute_taskwise_efficiency(prediction)
    overall_info = compute_overall_informativeness(prediction)
    taskwise_info = compute_taskwise_informativeness(prediction)
    covgap = compute_overall_covgap(prediction, [np.array([yt[t] for yt in y_trues]) for t in range(T)], task_num_classes, alpha)

    # Reporting
    print("Accuracies of the calibrated method:")
    for t in range(T):
        print(f"Accuracy of Task {t}: {in_it[t]} / {B} = {in_it[t] / B:.2%}")
    overall = sum(in_it)
    total = B * T
    print(f"Overall: {overall} / {total} = {overall / total:.2%}")
    for t in range(T):
        print(f"Efficiency of Task {t}: {taskwise_eff[t]:.4f}")
    print(f"Overall Efficiency: {overall_eff:.4f}")
    for t in range(T):
        print(f"Informativeness of Task {t}: {taskwise_info[t]:.4f}")
    print(f"Overall Informativeness: {overall_info:.4f}")
    print(f"Overall CovGap: {covgap:.4f}")

In [7]:
y_preds, y_trues = generate_predictions(model, data.test_dataloader())

100%|██████████| 57/57 [05:31<00:00,  5.81s/it]


In [8]:
for nonconformity_fn in NONCONFORMITY_FN_DIC.keys():
    print("-------------------------------------------------")
    print(f"Nonconformity function: {nonconformity_fn}")
    for calibration_type in CALIBRATION_FN_HIGH_DIC.keys():
        print("-------------------------------------------------")
        print(f"Calibration type: {calibration_type}")
        test_calibration_from_preds(
            y_preds,
            y_trues,
            nonconformity_fn,
            calibration_type,
            calibration_data,
            UTKFACE_CLASSES,
        )

-------------------------------------------------
Nonconformity function: hinge
-------------------------------------------------
Calibration type: scp_global_threshold
Accuracies of the calibrated method:
Accuracy of Task 0: 3554 / 3616 = 98.29%
Accuracy of Task 1: 3358 / 3616 = 92.87%
Overall: 6912 / 7232 = 95.58%
Efficiency of Task 0: 1.3219
Efficiency of Task 1: 2.0932
Overall Efficiency: 1.7075
Informativeness of Task 0: 0.6781
Informativeness of Task 1: 0.3993
Overall Informativeness: 0.5387
Overall CovGap: 5.9403
-------------------------------------------------
Calibration type: scp_task_thresholds
Accuracies of the calibrated method:
Accuracy of Task 0: 3435 / 3616 = 94.99%
Accuracy of Task 1: 3464 / 3616 = 95.80%
Overall: 6899 / 7232 = 95.40%
Efficiency of Task 0: 1.1444
Efficiency of Task 1: 2.5285
Overall Efficiency: 1.8364
Informativeness of Task 0: 0.8556
Informativeness of Task 1: 0.3053
Overall Informativeness: 0.5805
Overall CovGap: 3.3835
-----------------------------