In [1]:
import os
from steves_utils.utils_v2 import (per_domain_accuracy_from_confusion, get_datasets_base_path)
from steves_utils.torch_utils import get_dataset_metrics, ptn_confusion_by_domain_over_dataloader
from steves_utils.stratified_dataset.episodic_accessor import Episodic_Accessor_Factory
from  easydict import EasyDict

import steves_utils.CORES.utils as CORES
import steves_utils.ORACLE.utils_v2 as ORACLE
import steves_utils.wisig.utils as WISIG


datasets = {
    "CORES":{
        "labels": CORES.ALL_NODES,
        "domains": CORES.ALL_DAYS,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "cores.stratified_ds.2022A.pkl"),
    },
    "WISIG":{
        "labels": WISIG.ALL_NODES_MINIMUM_100_EXAMPLES,
        "domains": WISIG.ALL_DAYS,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "wisig.node3-19.stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run1": {
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run1_10kExamples_stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run2": {
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run2_10kExamples_stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run1.framed":{
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run1_framed_2000Examples_stratified_ds.2022A.pkl"),
    },
    "ORACLE.Run2.framed":{
        "labels": ORACLE.ALL_SERIAL_NUMBERS,
        "domains": ORACLE.ALL_DISTANCES_FEET_NARROWED,
        "num_examples_per_domain_per_label": -1,
        "pickle_path": os.path.join(get_datasets_base_path(), "oracle.Run2_framed_2000Examples_stratified_ds.2022A.pkl"),
    }
}

In [2]:
def get_ds_stats(ds):
    eaf = Episodic_Accessor_Factory(
        labels=ds["labels"],
        domains=ds["domains"],
        num_examples_per_domain_per_label=-1,
        iterator_seed=1337,
        dataset_seed=1337,
        n_shot=3,
        n_way=len(ds["labels"]),
        n_query=2,
        train_val_test_k_factors=(3,2,2),
        pickle_path=ds["pickle_path"],
#         x_transform_func=x_transform_func,
    )
    
    train, val, test = eaf.get_train(), eaf.get_val(), eaf.get_test()

    datasets = EasyDict({
        "source": {
            "original": {"train":train, "val":val, "test":test},
    #         "processed": {"train":train_processed_source, "val":val_processed_source, "test":test_processed_source}
        }
    })
    
    return get_dataset_metrics(datasets, "ptn")

In [3]:
for name, params in datasets.items():
    s = get_ds_stats(params)
#     print(s)
    print("Dataset:", name)
    for split in ["train", "val", "test"]:
        print("    {split}: n_unique_x: {n_unique_x},  n_unique_y: {n_unique_y}, n_episode: {n_episode}".format(
            split=split,
            n_unique_x=s["source"][split]["n_unique_x"],
            n_unique_y=s["source"][split]["n_unique_y"],
            n_episode=s["source"][split]["n_batch/episode"],
        ))

Dataset: CORES
    train: n_unique_x: 75852,  n_unique_y: 58, n_episode: 351
    val: n_unique_x: 11850,  n_unique_y: 58, n_episode: 48
    test: n_unique_x: 12315,  n_unique_y: 58, n_episode: 50
Dataset: WISIG
    train: n_unique_x: 79924,  n_unique_y: 130, n_episode: 183
    val: n_unique_x: 12656,  n_unique_y: 130, n_episode: 24
    test: n_unique_x: 12778,  n_unique_y: 130, n_episode: 24
Dataset: ORACLE.Run1
    train: n_unique_x: 896000,  n_unique_y: 16, n_episode: 33600
    val: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
    test: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
Dataset: ORACLE.Run2
    train: n_unique_x: 896000,  n_unique_y: 16, n_episode: 33600
    val: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
    test: n_unique_x: 192000,  n_unique_y: 16, n_episode: 4800
Dataset: ORACLE.Run1.framed
    train: n_unique_x: 179200,  n_unique_y: 16, n_episode: 6720
    val: n_unique_x: 38400,  n_unique_y: 16, n_episode: 960
    test: n_unique_x: 38400,  n

In [9]:
from steves_utils.iterable_aggregator import Iterable_Aggregator
from steves_utils.utils_v2 import to_hash

def every_example_in_dl_generator(dl):
    for u, (support_x, support_y, query_x, query_y, real_classes) in dl: 
        for x,y in zip(support_x, support_y):
            yield x,real_classes[y],u
        for x,y in zip(query_x, query_y):
            yield x,real_classes[y],u



def get_extended_stats(ds):
    eaf = Episodic_Accessor_Factory(
        labels=ds["labels"],
        domains=ds["domains"],
        num_examples_per_domain_per_label=-1,
        iterator_seed=1337,
        dataset_seed=1337,
        n_shot=3,
        n_way=len(ds["labels"]),
        n_query=2,
        train_val_test_k_factors=(3,2,2),
        pickle_path=ds["pickle_path"],
#         x_transform_func=x_transform_func,
    )
    
    it = Iterable_Aggregator([eaf.get_train(), eaf.get_val(), eaf.get_test()])
    
    labels = {}
    
    count = 0
    for x,y,u in every_example_in_dl_generator(it):
        if y not in labels:
            labels[y] = set()
        labels[y].add(to_hash(x))
    
    for y in labels.keys():
        labels[y] = len(labels[y])
    
    labels = list(labels.items())
    
    labels.sort(key=lambda x: x[1])
    
    for y,count in labels:
        print("{}: {}".format(y,count))
        
    
for name, params in datasets.items():
    print(name)
    get_extended_stats(params)

CORES
39: 992
26: 1055
41: 1115
49: 1117
38: 1138
14: 1146
34: 1150
37: 1178
28: 1183
32: 1190
31: 1195
7: 1231
33: 1252
2: 1277
35: 1320
42: 1333
18: 1715
22: 1780
36: 1784
1: 1818
17: 1820
24: 1822
10: 1826
0: 1848
12: 1871
6: 1874
45: 1893
30: 1893
54: 1900
3: 1932
44: 1938
47: 1943
4: 1943
48: 1952
57: 1952
5: 1953
21: 1956
52: 1958
43: 1959
23: 1960
11: 1965
29: 1970
20: 1975
46: 1976
53: 1977
13: 1983
16: 1985
50: 1988
40: 1989
19: 1989
9: 1990
51: 1992
56: 1993
8: 1995
15: 2005
25: 2018
27: 2023
55: 2042
WISIG
56: 468
35: 468
47: 477
59: 504
79: 516
115: 518
65: 534
45: 543
80: 545
77: 555
55: 560
38: 560
48: 567
75: 568
13: 569
123: 570
25: 571
46: 572
68: 573
67: 578
51: 579
50: 580
74: 584
76: 587
61: 588
37: 590
58: 590
63: 590
17: 595
100: 595
20: 597
41: 597
57: 597
64: 599
112: 601
103: 604
22: 604
104: 607
54: 608
53: 612
66: 612
31: 615
119: 616
28: 625
120: 628
52: 640
97: 653
96: 661
49: 703
23: 722
128: 742
73: 743
60: 744
91: 746
78: 757
42: 769
129: 769
6: 773
2: 7