In [1]:
import torch
import tqdm

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from case_studies.dc2_diffusion.utils.catalog_parser import CatalogParser
from bliss.global_env import GlobalEnv

In [2]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./ynet_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("ynet_diffusion_notebook_config")

In [3]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = 256
dc2.setup(stage="fit")
GlobalEnv.current_encoder_epoch = 1
GlobalEnv.seed_in_this_program = 7272
dc2_val_dataloader = dc2.val_dataloader()

catalog_parser: CatalogParser = instantiate(new_diffusion_notebook_cfg.encoder.catalog_parser)

In [4]:
quantiles = []
dc2_val_dataloader = dc2.val_dataloader()
quantile_tensor = torch.tensor([0.001, 0.01, 0.1, 0.2, 0.3, 0.5, 0.8, 0.9, 0.99, 0.999], device=device)
for batch in tqdm.tqdm(dc2_val_dataloader):
    batch = move_data_to_device(batch, device=device)
    # quantiles.append(torch.norm(batch["images"], dim=1, p=2).quantile(q=quantile_tensor))
    quantiles.append(batch["images"].quantile(q=quantile_tensor))
quantiles = torch.stack(quantiles)

100%|██████████| 98/98 [00:14<00:00,  6.93it/s]


In [5]:
torch.stack([quantile_tensor, quantiles.mean(dim=0)]).T

tensor([[ 1.0000e-03, -6.2135e-01],
        [ 1.0000e-02, -3.9222e-01],
        [ 1.0000e-01, -1.0976e-01],
        [ 2.0000e-01, -4.8438e-02],
        [ 3.0000e-01, -2.3607e-02],
        [ 5.0000e-01,  7.4611e-03],
        [ 8.0000e-01,  7.6441e-02],
        [ 9.0000e-01,  1.6448e-01],
        [ 9.9000e-01,  7.0327e-01],
        [ 9.9900e-01,  9.4697e+00]], device='cuda:2')