In [None]:
import torch
from pathlib import Path
import numpy as np

import matplotlib.pyplot as plt

from hydra import initialize, compose
from hydra.utils import instantiate

from case_studies.dc2_cataloging.utils.load_full_cat import get_full_cat
from case_studies.dc2_cataloging.utils.notebook_variables import NoteBookVariables

output_dir = Path("./plot_output/detection_selector_output/")
output_dir.mkdir(parents=True, exist_ok=True)

# change this model path according to your training setting
model_path = "../../../bliss_output/DC2_cataloging_exp/exp_08-02-1/checkpoints/best_encoder.ckpt"
lsst_root_dir = "/data/scratch/dc2_nfs/"

device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [None]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")

In [None]:
test_image_idx = 0
test_image, test_image_cat, bliss_full_cat, lsst_full_cat = get_full_cat(notebook_cfg, 
                                                                        test_image_idx, 
                                                                        model_path, 
                                                                        lsst_root_dir, 
                                                                        device)
image_lim = test_image.shape[1]
test_image = test_image[2]  # r-band

In [None]:
matcher = instantiate(notebook_cfg.encoder.matcher)
plocs_box_len = 100
first_legend = True

for i in range(0, image_lim, plocs_box_len):
    for j in range(0, image_lim, plocs_box_len):
        plocs_box_origin = torch.tensor([i, j])

        cur_target_full_cat = test_image_cat.filter_by_ploc_box(plocs_box_origin, plocs_box_len)
        cur_bliss_full_cat = bliss_full_cat.filter_by_ploc_box(plocs_box_origin, plocs_box_len)
        cur_lsst_full_cat = lsst_full_cat.filter_by_ploc_box(plocs_box_origin, plocs_box_len)
        bliss_matching = matcher.match_catalogs(cur_target_full_cat, cur_bliss_full_cat)[0]
        lsst_matching = matcher.match_catalogs(cur_target_full_cat, cur_lsst_full_cat)[0]
        
        n_bliss_matching = len(bliss_matching[1])
        n_lsst_matching = len(lsst_matching[1])
        n_target = cur_target_full_cat["plocs"].shape[1]
        bliss_lsst_matching_diff = abs(n_bliss_matching - n_lsst_matching) / n_target if n_target != 0 else 0
        if bliss_lsst_matching_diff < 0.4:
            continue

        target_set = set(list(range(0, cur_target_full_cat["plocs"].shape[1])))
        bliss_match_set = set(bliss_matching[0].int().tolist())
        lsst_match_set = set(lsst_matching[0].int().tolist())
        missing_match = list(target_set - (bliss_match_set | lsst_match_set))
        only_bliss_match = list(bliss_match_set - lsst_match_set)
        only_lsst_match = list(lsst_match_set - bliss_match_set)
        both_match = list(lsst_match_set & bliss_match_set)

        bliss_error = torch.tensor(list(target_set - bliss_match_set)).view(-1, 1)
        lsst_error = torch.tensor(list(target_set - lsst_match_set)).view(-1, 1)
        bliss_error_indices = (bliss_error == bliss_matching[0].int()).nonzero()[:, 1]
        lsst_error_indices = (lsst_error == lsst_matching[0].int()).nonzero()[:, 1]
        bliss_error = torch.take_along_dim(bliss_matching[1].int(), 
                                                   indices=bliss_error_indices, 
                                                   dim=0)
        lsst_error = torch.take_along_dim(lsst_matching[1].int(),
                                                  indices=lsst_error_indices,
                                                  dim=0)

        fig, ax = plt.subplots(figsize=NoteBookVariables.figsize)
        image_sub = test_image[i:(i + plocs_box_len), j:(j + plocs_box_len)]
        ax.imshow(np.arcsinh(image_sub - 0.0073), cmap="viridis")
        ax.scatter(cur_target_full_cat["plocs"][0, only_bliss_match, 1], 
                   cur_target_full_cat["plocs"][0, only_bliss_match, 0], 
                   facecolors="none", edgecolors="aqua", 
                   alpha=1, s=130, linewidth=3, label="Only BLISS")
        ax.scatter(cur_target_full_cat["plocs"][0, only_lsst_match, 1], 
                   cur_target_full_cat["plocs"][0, only_lsst_match, 0], 
                   facecolors="none", edgecolors="orange", 
                   alpha=1, s=130, linewidth=3, label="Only LSST")
        ax.scatter(cur_target_full_cat["plocs"][0, both_match, 1], 
                   cur_target_full_cat["plocs"][0, both_match, 0], 
                   facecolors="none", edgecolors="lime", 
                   alpha=1, s=130, linewidth=3, label="Both")
        ax.scatter(cur_target_full_cat["plocs"][0, missing_match, 1], 
                   cur_target_full_cat["plocs"][0, missing_match, 0], 
                   facecolors="none", edgecolors="red", 
                   alpha=1, s=130, linewidth=3, label="Neither")
        ax.scatter(cur_bliss_full_cat["plocs"][0, bliss_error, 1], 
                   cur_bliss_full_cat["plocs"][0, bliss_error, 0],
                    marker="X", facecolors="aqua", edgecolors="aqua", 
                    alpha=1, s=100, linewidth=1, label="BLISS Error")
        ax.scatter(cur_lsst_full_cat["plocs"][0, lsst_error, 1], 
                   cur_lsst_full_cat["plocs"][0, lsst_error, 0],
                    marker="X", facecolors="orange", edgecolors="orange", 
                    alpha=1, s=100, linewidth=1, label="LSST Error")
        ax.tick_params(labelsize=NoteBookVariables.font_size)

        if first_legend:
            ax.legend(loc="lower right", fontsize=NoteBookVariables.font_size)
            first_legend = False
        fig.savefig(output_dir / f"image_{test_image_idx}_{i}_{j}.pdf", bbox_inches="tight", dpi=NoteBookVariables.dpi)
        plt.close()