In [1]:
import torch
import tqdm

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from case_studies.dc2_diffusion.utils.catalog_parser import CatalogParser
from bliss.global_env import GlobalEnv

In [2]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
with initialize(config_path=".", version_base=None):
    new_diffusion_notebook_cfg = compose("new_diffusion_notebook_config")

In [3]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = 1024
dc2.setup(stage="fit")
GlobalEnv.current_encoder_epoch = 1
GlobalEnv.seed_in_this_program = 7272
dc2_val_dataloader = dc2.val_dataloader()

catalog_parser: CatalogParser = instantiate(new_diffusion_notebook_cfg.encoder.catalog_parser)

In [4]:
quantiles = []
dc2_val_dataloader = dc2.val_dataloader()
quantile_tensor = torch.tensor([0.001, 0.01, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 0.99, 0.999], device=device)
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch = move_data_to_device(batch, device=device)
    quantiles.append(torch.norm(batch["images"], dim=1, p=2).quantile(q=quantile_tensor))
quantiles = torch.stack(quantiles)

100%|██████████| 25/25 [00:11<00:00,  2.18it/s]


In [6]:
torch.stack([quantile_tensor, quantiles.mean(dim=0)]).T

tensor([[1.0000e-03, 4.6561e-02],
        [1.0000e-02, 7.4178e-02],
        [1.0000e-01, 1.3488e-01],
        [2.0000e-01, 1.7251e-01],
        [3.0000e-01, 2.0560e-01],
        [5.0000e-01, 2.7262e-01],
        [8.0000e-01, 4.1535e-01],
        [9.0000e-01, 5.1389e-01],
        [9.9000e-01, 1.7472e+00],
        [9.9900e-01, 2.7046e+01]], device='cuda:2')