In [1]:
import json
import math
import numpy as np
import pytorch_lightning as pl
from typing import Union

from src.models.high_level_model import HighLevelModel
from src.models.low_level_model import LowLevelModel
from src.calibration.calibration import compute_gt_nonconformity
from src.data.multi_output_dataset import MultiOutputDataModule
from src.models.model_utils import convert_multitask_preds
from src.calibration.nonconformity_functions import NONCONFORMITY_FN_DIC
from src.calibration.calibration import CALIBRATION_FN_HIGH_DIC, CALIBRATION_FN_LOW_DIC
from src.calibration.calibration_utils import compute_qhat_ccp_global_cluster, compute_qhat_ccp_task_cluster


In [2]:
SGVEHICLE_COLOR = 12
SGVEHICLE_TYPE = 11
SGVEHICLE_TASK_NUM_CLASSES = [SGVEHICLE_COLOR, SGVEHICLE_TYPE]
root_dir = "data/SGVehicle"

In [3]:
def convert_numpy_to_native(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.floating)):
        return obj.item()
    elif isinstance(obj, dict):
        return {k: convert_numpy_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_to_native(x) for x in obj]
    else:
        return obj

In [4]:
model = HighLevelModel.load_from_checkpoint(
    "models/sgvehicle-high-level-model.ckpt",
    map_location="cpu",
    task_num_classes=SGVEHICLE_TASK_NUM_CLASSES,
)
datamodule = MultiOutputDataModule(
    root_dir=root_dir,
    task_num_classes=SGVEHICLE_TASK_NUM_CLASSES,
    batch_size=64,
    num_workers=8,
)
datamodule.setup()

In [6]:
def calibrate_model(
    model: Union[HighLevelModel, LowLevelModel],
    datamodule: MultiOutputDataModule,
    load_preds=False,
):
    high_level = isinstance(model, HighLevelModel)
    high_level_string = "high" if high_level else "low"
    if load_preds:
        calib_preds = np.load(f"./models/sgvehicle-{high_level_string}-model-calibpreds.npz")
        calib_preds = [calib_preds[key] for key in calib_preds.files]
        if not high_level:
            calib_preds = np.array(calib_preds)

    else:
        model.eval()
        trainer = pl.Trainer(accelerator="gpu")
        calib_preds = trainer.predict(model, dataloaders=datamodule.calib_dataloader())
        if high_level:
            calib_preds = convert_multitask_preds(calib_preds)
        else:
            calib_preds = np.concatenate(calib_preds, axis=0)
            calib_preds = np.array(calib_preds)
        np.savez(f"./models/sgvehicle-{high_level_string}-model-calibpreds", *calib_preds)

    true_labels = np.stack(
        [labels for _, labels in datamodule.datasets["calib"]], axis=1
    )
    if not high_level:
        multiplier = np.array(
            [
                math.prod(datamodule.task_num_classes[i + 1 :])
                for i in range(len(datamodule.task_num_classes))
            ]
        )
        true_labels = np.array(true_labels * multiplier[:, None]).sum(axis=0)

    nonconformity_scores = compute_gt_nonconformity(calib_preds, true_labels)
    
    q_hats = {}
    alpha = 0.05
    CLUSTERED_FN_HIGH_DIC = {
        "ccp_task_cluster_thresholds": compute_qhat_ccp_task_cluster,
        "ccp_global_cluster_thresholds": compute_qhat_ccp_global_cluster,
    }
    CLUSTERED_FN_LOW_DIC = {
        "ccp_global_clusters": compute_qhat_ccp_global_cluster
    }
    
    for calibration_type, calibration_fn in (
        CLUSTERED_FN_HIGH_DIC.items()
        if high_level
        else CLUSTERED_FN_LOW_DIC.items()
    ):
        for nonconformity_name in NONCONFORMITY_FN_DIC.keys():
            if nonconformity_name not in q_hats:
                q_hats[nonconformity_name] = {}

            q_hats[nonconformity_name][calibration_type] = calibration_fn(
                nonconformity_scores[nonconformity_name], true_labels, alpha, cluster_method="hierarchical"
            )
    

    with open(f"models/sgvehicle-{high_level_string}-level-calibration-hierarchical.json", "w") as f:
        json.dump(convert_numpy_to_native(q_hats), f, indent=2)

In [7]:
model = HighLevelModel.load_from_checkpoint(
    "models/sgvehicle-high-level-model.ckpt",
    map_location="cpu",
    task_num_classes=SGVEHICLE_TASK_NUM_CLASSES,
)
datamodule = MultiOutputDataModule(
    root_dir=root_dir,
    task_num_classes=SGVEHICLE_TASK_NUM_CLASSES,
    batch_size=64,
    num_workers=8,
)
datamodule.setup()

calibrate_model(
    model=model,
    datamodule=datamodule,
    load_preds=True,
)

[Task 0] n_clustering=51, num_clusters=25
[Task 1] n_clustering=6, num_clusters=3
[Task 0] n_clustering=51, num_clusters=25
[Task 1] n_clustering=6, num_clusters=3
[Task 0] n_clustering=51, num_clusters=25
[Task 1] n_clustering=6, num_clusters=3


In [8]:
model = LowLevelModel.load_from_checkpoint(
    "models/sgvehicle-low-level-model.ckpt",
    map_location="cpu",
    task_num_classes=SGVEHICLE_TASK_NUM_CLASSES,
)
datamodule = MultiOutputDataModule(
    root_dir=root_dir,
    task_num_classes=SGVEHICLE_TASK_NUM_CLASSES,
    batch_size=64,
    num_workers=8,
)
datamodule.setup()

calibrate_model(
    model=model,
    datamodule=datamodule,
    load_preds=True,
)
