In [1]:
import json
import math
import numpy as np
import pytorch_lightning as pl
from typing import Union

from src.models.high_level_model import HighLevelModel
from src.models.low_level_model import LowLevelModel
from src.calibration.calibration import calibration
from src.data.multi_output_dataset import MultiOutputDataModule
from src.models.model_utils import convert_multitask_preds

In [2]:
UTKFACE_GENDER = 2
UTKFACE_RACE = 5
UTKFACE_TASK_NUM_CLASSES = [UTKFACE_GENDER, UTKFACE_RACE]
root_dir = "data/UTKFace"

In [3]:
def convert_numpy_to_native(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.floating)):
        return obj.item()
    elif isinstance(obj, dict):
        return {k: convert_numpy_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_to_native(x) for x in obj]
    else:
        return obj

In [4]:
def calibrate_model(
    model: Union[HighLevelModel, LowLevelModel],
    datamodule: MultiOutputDataModule,
    load_preds=False,
):
    high_level = isinstance(model, HighLevelModel)
    high_level_string = "high" if high_level else "low"
    if load_preds:
        calib_preds = np.load(f"./models/utkface-{high_level_string}-model-calibpreds.npz")
        calib_preds = [calib_preds[key] for key in calib_preds.files]
        if not high_level:
            calib_preds = np.array(calib_preds)

    else:
        model.eval()
        trainer = pl.Trainer(accelerator="gpu")
        calib_preds = trainer.predict(model, dataloaders=datamodule.calib_dataloader())
        if high_level:
            calib_preds = convert_multitask_preds(calib_preds)
        else:
            calib_preds = np.concatenate(calib_preds, axis=0)
            calib_preds = np.array(calib_preds)
        np.savez(f"./models/utkface-{high_level_string}-model-calibpreds", *calib_preds)

    true_labels = np.stack(
        [labels for _, labels in datamodule.datasets["calib"]], axis=1
    )
    if not high_level:
        multiplier = np.array(
            [
                math.prod(datamodule.task_num_classes[i + 1 :])
                for i in range(len(datamodule.task_num_classes))
            ]
        )
        true_labels = np.array(true_labels * multiplier[:, None]).sum(axis=0)

    q_hats = calibration(
        calib_preds,
        true_labels,
        high_level,
    )

    with open(f"models/utkface-{high_level_string}-level-calibration.json", "w") as f:
        json.dump(convert_numpy_to_native(q_hats), f, indent=2)

In [5]:
model = HighLevelModel.load_from_checkpoint(
    "models/utkface-high-level-model.ckpt",
    map_location="cpu",
    task_num_classes=UTKFACE_TASK_NUM_CLASSES,
)
datamodule = MultiOutputDataModule(
    root_dir=root_dir,
    task_num_classes=UTKFACE_TASK_NUM_CLASSES,
    batch_size=64,
    num_workers=8,
)
datamodule.setup()

In [6]:
calibrate_model(
    model=model,
    datamodule=datamodule,
    load_preds=True,
)

[Task 0] n_clustering=45, num_clusters=22
[Task 1] n_clustering=17, num_clusters=8
[Task 0] n_clustering=45, num_clusters=22
[Task 1] n_clustering=17, num_clusters=8
[Task 0] n_clustering=45, num_clusters=22
[Task 1] n_clustering=17, num_clusters=8


In [7]:
model = LowLevelModel.load_from_checkpoint(
    "models/utkface-low-level-model.ckpt",
    map_location="cpu",
    task_num_classes=UTKFACE_TASK_NUM_CLASSES,
)
datamodule = MultiOutputDataModule(
    root_dir=root_dir,
    task_num_classes=UTKFACE_TASK_NUM_CLASSES,
    batch_size=64,
    num_workers=8,
)
datamodule.setup()

In [8]:
calibrate_model(
    model=model,
    datamodule=datamodule,
    load_preds=True,
)