In [1]:
import numpy as np
import pytorch_lightning as pl
import json

from src.models.high_level_model import HighLevelModel
from src.calibration.calibration import calibration
from src.data.multi_output_dataset import MultiOutputDataModule
from src.models.model_utils import convert_multitask_preds

In [2]:
MDC_COLOR = 12
MDC_TYPE = 11
MDC_TASK_NUM_CLASSES = [MDC_COLOR, MDC_TYPE]
root_dir = "data"

In [3]:
model = HighLevelModel.load_from_checkpoint(
    "models/mdc-high-level-model.ckpt",
    map_location="cpu",
    task_num_classes=MDC_TASK_NUM_CLASSES,
)
datamodule = MultiOutputDataModule(
    root_dir=root_dir,
    task_num_classes=MDC_TASK_NUM_CLASSES,
    batch_size=64,
    num_workers=8,
)
datamodule.setup()

In [4]:
load_preds = True

if load_preds:
    calib_preds = np.load("./models/mdc-high-model-calibpreds.npz")
    calib_preds = [calib_preds[key] for key in calib_preds.files]
else:
    model.eval()
    trainer = pl.Trainer(accelerator="gpu")
    calib_preds = trainer.predict(model, dataloaders=datamodule.calib_dataloader())
    calib_preds = convert_multitask_preds(calib_preds)
    np.savez("./models/mdc-high-model-calibpreds", *calib_preds)

true_labels = np.stack([labels for _, labels in datamodule.datasets["calib"]], axis=1)

In [5]:
q_hats = calibration(
    calib_preds,
    true_labels,
    high_level=isinstance(model, HighLevelModel),
)

In [6]:
def convert_numpy_to_native(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.floating)):
        return obj.item()
    elif isinstance(obj, dict):
        return {k: convert_numpy_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_to_native(x) for x in obj]
    else:
        return obj


with open(f"models/mdc-high-level-calibration.json", "w") as f:
    json.dump(convert_numpy_to_native(q_hats), f, indent=2)