In [1]:
import torch
from hydra import initialize, compose
from hydra.utils import instantiate
from bliss.surveys.dc2 import DC2DataModule
import tqdm
from bliss.catalog import TileCatalog
from bliss.global_settings import GlobalSettings

In [2]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")
dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.prepare_data()
dc2.setup("fit")

dc2_train_dataloader = dc2.train_dataloader()
device = torch.device("cpu")




In [3]:
n_star = 0
n_galaxy = 0
GlobalSettings.seed_in_this_program = 0
GlobalSettings.current_encoder_epoch = 0
for batch in tqdm.tqdm(dc2_train_dataloader):
    tile_cat = TileCatalog(4, batch["tile_catalog"])
    star_count = tile_cat.star_bools.sum().item()
    galaxy_count = tile_cat.galaxy_bools.sum().item()
    n_star += star_count
    n_galaxy += galaxy_count

100%|██████████| 3047/3047 [07:23<00:00,  6.87it/s]


In [4]:
print(f"star count: {n_star}")
print(f"galaxy count: {n_galaxy}")

star count: 35646
galaxy count: 1170188


In [6]:
print(f"per full image star: {n_star / 0.8 / 100}")
print(f"per full image galaxy: {n_galaxy / 0.8 / 100}")
print(f"star percent: {n_star / (n_star + n_galaxy)}")
print(f"galaxy percent: {n_galaxy / (n_star + n_galaxy)}")

per full image star: 445.575
per full image galaxy: 14627.35
star percent: 0.029561282896319062
galaxy percent: 0.9704387171036809
