In [1]:
import torch
from pathlib import Path
from einops import rearrange
import pickle
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from hydra import initialize, compose
from hydra.utils import instantiate

from pytorch_lightning.utilities import move_data_to_device

from bliss.catalog import FullCatalog
from bliss.surveys.dc2 import DC2, unsqueeze_tile_dict
from case_studies.dc2_cataloging.utils.load_lsst import get_lsst_full_cat

output_dir = Path("./match_selector_output/")
output_dir.mkdir(parents=True, exist_ok=True)

# change this model path according to your training setting
model_path = "../../../bliss_output/DC2_cataloging_exp/exp_06-16-2/checkpoints/best_encoder.ckpt"
lsst_root_dir = "/data/scratch/dc2_nfs/"

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [2]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")

In [3]:
dc2: DC2 = instantiate(notebook_cfg.surveys.dc2)
image_idx = 0
test_sample = dc2.get_plotting_sample(image_idx)
cur_image_wcs = test_sample["wcs"]
cur_image_true_full_cat: FullCatalog = test_sample["full_catalog"]
cur_image_match_id = test_sample["match_id"]
image_lim = test_sample["image"].shape[1]
r_band_min_flux = notebook_cfg.encoder.min_flux_for_metrics

In [4]:
lsst_full_cat = get_lsst_full_cat(lsst_root_dir=lsst_root_dir,
                                  cur_image_wcs=cur_image_wcs,
                                  image_lim=image_lim,
                                  r_band_min_flux=r_band_min_flux,
                                  device=device)

In [5]:
bliss_encoder = instantiate(notebook_cfg.encoder).to(device=device)
pretrained_weights = torch.load(model_path, device)["state_dict"]
bliss_encoder.load_state_dict(pretrained_weights)
bliss_encoder.eval();

In [6]:
batch = {
    "tile_catalog": unsqueeze_tile_dict(test_sample["tile_catalog"]),
    "images": rearrange(test_sample["image"], "h w nw -> 1 h w nw"),
    "background": rearrange(test_sample["background"], "h w nw -> 1 h w nw"),
    "psf_params": rearrange(test_sample["psf_params"], "h w -> 1 h w")
}

batch = move_data_to_device(batch, device=device)

bliss_output_path = output_dir / "bliss_output.pkl"

if not bliss_output_path.exists():
    bliss_out_dict = bliss_encoder.predict_step(batch, None)

    with open(bliss_output_path, "wb") as outp:  # Overwrites any existing file.
        pickle.dump(bliss_out_dict, outp, pickle.HIGHEST_PROTOCOL)
else:
    with open(bliss_output_path, "rb") as inputp:
        bliss_out_dict = pickle.load(inputp)

In [7]:
bliss_full_cat: FullCatalog = bliss_out_dict["mode_cat"].to_full_catalog()

In [8]:
matcher = instantiate(notebook_cfg.encoder.matcher)
plocs_box_len = 100
output_img_dir = output_dir / "images"
output_img_dir.mkdir(exist_ok=True)
for i in range(0, image_lim, plocs_box_len):
    for j in range(0, image_lim, plocs_box_len):
        plocs_box_origin = torch.tensor([i, j])

        cur_target_full_cat = cur_image_true_full_cat.filter_full_catalog_by_ploc_box(plocs_box_origin, plocs_box_len)
        cur_bliss_full_cat = bliss_full_cat.filter_full_catalog_by_ploc_box(plocs_box_origin, plocs_box_len)
        cur_lsst_full_cat = lsst_full_cat.filter_full_catalog_by_ploc_box(plocs_box_origin, plocs_box_len)
        bliss_matching = matcher.match_catalogs(cur_target_full_cat, cur_bliss_full_cat)[0]
        lsst_matching = matcher.match_catalogs(cur_target_full_cat, cur_lsst_full_cat)[0]

        fig,ax = plt.subplots(figsize=(8, 8))
        image = test_sample["image"][0]
        image_sub = image[i:(i + plocs_box_len), j:(j + plocs_box_len)]
        ax.imshow(np.log((image_sub - image_sub.min()) + 80), cmap="viridis")
        ax.scatter(cur_target_full_cat["plocs"][0, :, 1], cur_target_full_cat["plocs"][0, :, 0], 
                   facecolors="none", edgecolors="r", 
                   alpha=1, s=130, linewidth=3, label="Truth Objects")
        ax.scatter(cur_bliss_full_cat["plocs"][0, bliss_matching[1].tolist(), 1], cur_bliss_full_cat["plocs"][0, bliss_matching[1].tolist(), 0], 
                marker="X", facecolors="lime", edgecolors="k", 
                alpha=1, s=100, linewidth=1, label="BLISS Detection")
        ax.scatter(cur_lsst_full_cat["plocs"][0, lsst_matching[1].tolist(), 1], cur_lsst_full_cat["plocs"][0, lsst_matching[1].tolist(), 0], 
                marker="P", facecolors="y", edgecolors="k", 
                alpha=1, s=80, linewidth=1, label="LSST Detection")

        ax.legend()
        plt.savefig(output_img_dir / f"image_{image_idx}_{i}_{j}.pdf", bbox_inches="tight")
        plt.close()