In [1]:
import torch
import tqdm
import os

from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.cached_dataset import CachedSimulatedDataModule
from bliss.catalog import TileCatalog
from case_studies.dc2_mdt.utils.encoder import M2BlissEncoder

In [2]:
model_name = "exp_04-20-3"
model_check_point_name = "encoder_31.ckpt"
model_path = f"../../../bliss_output/M2_mdt_exp/{model_name}/checkpoints/{model_check_point_name}"
cached_data_path = Path("/data/scratch/pduan/posterior_cached_files")
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./m2_mdt_config", version_base=None):
    cfg = compose("m2_cond_true_bliss_train_config")

In [3]:
# seed = cfg.train.seed
# pytorch_lightning.seed_everything(seed=seed)

In [4]:
seed = 9999
pytorch_lightning.seed_everything(seed=seed)

Seed set to 9999


9999

In [5]:
batch_size = 800
m2: CachedSimulatedDataModule = instantiate(cfg.cached_simulator)
m2.batch_size = batch_size
m2.setup(stage="validate")
m2_val_dataloader = m2.val_dataloader()

In [6]:
my_encoder: M2BlissEncoder = instantiate(cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
my_encoder.load_state_dict(pretrained_weights)
my_encoder.eval();

In [7]:
one_batch = next(iter(m2_val_dataloader))
one_batch = move_data_to_device(one_batch, device=device)

In [8]:
target_tile_cat = TileCatalog(one_batch["tile_catalog"])
target_images = one_batch["images"]
target_n_sources = target_tile_cat["n_sources"]
target_locs = target_tile_cat["locs"]
target_fluxes = target_tile_cat["fluxes"]

In [9]:
total_iters = 500
cached_file_name = f"m2_cond_true_bliss_posterior_" \
                         f"{model_name}_{model_check_point_name}_" \
                         f"b_{batch_size}_" \
                         f"iter_{total_iters}_" \
                         f"seed_{seed}.pt"
save_path = cached_data_path / cached_file_name
if not os.path.isfile(save_path):
    print("can't find cached file; rerun the inference")
    n_sources_list = []
    locs_list = []
    fluxes_list = []
    for i in tqdm.tqdm(range(total_iters)):
        with torch.no_grad():
            sample_tile_cat = my_encoder.sample(one_batch, use_mode=False)
        cur_n_sources = sample_tile_cat["n_sources"]
        n_sources_list.append(cur_n_sources.cpu())
        locs = sample_tile_cat["locs"]  # (b, h, w, 2, 2)
        locs_list.append(locs.cpu())
        fluxes = sample_tile_cat["fluxes"]  # (b, h, w, 2, 6)
        fluxes_list.append(fluxes.cpu())

    bliss_result_dict = {
        "n_sources_list": n_sources_list,
        "locs_list": locs_list,
        "fluxes_list": fluxes_list,
        "target_images": target_images.cpu(),
        "target_n_sources": target_n_sources.cpu(),
        "target_locs": target_locs.cpu(),
        "target_fluxes": target_fluxes.cpu(),
    }
    torch.save(bliss_result_dict, save_path)
else:
    print("find the cached file; run nothing")

can't find cached file; rerun the inference


100%|██████████| 500/500 [02:15<00:00,  3.70it/s]
