In [None]:
import torch
import tqdm
import os

import matplotlib.pyplot as plt
import seaborn as sns
from einops import rearrange, repeat
from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate
import pytorch_lightning
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from case_studies.dc2_new_diffusion.utils.encoder import DiffusionEncoder

In [None]:
model_name = "exp_03-16-3"
model_check_point_name = "encoder_113.ckpt"
model_path = f"../../../bliss_output/DC2_ynet_full_diffusion_exp/{model_name}/checkpoints/{model_check_point_name}"
cached_data_path = Path("/data/scratch/pduan/posterior_cached_files")
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./ynet_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("ynet_full_diffusion_notebook_config")

In [None]:
seed = 7272
pytorch_lightning.seed_everything(seed=seed)

In [None]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

batch_size = 128
dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = batch_size
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: DiffusionEncoder = instantiate(new_diffusion_notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [None]:
print(bliss_encoder.ddim_steps)
print(bliss_encoder.ddim_objective)
print(bliss_encoder.ddim_beta_schedule)
print(bliss_encoder.detection_diffusion.ddim_sampling_eta)
print(bliss_encoder.ddim_self_cond)
print(bliss_encoder.catalog_parser.factors[0].threshold)

In [None]:
ddim_steps = 500
bliss_encoder.ddim_steps = ddim_steps
bliss_encoder.detection_diffusion.sampling_timesteps = ddim_steps
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0
bliss_encoder.catalog_parser.factors[0].threshold = 0.0

In [None]:
one_batch = next(iter(dc2_val_dataloader))
one_batch = move_data_to_device(one_batch, device=device)

In [None]:
target_tile_cat = TileCatalog(one_batch["tile_catalog"])
target_images = one_batch["images"]
target_n_sources = target_tile_cat["n_sources"]
target_tile_cat = target_tile_cat.get_brightest_sources_per_tile(band=2)
target_locs = target_tile_cat["locs"]
target_fluxes = target_tile_cat["fluxes"]
target_ellipticity = target_tile_cat["ellipticity"]

In [None]:
total_iters = 100
diffusion_cached_file_name = f"diffusion_posterior_{model_name}_{model_check_point_name}_b_{batch_size}_iter_{total_iters}_seed_{seed}.pt"
if not os.path.isfile(cached_data_path / diffusion_cached_file_name):
    print("can't find cached file; rerun the inference")
    init_n_sources = None
    n_sources_list = []
    locs_list = []
    fluxes_list = []
    for i in tqdm.tqdm(range(total_iters)):
        with torch.inference_mode():
            sample_tile_cat, _ = bliss_encoder.sample(one_batch, return_inter_output=False)
        if init_n_sources is None:
            init_n_sources = rearrange(sample_tile_cat["n_sources_multi"], "b h w 1 1 -> b h w")
        cur_n_sources = rearrange(sample_tile_cat["n_sources_multi"], "b h w 1 1 -> b h w")
        init_n_sources_mask = init_n_sources > 0
        cur_n_sources_mask = cur_n_sources > 0
        nan_mask = (~(init_n_sources_mask & cur_n_sources_mask)) & init_n_sources_mask
        nan_mask = rearrange(nan_mask, "b h w -> b h w 1 1")
        n_sources_list.append(cur_n_sources.cpu())
        locs = sample_tile_cat["locs"][..., 0:1, :]  # (b, h, w, 1, 2)
        locs_list.append(
            torch.where(nan_mask, torch.nan, locs).cpu()
        )
        fluxes = sample_tile_cat["fluxes"][..., 0:1, :]  # (b, h, w, 1, 6)
        fluxes_list.append(
            torch.where(nan_mask, torch.nan, fluxes).cpu()
        )

    diffusion_result_dict = {
        "init_n_sources": init_n_sources.cpu(),
        "n_sources_list": n_sources_list,
        "locs_list": locs_list,
        "fluxes_list": fluxes_list,
        "target_images": target_images.cpu(),
        "target_n_sources": target_n_sources.cpu(),
        "target_locs": target_locs.cpu(),
        "target_fluxes": target_fluxes.cpu(),
        "target_ellipticity": target_ellipticity.cpu(),
    }
    torch.save(diffusion_result_dict, cached_data_path / diffusion_cached_file_name)
else:
    print("find the cached file; directly use it")
    with open(cached_data_path / diffusion_cached_file_name, "rb") as f:
        diffusion_result_dict = torch.load(f, map_location="cpu")

### Locs CI

In [None]:
def get_locs_ci(on_mask_locs, on_mask_target_locs, both_on_mask,
                ci_cover: list[float], for_vertical_locs: bool):
    left_q_points = [(1.0 - c) / 2 for c in ci_cover]
    right_q_points = [1.0 - lq for lq in left_q_points]
    actual_ci_cover = []
    if for_vertical_locs:
        locs_index = 0
    else:
        locs_index = 1

    for q in zip(left_q_points, right_q_points):
        q = torch.tensor(q)
        on_mask_locs_q = on_mask_locs.nanquantile(q=q, dim=-2).permute([1, 2, 0])  # (matched_sources, 2, 2)
        above_lower_bound = on_mask_target_locs[:, locs_index] > on_mask_locs_q[:, locs_index, 0]
        below_upper_bound = on_mask_target_locs[:, locs_index] < on_mask_locs_q[:, locs_index, 1]
        actual_ci_cover.append((above_lower_bound & below_upper_bound).sum() / both_on_mask.sum())
    return actual_ci_cover

In [None]:
def plot_data_for_locs_ci(result_dict, ci_cover):
    init_n_sources = result_dict["init_n_sources"]
    locs_list = result_dict["locs_list"]
    target_n_sources = result_dict["target_n_sources"]
    target_locs = result_dict["target_locs"]

    both_on_mask = (target_n_sources > 0) & (init_n_sources > 0)  # (b, h, w)
    all_locs = torch.cat(locs_list, dim=-2)  # (b, h, w, iter, 2)
    both_on_mask_iter_repeated = repeat(both_on_mask, 
                                        "b h w -> b h w iter k",
                                        iter=all_locs.shape[-2],
                                        k=all_locs.shape[-1])  # (b, h, w, iter, 2)
    both_on_mask_single_repeated = repeat(both_on_mask, 
                                        "b h w -> b h w 1 k",
                                        k=all_locs.shape[-1])  # (b, h, w, 1, 2)
    on_mask_locs = all_locs[both_on_mask_iter_repeated].view(-1, all_locs.shape[-2], 2)  # (matched_sources, iter, 2)
    on_mask_target_locs = target_locs[both_on_mask_single_repeated].view(-1, 2)  # (matched_sources, 2)
    v_ci_cover = get_locs_ci(on_mask_locs, on_mask_target_locs, both_on_mask,
                             ci_cover=ci_cover,
                             for_vertical_locs=True)
    h_ci_cover = get_locs_ci(on_mask_locs, on_mask_target_locs, both_on_mask,
                             ci_cover=ci_cover,
                             for_vertical_locs=False)
    return v_ci_cover, h_ci_cover

In [None]:
ci_cover = [i / 100 for i in range(5, 100, 5)]
diffusion_v_ci_cover, diffusion_h_ci_cover = plot_data_for_locs_ci(diffusion_result_dict, ci_cover)

In [None]:
plt.figure(figsize=(10, 10))
plt.plot(ci_cover, diffusion_v_ci_cover, label="Diffusion")
plt.plot(ci_cover, ci_cover, linestyle="dashed", label="Expected Coverage")
plt.xlabel("Expected Vertical Locs CI Coverage")
plt.ylabel("Actual Vertical Locs CI Coverage")
plt.legend()
plt.xticks(ci_cover)
plt.yticks(ci_cover)
plt.grid()
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
plt.plot(ci_cover, diffusion_h_ci_cover, label="Diffusion")
plt.plot(ci_cover, ci_cover, linestyle="dashed", label="Expected Coverage")
plt.xlabel("Expected Horizontal Locs CI Coverage")
plt.ylabel("Actual Horizontal Locs CI Coverage")
plt.legend()
plt.xticks(ci_cover)
plt.yticks(ci_cover)
plt.grid()
plt.show()

### N sources CI

In [None]:
def plot_data_for_n_sources_ci(result_dict):
    n_sources_list = result_dict["n_sources_list"]
    target_n_sources = result_dict["target_n_sources"]

    est_n_sources_percent = (torch.stack(n_sources_list, dim=-1) > 0).float().mean(dim=-1)
    target_n_sources_mask = target_n_sources > 0

    boundaries = torch.tensor([-0.1, ] + [i / 100 for i in range(10, 100, 10)] + [1.1, ])
    est_n_sources_percent_bin = torch.bucketize(est_n_sources_percent, boundaries=boundaries)
    est_n_sources_percent_bin -= 1
    bins = list(zip(boundaries[:-1].tolist(), boundaries[1:].tolist()))

    bin_masks = [est_n_sources_percent_bin == i for i in range(len(bins))]
    true_n_sources_percent = [(target_n_sources_mask[m].sum() / m.sum()).item() for m in bin_masks]
    return bins, true_n_sources_percent, est_n_sources_percent_bin

In [None]:
bins, diffusion_true_ns_percent, diffusion_est_ns_percent_bin = plot_data_for_n_sources_ci(diffusion_result_dict)

In [None]:
x = list(range(len(bins)))
_, _, patches = plt.hist(diffusion_est_ns_percent_bin.flatten(), 
                         bins=x + [len(bins), ], log=True,
                         edgecolor="black", facecolor="skyblue")
plt.bar_label(patches, fmt="%.1e", fontsize=7)
plt.xlabel("Bin")
plt.xticks([xx + 0.5 for xx in x], 
           labels=["[0.00, 0.10]", ] + \
                  [f"[{bl:.2f}, {br:.2f}]" for bl, br in bins[1:-1]] + \
                  ["[0.90, 1.00]", ], 
           rotation=45)
plt.ylabel("Count")
plt.show()

In [None]:
refer_n_sources_prob = [i / 100 for i in range(5, 96, 10)]
x = list(range(len(bins)))
plt.plot(x, diffusion_true_ns_percent, label="Diffusion")
plt.plot(x, refer_n_sources_prob, linestyle="dashed", label="Expected")
plt.yticks(refer_n_sources_prob)
plt.xticks(x, 
           labels=["[0.00, 0.10]", ] + \
                  [f"[{bl:.2f}, {br:.2f}]" for bl, br in bins[1:-1]] + \
                  ["[0.90, 1.00]", ], 
           rotation=45)
plt.xlabel("Diffusion Estimated N Sources Probability")
plt.ylabel("Actual Probability of Bins")
plt.grid()
plt.legend()
plt.show()