In [None]:
import torch
import tqdm
import os

from einops import rearrange
from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from case_studies.dc2_mdt.utils.encoder import DiffusionEncoder

In [None]:
model_tag_name = "mdt"
model_name = "exp_04-12-1"
model_check_point_name = "encoder_46.ckpt"
model_path = f"../../../bliss_output/DC2_mdt_exp/{model_name}/checkpoints/{model_check_point_name}"
cached_data_path = Path("/data/scratch/pduan/posterior_cached_files")
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./mdt_config", version_base=None):
    cfg = compose("mdt_train_config")
infer_batch_size = 720
infer_total_iters = 500

In [None]:
seed = cfg.train.seed
print(f"using seed {seed}")
pytorch_lightning.seed_everything(seed=seed)

In [None]:
dc2: DC2DataModule = instantiate(cfg.surveys.dc2)
dc2.batch_size = infer_batch_size
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: DiffusionEncoder = instantiate(cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [None]:
if hasattr(bliss_encoder.my_net, "fast_inference_mode"):
    bliss_encoder.my_net.fast_inference_mode = True
    print("enable fast inference mode")
else:
    print("no fast inference mode")

In [None]:
bliss_encoder.reconfig_sampling(new_sampling_time_steps=500, new_ddim_eta=1.0)

In [None]:
one_batch = next(iter(dc2_val_dataloader))
one_batch = move_data_to_device(one_batch, device=device)

In [None]:
target_tile_cat = TileCatalog(one_batch["tile_catalog"])
target_images = one_batch["images"]
target_n_sources = target_tile_cat["n_sources"]
target_tile_cat = target_tile_cat.get_brightest_sources_per_tile(band=2)
target_locs = target_tile_cat["locs"]
max_fluxes = bliss_encoder.max_fluxes
print(f"max_fluxes: {max_fluxes}")
target_fluxes = target_tile_cat["fluxes"].clamp(max=max_fluxes)
target_ellipticity = target_tile_cat["ellipticity"]

In [None]:
diffusion_cached_file_name = f"{model_tag_name}_" \
                             f"posterior_{model_name}_{model_check_point_name}_" \
                             f"b_{infer_batch_size}_" \
                             f"iter_{infer_total_iters}_" \
                             f"seed_{seed}.pt"
save_path = cached_data_path / diffusion_cached_file_name
if not os.path.isfile(save_path):
    print(f"can't find cached file [{diffusion_cached_file_name}]; rerun the inference")
    init_n_sources = None
    n_sources_list = []
    locs_list = []
    fluxes_list = []
    for i in tqdm.tqdm(range(infer_total_iters)):
        with torch.inference_mode():
            sample_tile_cat = bliss_encoder.sample(one_batch)

        if init_n_sources is None:
            if "n_sources_multi" in sample_tile_cat:
                init_n_sources = rearrange(sample_tile_cat["n_sources_multi"], "b h w 1 1 -> b h w")
            else:
                init_n_sources = sample_tile_cat["n_sources"]
        if "n_sources_multi" in sample_tile_cat:
            cur_n_sources = rearrange(sample_tile_cat["n_sources_multi"], "b h w 1 1 -> b h w")
        else:
            cur_n_sources = sample_tile_cat["n_sources"]
            
        n_sources_list.append(cur_n_sources.cpu())
        locs = sample_tile_cat["locs"][..., 0:1, :]  # (b, h, w, 1, 2)
        locs_list.append(locs.cpu())
        fluxes = sample_tile_cat["fluxes"][..., 0:1, :]  # (b, h, w, 1, 6)
        fluxes_list.append(fluxes.cpu())

    diffusion_result_dict = {
        "init_n_sources": init_n_sources.cpu(),
        "n_sources_list": n_sources_list,
        "locs_list": locs_list,
        "fluxes_list": fluxes_list,
        "target_images": target_images.cpu(),
        "target_n_sources": target_n_sources.cpu(),
        "target_locs": target_locs.cpu(),
        "target_fluxes": target_fluxes.cpu(),
        "target_ellipticity": target_ellipticity.cpu(),
    }
    torch.save(diffusion_result_dict, save_path)
else:
    print(f"find the cached file [{diffusion_cached_file_name}]; run nothing")