In [21]:
import glob
import numpy as np
import gzip as gz
import pickle
from pathlib import Path
from scipy.io import savemat
import pandas as pd
from tqdm import tqdm
import h5py

In [22]:
base_path = Path("assets/results")
cl = ["resnet_ce_mse", "resnet_edl_mse", "resnet50_vicreg_ce"]
en = ["flow_ss_vcr_mse", "resnet_mse"]
ood_maps = dict(
    CIFAR10=     {"F": "", "A": "5:6:7:8:9", "B": "0:1:2:3:4"},
    MNIST=       {"F": "", "A": "5:6:7:8:9", "B": "0:1:2:3:4"},
    QPM_species= {"F": "", "A": "1:4",       "B": "0:2:3"    },
    QPM2_species={"F": "", "A": "2:3",       "B": "0:1"      },
    rbc_phase=   {"F": "", "A": "1",         "B": "0"        },
    rbc2_phase=  {"F": "", "A": "1",         "B": "0"        },
)
datasets = dict(
    CIFAR10=     [("F", cl + en), ("A", cl + en), ("B", cl + en)],
    MNIST=       [("F", cl + en), ("A", cl + en), ("B", cl + en)],
    QPM_species= [("F", cl + en), ("A", cl + en), ("B", cl + en)],
    QPM2_species=[("F", cl + en), ("A", cl + en), ("B", cl + en)],
    rbc_phase=   [("F", cl + en), ("A", en), ("B", en)],
    rbc2_phase=  [("F", cl + en), ("A", en), ("B", en)],
)

In [26]:
def to_dict(npz):
    cols = {}
    for name in npz.files:
        data = npz[name]
        cols[name] = data
    return cols

def gen_dataframes(dataset: str, vartype: str):

    ood_map = ood_maps[dataset]
    runs = [(ood, model) for (ood, models) in datasets[dataset] for model in models]

    if vartype == "val":
        for (ood, model) in tqdm(runs):
            data = np.load(base_path / dataset / model / f"{ood_map[ood]}_val.npz")
            yield (ood, model, to_dict(data))

    elif vartype == "ind":
        for (ood, model) in tqdm(runs):
            data = np.load(base_path / dataset / model / f"{ood_map[ood]}_ind.npz")
            yield (ood, model, to_dict(data))

    elif vartype == "ood":
        for (ood, model) in tqdm(runs):
            if not ood_map[ood]: continue
            data = np.load(base_path / dataset / model / f"{ood_map[ood]}_ood.npz")
            yield (ood, model, to_dict(data))

    elif vartype == "stat":
        for (ood, model) in tqdm(runs):
            with gz.open(base_path / dataset / model / f"{ood_map[ood]}_stats.gz", "rb") as f:
                data = pickle.load(f)
            yield (ood, model, to_dict(data))

    else:
        raise ValueError(vartype)

In [27]:
for dataset in datasets:
    print(dataset)
    
    for (ood, model, data) in gen_dataframes(dataset, vartype="val"):
        with h5py.File(base_path / dataset / model / f"val_{ood}.h5", "w") as f:
            for k, v in data.items():
                dset = f.create_dataset(k, data=v, compression="gzip", compression_opts=9)
    
    for (ood, model, data) in gen_dataframes(dataset, vartype="ind"):
        with h5py.File(base_path / dataset / model / f"ind_{ood}.h5", "w") as f:
            for k, v in data.items():
                dset = f.create_dataset(k, data=v, compression="gzip", compression_opts=9)
    
    for (ood, model, data) in gen_dataframes(dataset, vartype="ood"):
        with h5py.File(base_path / dataset / model / f"ood_{ood}.h5", "w") as f:
            for k, v in data.items():
                dset = f.create_dataset(k, data=v, compression="gzip", compression_opts=9)

CIFAR10


100%|██████████| 15/15 [02:35<00:00, 10.34s/it]
100%|██████████| 15/15 [02:34<00:00, 10.29s/it]
100%|██████████| 15/15 [01:17<00:00,  5.14s/it]


MNIST


100%|██████████| 15/15 [02:15<00:00,  9.04s/it]
100%|██████████| 15/15 [01:53<00:00,  7.57s/it]
100%|██████████| 15/15 [00:46<00:00,  3.07s/it]


QPM_species


100%|██████████| 15/15 [22:23<00:00, 89.57s/it] 
100%|██████████| 15/15 [09:06<00:00, 36.43s/it]
100%|██████████| 15/15 [04:48<00:00, 19.26s/it]


QPM2_species


100%|██████████| 15/15 [30:19<00:00, 121.27s/it]
100%|██████████| 15/15 [49:44<00:00, 198.94s/it] 
100%|██████████| 15/15 [24:06<00:00, 96.44s/it] 


rbc_phase


100%|██████████| 9/9 [00:25<00:00,  2.79s/it]
100%|██████████| 9/9 [00:22<00:00,  2.54s/it]
100%|██████████| 9/9 [00:30<00:00,  3.37s/it]


rbc2_phase


100%|██████████| 9/9 [00:18<00:00,  2.04s/it]
100%|██████████| 9/9 [00:26<00:00,  2.95s/it]
100%|██████████| 9/9 [00:28<00:00,  3.21s/it]
