In [1]:
import torch
from pathlib import Path
import numpy as np

import matplotlib.pyplot as plt

from hydra import initialize, compose
from hydra.utils import instantiate

from case_studies.dc2_cataloging.utils.load_full_cat import get_full_cat

output_dir = Path("./detection_selector_output/")
output_dir.mkdir(parents=True, exist_ok=True)

# change this model path according to your training setting
model_path = "../../../bliss_output/DC2_cataloging_exp/exp_06-16-2/checkpoints/best_encoder.ckpt"
lsst_root_dir = "/data/scratch/dc2_nfs/"

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [2]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")

In [3]:
image_idx = 0
test_image, test_image_cat, bliss_full_cat, lsst_full_cat = get_full_cat(notebook_cfg, 
                                                                        image_idx, 
                                                                        model_path, 
                                                                        lsst_root_dir, 
                                                                        device)
image_lim = test_image.shape[1]
test_image = test_image[2]  # r-band

In [8]:
matcher = instantiate(notebook_cfg.encoder.matcher)
color_list = plt.rcParams["axes.prop_cycle"].by_key()["color"][0:4]
plocs_box_len = 100
output_img_dir = output_dir / "images"
output_img_dir.mkdir(exist_ok=True)
for i in range(0, image_lim, plocs_box_len):
    for j in range(0, image_lim, plocs_box_len):
        plocs_box_origin = torch.tensor([i, j])

        cur_target_full_cat = test_image_cat.filter_full_catalog_by_ploc_box(plocs_box_origin, plocs_box_len)
        cur_bliss_full_cat = bliss_full_cat.filter_full_catalog_by_ploc_box(plocs_box_origin, plocs_box_len)
        cur_lsst_full_cat = lsst_full_cat.filter_full_catalog_by_ploc_box(plocs_box_origin, plocs_box_len)
        bliss_matching = matcher.match_catalogs(cur_target_full_cat, cur_bliss_full_cat)[0]
        lsst_matching = matcher.match_catalogs(cur_target_full_cat, cur_lsst_full_cat)[0]
        
        n_bliss_matching = len(bliss_matching[1])
        n_lsst_matching = len(lsst_matching[1])
        n_target = cur_target_full_cat["plocs"].shape[1]
        bliss_lsst_matching_diff = abs(n_bliss_matching - n_lsst_matching) / n_target if n_target != 0 else 0
        if bliss_lsst_matching_diff < 0.5:
            continue

        target_set = set(list(range(0, cur_target_full_cat["plocs"].shape[1])))
        bliss_match_set = set(bliss_matching[0].int().tolist())
        lsst_match_set = set(lsst_matching[0].int().tolist())
        missing_match = list(target_set - (bliss_match_set | lsst_match_set))
        only_bliss_match = list(bliss_match_set - lsst_match_set)
        only_lsst_match = list(lsst_match_set - bliss_match_set)
        both_match = list(lsst_match_set & bliss_match_set)

        fig,ax = plt.subplots(figsize=(8, 8))
        image_sub = test_image[i:(i + plocs_box_len), j:(j + plocs_box_len)]
        ax.imshow(np.log((image_sub - image_sub.min()) + 80), cmap="viridis")
        ax.scatter(cur_target_full_cat["plocs"][0, missing_match, 1], 
                   cur_target_full_cat["plocs"][0, missing_match, 0], 
                   facecolors="none", edgecolors=color_list[0], 
                   alpha=1, s=130, linewidth=3, label="Missing Objects")
        ax.scatter(cur_target_full_cat["plocs"][0, only_bliss_match, 1], 
                   cur_target_full_cat["plocs"][0, only_bliss_match, 0], 
                   facecolors="none", edgecolors=color_list[1], 
                   alpha=1, s=130, linewidth=3, label="Only BLISS Match")
        ax.scatter(cur_target_full_cat["plocs"][0, only_lsst_match, 1], 
                   cur_target_full_cat["plocs"][0, only_lsst_match, 0], 
                   facecolors="none", edgecolors=color_list[2], 
                   alpha=1, s=130, linewidth=3, label="Only LSST Match")
        ax.scatter(cur_target_full_cat["plocs"][0, both_match, 1], 
                   cur_target_full_cat["plocs"][0, both_match, 0], 
                   facecolors="none", edgecolors=color_list[3], 
                   alpha=1, s=130, linewidth=3, label="Both Match")

        ax.legend()
        plt.savefig(output_img_dir / f"image_{image_idx}_{i}_{j}.pdf", bbox_inches="tight")
        plt.close()