In [1]:
import torch
from hydra import initialize, compose
from hydra.utils import instantiate
import tqdm
import math

from bliss.surveys.dc2 import DC2DataModule
from bliss.global_env import GlobalEnv

In [2]:
with initialize(config_path="./ynet_diffusion_config", version_base=None):
    notebook_cfg = compose("ynet_full_diffusion_notebook_config")

In [3]:
GlobalEnv.current_encoder_epoch = 1
GlobalEnv.seed_in_this_program = 7272

In [4]:
dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.prepare_data()
dc2.setup(stage="fit")

dc2_train_dataloader = dc2.train_dataloader()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")




In [5]:
n_sources_count = torch.zeros(10, device=device)
flux_max = 0
for batch in tqdm.tqdm(dc2_train_dataloader):
    n_sources_count += batch["tile_catalog"]["n_sources"].to(device=device).flatten().bincount(minlength=10)
    cur_flux_max = batch["tile_catalog"]["fluxes"].to(device=device).max()
    if cur_flux_max > flux_max:
        flux_max = cur_flux_max

100%|██████████| 1524/1524 [04:07<00:00,  6.16it/s]


In [6]:
print(flux_max)
print(math.log(flux_max + 1))

tensor(5.9038e+09, device='cuda:0')
22.498856031849975


In [7]:
print(list(zip(range(10), n_sources_count.tolist())))

[(0, 76836320.0), (1, 1153332.0), (2, 10308.0), (3, 67.0), (4, 0.0), (5, 0.0), (6, 0.0), (7, 0.0), (8, 0.0), (9, 0.0)]


In [8]:
print(n_sources_count / n_sources_count.sum())

tensor([9.8508e-01, 1.4786e-02, 1.3215e-04, 8.5897e-07, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], device='cuda:0')


In [9]:
print(n_sources_count[1:] / n_sources_count[1:].sum())

tensor([9.9108e-01, 8.8579e-03, 5.7575e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00], device='cuda:0')
