In [1]:
import numpy as np
import pytorch_lightning as pl
import json

from src.models.high_level_model import HighLevelModel
from src.calibration.calibration import calibration
from src.data.multi_output_dataset import MultiOutputDataModule
from src.models.model_utils import convert_multitask_preds

In [2]:
MDC_COLOR = 12
MDC_TYPE = 11
MDC_TASK_NUM_CLASSES = [MDC_COLOR, MDC_TYPE]
root_dir = "data"

In [3]:
model = HighLevelModel.load_from_checkpoint(
    "models/mdc-high-level-model.ckpt",
    map_location="cpu",
    task_num_classes=MDC_TASK_NUM_CLASSES,
)
datamodule = MultiOutputDataModule(
    root_dir=root_dir,
    task_num_classes=MDC_TASK_NUM_CLASSES,
    batch_size=64,
    num_workers=8,
)
datamodule.setup()

In [4]:
#model.eval()
#trainer = pl.Trainer(accelerator="gpu")

#calib_preds = trainer.predict(model, dataloaders=datamodule.calib_dataloader())
#calib_preds = convert_multitask_preds(calib_preds)
calib_preds = np.load("./models/mdc-high-model-calibpreds.npz")
calib_preds = [calib_preds[key] for key in calib_preds.files]
true_labels = np.stack([labels for _, labels in datamodule.datasets["calib"]], axis=1)

In [5]:
print(calib_preds)
#np.savez("./models/mdc-high-model-calibpreds", *calib_preds)

[array([[8.22556019e-01, 7.46676000e-04, 2.59488697e-08, ...,
        9.59193130e-07, 1.69652703e-09, 5.70246117e-09],
       [9.56030369e-01, 3.25049344e-03, 3.80841941e-10, ...,
        1.97434780e-09, 7.52362200e-11, 3.62321534e-10],
       [6.41980410e-01, 1.85987493e-03, 1.14738071e-07, ...,
        5.28904594e-11, 2.08485377e-12, 4.00961708e-09],
       ...,
       [4.47997361e-10, 3.55223750e-12, 1.34482370e-05, ...,
        8.44685495e-01, 5.72019987e-10, 2.66193578e-11],
       [1.13884257e-13, 2.75420242e-10, 1.44258636e-04, ...,
        9.21593666e-01, 7.46730677e-09, 1.47014049e-07],
       [4.54582249e-12, 2.49659308e-07, 3.27027960e-10, ...,
        9.85211372e-01, 1.05621805e-11, 7.89443566e-10]],
      shape=(4812, 12), dtype=float32), array([[2.63850063e-01, 2.54477002e-03, 1.09256616e-05, ...,
        2.21990533e-02, 1.14108017e-03, 3.97044687e-06],
       [8.55591178e-01, 1.84196346e-02, 2.56511257e-08, ...,
        6.33316636e-02, 1.92780746e-03, 6.46899112e-09],
  

In [6]:
q_hats = calibration(
    calib_preds,
    true_labels,
    high_level=isinstance(model, HighLevelModel),
)

In [7]:
def convert_numpy_to_native(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, (np.integer, np.floating)):
        return obj.item()
    elif isinstance(obj, dict):
        return {k: convert_numpy_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_to_native(x) for x in obj]
    else:
        return obj


with open(f"models/mdc-high-level-calibration.json", "w") as f:
    json.dump(convert_numpy_to_native(q_hats), f, indent=2)