In [1]:
import torch
import tqdm
import os

from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.cached_dataset import CachedSimulatedDataModule
from bliss.catalog import TileCatalog
from case_studies.dc2_mdt.utils.encoder import M2BlissEncoder

In [2]:
model_name = "exp_07-03-1"
model_check_point_name = "encoder_43.ckpt"
model_path = "../../../bliss_output/m2_ori_bliss_exp_07-03-1_encoder_43.ckpt"
cached_data_path = Path("/data/scratch/pduan/posterior_cached_files")
device = torch.device("cuda:7")
with initialize(config_path="./m2_mdt_config", version_base=None):
    cfg = compose("m2_ori_bliss_train_config")

In [3]:
seed = cfg.train.seed
pytorch_lightning.seed_everything(seed=seed)

Seed set to 7272


7272

In [4]:
batch_size = 800
m2: CachedSimulatedDataModule = instantiate(cfg.cached_simulator)
m2.batch_size = batch_size
m2.setup(stage="validate")
m2_val_dataloader = m2.val_dataloader()

In [5]:
cfg.encoder.use_checkerboard = False
my_encoder: M2BlissEncoder = instantiate(cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
my_encoder.load_state_dict(pretrained_weights)
my_encoder.eval();

In [6]:
bliss_cached_file_name = f"m2_bliss_posterior_{model_name}_{model_check_point_name}_whole_val_set_" \
                         f"seed_{seed}.pt"
save_path = cached_data_path / bliss_cached_file_name

if not os.path.isfile(save_path):
    print(f"can't find cached file [{bliss_cached_file_name}]; rerun the inference")

    bliss_result_dict = {
        "pred_n_sources_list": [],
        "pred_locs_list": [],
        "pred_fluxes_list": [],
        "target1_n_sources_list": [],
        "target1_locs_list": [],
        "target1_fluxes_list": [],
        "target2_n_sources_list": [],
        "target2_locs_list": [],
        "target2_fluxes_list": [],
    }

    for one_batch in tqdm.tqdm(m2_val_dataloader):
        one_batch = move_data_to_device(one_batch, device=device)

        target_tile_cat = TileCatalog(one_batch["tile_catalog"])
        target_cat1 = target_tile_cat.get_brightest_sources_per_tile(
            band=0, exclude_num=0
        )
        target_cat2 = target_tile_cat.get_brightest_sources_per_tile(
            band=0, exclude_num=1
        )
        bliss_result_dict["target1_n_sources_list"].append(target_cat1["n_sources"].cpu())
        bliss_result_dict["target1_locs_list"].append(target_cat1["locs"].cpu())
        bliss_result_dict["target1_fluxes_list"].append(target_cat1["fluxes"].cpu())
        bliss_result_dict["target2_n_sources_list"].append(target_cat2["n_sources"].cpu())
        bliss_result_dict["target2_locs_list"].append(target_cat2["locs"].cpu())
        bliss_result_dict["target2_fluxes_list"].append(target_cat2["fluxes"].cpu())

        with torch.inference_mode():
            sample_tile_cat = my_encoder.sample(one_batch, use_mode=False)

        bliss_result_dict["pred_n_sources_list"].append(sample_tile_cat["n_sources"].cpu())  # (b, h, w)
        n_sources_mask = (sample_tile_cat["n_sources"].unsqueeze(-1).cpu() <= torch.arange(1, 3)).unsqueeze(-1)
        bliss_result_dict["pred_locs_list"].append(sample_tile_cat["locs"].cpu() * n_sources_mask)  # (b, h, w, 2, 2)
        bliss_result_dict["pred_fluxes_list"].append(sample_tile_cat["fluxes"].cpu() * n_sources_mask)  # (b, h, w, 2, 1)

    torch.save(bliss_result_dict, save_path)
else:
    print(f"find the cached file [{bliss_cached_file_name}]; run nothing")

can't find cached file [m2_bliss_posterior_exp_07-03-1_encoder_43.ckpt_whole_val_set_seed_7272.pt]; rerun the inference


100%|██████████| 31/31 [00:24<00:00,  1.26it/s]
