In [1]:
import os
from steves_utils.utils_v2 import (per_domain_accuracy_from_confusion, get_datasets_base_path)
from steves_utils.torch_utils import get_dataset_metrics, ptn_confusion_by_domain_over_dataloader
from steves_utils.stratified_dataset.episodic_accessor import Episodic_Accessor_Factory
from  easydict import EasyDict

import steves_utils.CORES.utils as CORES
import steves_utils.ORACLE.utils_v2 as ORACLE
import steves_utils.wisig.utils as WISIG


datasets = {
    "CORES":{
        "labels": CORES.ALL_NODES,
        "domains": CORES.ALL_DAYS,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "cores.stratified_ds.2022A.pkl"),
    },
    "WISIG":{
        "labels": WISIG.ALL_NODES_MINIMUM_100_EXAMPLES,
        "domains": WISIG.ALL_DAYS,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "wisig.node3-19.stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run1": {
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run1_10kExamples_stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run2": {
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run2_10kExamples_stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run1.framed":{
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run1_framed_2000Examples_stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run2.framed":{
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run2_framed_2000Examples_stratified_ds.2022A.pkl"),
    }
}

In [2]:
def get_ds_stats(ds):
    eaf = Episodic_Accessor_Factory(
        labels=ds["labels"],
        domains=ds["domains"],
        num_examples_per_domain_per_label=-1,
        iterator_seed=1337,
        dataset_seed=1337,
        n_shot=3,
        n_way=len(ds["labels"]),
        n_query=2,
        train_val_test_k_factors=(3,2,2),
        pickle_path=ds["pickle_path"],
#         x_transform_func=x_transform_func,
    )
    
    train, val, test = eaf.get_train(), eaf.get_val(), eaf.get_test()

    datasets = EasyDict({
        "source": {
            "original": {"train":train, "val":val, "test":test},
    #         "processed": {"train":train_processed_source, "val":val_processed_source, "test":test_processed_source}
        }
    })
    
    return get_dataset_metrics(datasets, "ptn")

In [3]:
for name, params in datasets.items():
    s = get_ds_stats(params)
#     print(s)
    print("Dataset:", name)
    for split in ["train", "val", "test"]:
        print("    {split}: n_unique_x: {n_unique_x},  n_unique_y: {n_unique_y}, n_episode: {n_episode}".format(
            split=split,
            n_unique_x=s["source"][split]["n_unique_x"],
            n_unique_y=s["source"][split]["n_unique_y"],
            n_episode=s["source"][split]["n_batch/episode"],
        ))

Dataset: CORES
    train: n_unique_x: 75852,  n_unique_y: 58, n_episode: 351
    val: n_unique_x: 11850,  n_unique_y: 58, n_episode: 48
    test: n_unique_x: 12315,  n_unique_y: 58, n_episode: 50
Dataset: WISIG
    train: n_unique_x: 79924,  n_unique_y: 130, n_episode: 183
    val: n_unique_x: 12656,  n_unique_y: 130, n_episode: 24
    test: n_unique_x: 12778,  n_unique_y: 130, n_episode: 24
Dataset: ORACLE.Run1
    train: n_unique_x: 896000,  n_unique_y: 16, n_episode: 33600
    val: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
    test: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
Dataset: ORACLE.Run2
    train: n_unique_x: 896000,  n_unique_y: 16, n_episode: 33600
    val: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
    test: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
Dataset: ORACLE.Run1.framed
    train: n_unique_x: 179200,  n_unique_y: 16, n_episode: 6720
    val: n_unique_x: 38400,  n_unique_y: 16, n_episode: 960
    test: n_unique_x: 38400,  n