In [None]:
import torch
import tqdm

import matplotlib.pyplot as plt
from einops import rearrange, repeat

from hydra import initialize, compose
from hydra.utils import instantiate
from pytorch_lightning.utilities import move_data_to_device

from bliss.surveys.dc2 import DC2DataModule
from bliss.catalog import TileCatalog
from case_studies.dc2_new_diffusion.utils.encoder import DiffusionEncoder

In [None]:
model_name = "exp_02-05-1"
model_check_point_name = "encoder_48.ckpt"
model_path = f"../../../bliss_output/DC2_ynet_diffusion_exp/{model_name}/checkpoints/{model_check_point_name}"
device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
with initialize(config_path="./ynet_diffusion_config", version_base=None):
    new_diffusion_notebook_cfg = compose("ynet_diffusion_notebook_config")

In [None]:
tile_slen = new_diffusion_notebook_cfg.surveys.dc2.tile_slen
max_sources_per_tile = new_diffusion_notebook_cfg.surveys.dc2.max_sources_per_tile
r_band_min_flux = new_diffusion_notebook_cfg.notebook_var.r_band_min_flux

dc2: DC2DataModule = instantiate(new_diffusion_notebook_cfg.surveys.dc2)
dc2.batch_size = 128
dc2.setup(stage="validate")
dc2_val_dataloader = dc2.val_dataloader()

bliss_encoder: DiffusionEncoder = instantiate(new_diffusion_notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, map_location=device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [None]:
print(bliss_encoder.ddim_steps)
print(bliss_encoder.ddim_objective)
print(bliss_encoder.ddim_beta_schedule)

In [None]:
print(bliss_encoder.catalog_parser.factors[0].threshold)

In [None]:
bliss_encoder.catalog_parser.factors[0].threshold = 0.0

In [None]:
ddim_steps = 1000
bliss_encoder.ddim_steps = ddim_steps
bliss_encoder.detection_diffusion.sampling_timesteps = ddim_steps

In [None]:
one_batch = next(iter(dc2_val_dataloader))
one_batch = move_data_to_device(one_batch, device=device)

In [None]:
bliss_encoder.detection_diffusion.ddim_sampling_eta = 1.0

In [None]:
target_tile_cat = TileCatalog(one_batch["tile_catalog"]).get_brightest_sources_per_tile(band=2)
target_n_sources = target_tile_cat["n_sources"]
target_locs = target_tile_cat["locs"]

In [None]:
init_n_sources = None
locs_list = []
total_iters = 50
for i in tqdm.tqdm(range(total_iters)):
    with torch.no_grad():
        sample_tile_cat, _ = bliss_encoder.sample(one_batch, return_inter_output=False)
    if init_n_sources is None:
        init_n_sources = sample_tile_cat["n_sources"] > 0  # (b, h, w)
    cur_n_sources = sample_tile_cat["n_sources"] > 0
    nan_mask = (~(init_n_sources & cur_n_sources)) & init_n_sources
    nan_mask = rearrange(nan_mask, 
                         "b h w -> b h w 1 1")
    
    locs = sample_tile_cat["locs"]  # (b, h, w, 1, 2)
    locs_list.append(
        torch.where(nan_mask, torch.nan, locs).cpu()
    )

In [None]:
both_on_mask = ((target_n_sources > 0) & init_n_sources).cpu()  # (b, h, w)
all_locs = torch.cat(locs_list, dim=-2)  # (b, h, w, iter, 2)
both_on_mask_iter_repeated = repeat(both_on_mask, 
                                "b h w -> b h w iter k",
                                iter=all_locs.shape[-2],
                                k=all_locs.shape[-1])  # (b, h, w, iter, 2)
both_on_mask_single_repeated = repeat(both_on_mask, 
                                "b h w -> b h w 1 k",
                                k=all_locs.shape[-1])  # (b, h, w, 1, 2)

In [None]:
on_mask_locs = torch.masked_select(all_locs, both_on_mask_iter_repeated).view(-1, total_iters, 2)  # (matched_sources, iter, 2)
on_mask_target_locs = torch.masked_select(target_locs.cpu(), both_on_mask_single_repeated).view(-1, 2)  # (matched_sources, 1, 2)

In [None]:
on_mask_locs.shape, on_mask_target_locs.shape

In [None]:
both_on_mask.sum()

In [None]:
on_mask_locs_q = on_mask_locs.nanquantile(q=torch.tensor([0.05, 0.95]), dim=-2).permute([1, 2, 0])  # (matched_sources, 2, 2)

In [None]:
((on_mask_target_locs[:, 0] > on_mask_locs_q[:, 0, 0]) & \
  (on_mask_target_locs[:, 0] < on_mask_locs_q[:, 0, 1])).sum() / both_on_mask.sum()

In [None]:
on_mask_target_locs[50:60, 0]

In [None]:
on_mask_locs_q[50:60, 0]