# Compare BLISS and Photo Catalog against DECaLS on SDSS image

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from os import environ
import torch
import numpy as np
import pandas as pd
from matplotlib.markers import MarkerStyle
import matplotlib.pyplot as plt

from pathlib import Path

from hydra import initialize, compose
from hydra.utils import instantiate

from bliss.encoder.encoder import Encoder
from bliss.surveys.sdss import PhotoFullCatalog, SloanDigitalSkySurvey
from bliss.surveys.decals import TractorFullCatalog
from bliss.align import align
from bliss.encoder.plots import plot_plocs
from bliss.encoder.metrics import CatalogMetrics

In [None]:
environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path=".", version_base=None):
    cfg = compose("config")

## Load model and data

In [None]:
model: Encoder = instantiate(cfg.encoder)
model.load_state_dict(torch.load("../../data/pretrained_models/psf_aware.pt"))
model.eval();

In [None]:
# load SDSS catalog, image and background, and WCS
run, camcol, field = 94, 1, 12
sdss = SloanDigitalSkySurvey("/data/sdss", run, camcol, (field,))
wcs = sdss[0]["wcs"]

image = align(sdss[0]['image'], wcs)[None]
background = align(sdss[0]['background'], wcs)[None]

# crop to center, ensure dims are multiples of 16
_, _, height, width = image.shape
min_h, min_w = (height // 64) * 16, (width // 64) * 16
max_h, max_w = min_h * 3, min_w * 3
cropped_image = image[:, :, min_h:max_h, min_w:max_w]
cropped_background = background[:, :, min_h:max_h, min_w:max_w]

### Load PSF parameters

In [None]:
simulator = instantiate(cfg.simulator)
decoder = simulator.image_decoder

psfs = [decoder.psf_galsim[(run, camcol, field)]]
psf_params = decoder.psf_params[(run, camcol, field)][None]

### Load Photo and DECaLS catalogs based on SDSS image

In [None]:
crop_px = model.tiles_to_crop * model.tile_slen
ra_lim, dec_lim = wcs[2].all_pix2world((min_w + crop_px, max_w - crop_px), (min_h + crop_px, max_h - crop_px), 0)  # don't include areas that will be cropped

photo_cat = PhotoFullCatalog.from_file(
    cfg.paths.sdss,
    run=run, camcol=camcol, field=field, sdss_obj=sdss
).restrict_by_ra_dec(ra_lim, dec_lim)

decals_path = Path(cfg.paths.decals) / "tractor-3366m010.fits"
decals_cat = TractorFullCatalog.from_file(decals_path, ra_lim, dec_lim, wcs=wcs[2])

## Make predictions

In [None]:
batch = {
    "images": torch.from_numpy(cropped_image).float(),
    "background": torch.from_numpy(cropped_background).float(),
    "psf_params": psf_params.float()
}

In [None]:
with torch.no_grad():
    x_cat_marginal, _ = encoder.get_marginal(batch)
    pred = encoder.get_predicted_dist(x_cat_marginal)
    bliss_cat = model.variational_mode(pred)

# adjust locations to account for cropped tiles and cropped image
bliss_cat["plocs"] = bliss_cat["plocs"] + torch.tensor([crop_px + min_h, crop_px + min_w])

In [None]:
print(f"{bliss_cat["n_sources"].item()} light sources predicted by bliss")
print(f"{photo_cat["n_sources"].item()} light sources in PhotoCatalog")
print(f"{decals_cat["n_sources"].item()} light sources in DECaLS catalog")

### Plot predictions

In [None]:
def three_way_matching(pred_cat, comp_cat, gt_cat, slack=1):
    """Performs a 3-way matching between two catalogs and a ground truth catalog.

    Args:
        pred_cat: predicted catalog
        comp_cat: catalog to compare to
        gt_cat: catalog to use as "ground truth"
        slack: L2 distance threshold for matching objects

    Returns:
        Dict: a dictionary of matches between sets of catalogs.
            gt_all: all gt sources that matched either pred or comp
            gt_pred_only: gt sources that matched pred but not comp
            gt_comp_only: gt sources that matched comp but not pred
            pred_only: pred sources that did not match gt
            comp_only: comp sources that did not match gt
    """
    gt_locs, pred_locs, comp_locs = gt_cat["plocs"][0], pred_cat["plocs"][0], comp_cat["plocs"][0]
    # compute matches for both catalogs against gt
    metrics = CatalogMetrics("dummy", slack=slack)
    match_gt_pred = metrics.match_catalogs(gt_locs, pred_locs)
    match_gt_comp = metrics.match_catalogs(gt_locs, comp_locs)

    gt_pred_matches = match_gt_pred[0][match_gt_pred[2]]  # get indices to keep based on distance
    pred_gt_matches = match_gt_pred[1][match_gt_pred[2]]
    gt_comp_matches = match_gt_comp[0][match_gt_comp[2]]
    comp_gt_matches = match_gt_comp[1][match_gt_comp[2]]

    return {
        # in gt and pred or comp
        "gt_all": set(match_gt_pred[0]).union(match_gt_comp[0]),
        # in pred and gt, not in comp
        "gt_pred_only": set(gt_pred_matches).difference(gt_comp_matches),
        # in comp and gt, not in pred
        "gt_comp_only": set(gt_comp_matches).difference(gt_pred_matches),
        # in pred, not in gt
        "pred_only": set(range(pred_cat["n_sources"].item())).difference(pred_gt_matches),
        # in comp, not in gt
        "comp_only": set(range(comp_cat["n_sources"].item())).difference(comp_gt_matches),
    }


In [None]:
# compare bliss and photo against union of detectable sources in decals
matches = three_way_matching(bliss_cat, photo_cat, decals_cat, slack=2)

In [None]:
fig, ax = plt.subplots(figsize=(16, 12))
extent = np.array([min_w, max_w, min_h, max_h]) - 0.5
ax.imshow(np.log(cropped_image[0, 2] - cropped_image[0, 2].min() + 20), cmap="gray", extent=extent, origin="lower")

plot_plocs(decals_cat, ax, 0, list(matches["gt_all"]), color="r", marker="o", s=60, edgecolor="black", alpha=0.7, linewidth=0.5, label="DECaLS")
plot_plocs(photo_cat, ax, 0, "all", color="g", marker="X", s=50, edgecolor="black", linewidth=0.5, label="SDSS")
plot_plocs(bliss_cat, ax, 0, "all", color="y", marker="P", s=30, edgecolor="black", linewidth=0.5, label="BLISS")

params = {
    "marker": MarkerStyle("o", fillstyle="none"),
    "s": 200,
    "linewidth": 0.7,
}
colors = [
    '#08F7FE',  # cyan
    '#FE53BB',  # pink
    '#F5D300',  # yellow
    '#00ff41', # matrix green
]
#plot_plocs(decals_cat, ax, 0, matches["all"], c=colors[0], label="all", **params)
plot_plocs(decals_cat, ax, 0, list(matches["gt_pred_only"]), c=colors[0], label=r"(BLISS $\cup$ DECaLS) - SDSS", **params)
plot_plocs(decals_cat, ax, 0, list(matches["gt_comp_only"]), c=colors[1], label=r"(SDSS $\cup$ DECaLS) - BLISS", **params)
plot_plocs(bliss_cat, ax, 0, list(matches["pred_only"]), c=colors[2], label=r"BLISS - DECaLS", **params)
plot_plocs(photo_cat, ax, 0, list(matches["comp_only"]), c=colors[3], label=r"SDSS - DECaLS", **params)

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc="upper center", ncol=7, bbox_to_anchor=(0.0, 0.06, 1, 1), fontsize=10)

The sources circled in light blue are BLISS predictions that match a source in DECaLS, but are not present in SDSS. Conversely, the light purple circles indicate sources present in SDSS and DECaLS but not predicted by BLISS. Yellow circles are false positives predicted by BLISS, and green circles denote false positives in SDSS.

## 3-way metrics computation

In [None]:
metrics = CatalogMetrics(mode="matching", slack=3)
bliss_metrics = metrics(decals_cat, bliss_cat)
photo_metrics = metrics(decals_cat, photo_cat)

df = pd.DataFrame.from_dict({"bliss": bliss_metrics, "sdss": photo_metrics}, orient="index")
df