## Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "1"

from astropy.io import fits
from astropy.wcs import WCS

from hydra import initialize, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf

from bliss.main import predict
from bliss.catalog import TileCatalog, FullCatalog

import torch
from pytorch_lightning.callbacks import Callback

import numpy as np
from matplotlib import pyplot as plt
from astropy.visualization import make_lupton_rgb

torch.set_grad_enabled(False)

ckpt = "/home/regier/bliss/tests/data/base_config_trained_encoder.pt"
with initialize(config_path=".", version_base=None):
    cfg0 = compose("config", {
        f"train.pretrained_weights={ckpt}",
        f"predict.weight_save_path={ckpt}",
        "cached_simulator.splits=0:80/80:90/99:100",
        "cached_simulator.num_workers=0",
    })

cfg_c4 = OmegaConf.merge(cfg0, {"encoder": {
    "use_checkerboard": True,
    "n_sampler_colors": 4
}})
cfg_c2 = OmegaConf.merge(cfg0, {"encoder": {
    "use_checkerboard": True,
    "n_sampler_colors": 2,
}})
cfg_c1 = OmegaConf.merge(cfg0, {"encoder": {
    "use_checkerboard": False,
}})

## Load and view the SDSS field

In [None]:
sdss = instantiate(cfg0.surveys.sdss, load_image_data=True)
sdss.prepare_data()
sdss_frame, = sdss.predict_dataloader()
obs_image = sdss_frame["images"][0]

In [None]:
rgb = make_lupton_rgb(obs_image[3], obs_image[2], obs_image[1], Q=0, stretch=0.1)
plt.imshow(rgb, origin='lower')

## Load and view SDSS predictions

In [None]:
from bliss.surveys.sdss import PhotoFullCatalog
from pathlib import Path

rcf = cfg0.surveys.sdss.fields[0]

run, camcol, field = rcf["run"], rcf["camcol"], rcf["fields"][0]
po_fn = f"photoObj-{run:06d}-{camcol}-{field:04d}.fits"
po_path = Path(cfg0.paths.sdss) / str(run) / str(camcol) / str(field) / po_fn

sdss_wcs = sdss[0]["wcs"][2]
photo_cat = PhotoFullCatalog.from_file(po_path, sdss_wcs, *obs_image[2].shape)
photo_cat["n_sources"]

## Load and view DECaLS predictions

In [None]:
from bliss.surveys.des import TractorFullCatalog

sdss_wcs = sdss[0]["wcs"][2]
decals_path = Path(cfg0.paths.des) / "336" / "3366m010" / "tractor-3366m010.fits"
decals_cat = TractorFullCatalog.from_file(decals_path, sdss_wcs, 1488, 2048)
decals_cat["n_sources"]

## Make and plot predictions with BLISS

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.memory_allocated() / 1e9  # show current memory usage in GB

In [None]:
encoder = instantiate(cfg0.train.encoder).cuda()
enc_state_dict = torch.load(cfg0.train.pretrained_weights)
if cfg0.train.pretrained_weights.endswith(".ckpt"):
    enc_state_dict = enc_state_dict["state_dict"]
encoder.load_state_dict(enc_state_dict)
encoder.eval()

batch = {
    "images": obs_image[:, :, :].unsqueeze(0).cuda(),
    "psf_params": sdss_frame["psf_params"].cuda(),
}

In [None]:
from bliss.catalog import convert_mag_to_nmgy

bliss_tile_cat = encoder.sample(batch, use_mode=True)
bliss_flux_filter_cat = bliss_tile_cat.filter_by_flux(convert_mag_to_nmgy(23))
bliss_cat = bliss_flux_filter_cat.to_full_catalog(4).to("cpu")
bliss_cat["n_sources"]

## Three-way performance scoring

In [None]:
photo_cat_box = photo_cat.filter_by_ploc_box(torch.zeros(2), 1488)
decals_cat_box = decals_cat.filter_by_ploc_box(torch.zeros(2), 1488)
bliss_cat_box = bliss_cat.filter_by_ploc_box(torch.zeros(2), 1488)

In [None]:
from bliss.encoder.metrics import CatalogMatcher

# Create a CatalogMatcher object
matcher = CatalogMatcher()

# Match the catalogs based on their positions
match_gt_pred = matcher.match_catalogs(decals_cat_box, bliss_cat_box)
match_gt_comp = matcher.match_catalogs(decals_cat_box, photo_cat_box)


In [None]:
len(matcher.match_catalogs(bliss_cat_box, photo_cat_box)[0][0])

In [None]:
# TP:
# both
# neither

# gt only
# comp only
# FP: pred only
# FP: gt only

matches = {
    # in gt and pred or comp
    "gt_all": set(match_gt_pred[0]).union(match_gt_comp[0]),
    # in pred and gt, not in comp
    "gt_pred_only": set(match_gt_pred[0]).difference(match_gt_comp[0]),
    # in comp and gt, not in pred
    "gt_comp_only": set(gt_comp_matches).difference(gt_pred_matches),
    # in pred, not in gt
    "pred_only": set(range(pred_cat["n_sources"].item())).difference(pred_gt_matches),
    # in comp, not in gt
    "comp_only": set(range(comp_cat["n_sources"].item())).difference(comp_gt_matches),
}

In [None]:
photo_cat_box["n_sources"], bliss_cat_box["n_sources"]