In [None]:
import torch
import tqdm
import os
from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from case_studies.dc2_cataloging.utils.encoder import CalibrationEncoder

In [2]:
model_name = "exp_03-30-1"
model_check_point_name = "encoder_44.ckpt"
model_path = f"../../../bliss_output/DC2_cataloging_exp/{model_name}/checkpoints/{model_check_point_name}"
cached_data_path = Path("/data/scratch/pduan/posterior_cached_files")
device = torch.device("cuda:2")
with initialize(config_path="./", version_base=None):
    notebook_cfg = compose("notebook_config")
notebook_cfg.encoder.use_double_detect = True
notebook_cfg.encoder.minimalist_conditioning = False
notebook_cfg.encoder.use_checkerboard = False

In [3]:
seed = 7272
pytorch_lightning.seed_everything(seed=seed)

Seed set to 7272


7272

In [4]:
batch_size = 800
dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.batch_size = batch_size
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: CalibrationEncoder = instantiate(notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [5]:
bliss_cached_file_name = f"bliss_posterior_{model_name}_{model_check_point_name}_whole_val_set_" \
                         f"seed_{seed}.pt"
save_path = cached_data_path / bliss_cached_file_name

if not os.path.isfile(save_path):
    print(f"can't find cached file [{bliss_cached_file_name}]; rerun the inference")

    bliss_result_dict = {
        "pred_n_sources_list": [],
        "pred_locs_list": [],
        "pred_fluxes_list": [],
        "target_n_sources_list": [],
        "target_locs_list": [],
        "target_fluxes_list": [],
    }

    for one_batch in tqdm.tqdm(dc2_val_dataloader):
        one_batch = move_data_to_device(one_batch, device=device)

        target_tile_cat = TileCatalog(one_batch["tile_catalog"])
        bliss_result_dict["target_n_sources_list"].append(target_tile_cat["n_sources"].cpu())
        target_tile_cat = target_tile_cat.get_brightest_sources_per_tile(band=2)
        bliss_result_dict["target_locs_list"].append(target_tile_cat["locs"].cpu())
        bliss_result_dict["target_fluxes_list"].append(target_tile_cat["fluxes"].cpu())

        with torch.inference_mode():
            sample_tile_cat = bliss_encoder.sample(one_batch, use_mode=False)

        bliss_result_dict["pred_n_sources_list"].append(sample_tile_cat["n_sources"].cpu())
        bliss_result_dict["pred_locs_list"].append(sample_tile_cat["locs"][..., 0:1, :].cpu())  # (b, h, w, 1, 2)
        bliss_result_dict["pred_fluxes_list"].append(sample_tile_cat["fluxes"][..., 0:1, :].cpu())  # (b, h, w, 1, 6)

    torch.save(bliss_result_dict, save_path)
else:
    print(f"find the cached file [{bliss_cached_file_name}]; run nothing")

can't find cached file [bliss_posterior_exp_03-30-1_encoder_44.ckpt_whole_val_set_seed_7272.pt]; rerun the inference


100%|██████████| 32/32 [01:59<00:00,  3.75s/it]
