In [1]:
import torch
from hydra import initialize, compose
from hydra.utils import instantiate
from bliss.surveys.dc2 import DC2DataModule
import tqdm
from bliss.global_env import GlobalEnv

In [2]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")
notebook_cfg.surveys.dc2.tile_slen = 4
notebook_cfg.surveys.dc2.max_sources_per_tile = 6
notebook_cfg.surveys.dc2.cached_data_path = "/data/scratch/dc2local/dc2_cached_data"
notebook_cfg.surveys.dc2.batch_size = 512

In [3]:
dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.prepare_data()
dc2.setup(stage="fit")

dc2_train_dataloader = dc2.train_dataloader()
device = torch.device("cpu")
GlobalEnv.seed_in_this_program = 7272
GlobalEnv.current_encoder_epoch = 1




In [4]:
n_sources_count = torch.zeros(10, device=device)
for batch in tqdm.tqdm(dc2_train_dataloader):
    n_sources_count += batch["tile_catalog"]["n_sources"].to(device=device).flatten().bincount(minlength=10)

100%|██████████| 381/381 [25:32<00:00,  4.02s/it]


In [5]:
print(n_sources_count)

tensor([7.6831e+07, 1.1586e+06, 1.0393e+04, 7.1000e+01, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])


In [6]:
print(n_sources_count / n_sources_count.sum())

tensor([9.8501e-01, 1.4854e-02, 1.3324e-04, 9.1026e-07, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])


In [7]:
print(n_sources_count[1:] / n_sources_count[1:].sum())

tensor([9.9105e-01, 8.8902e-03, 6.0733e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])
