## Load SDSS image data

In [None]:
from astropy.io import fits
from astropy.wcs import WCS

f = fits.open('/home/regier/bliss/data/sdss/2583/2/136/frame-r-002583-2-0136.fits')
w = WCS(f[0].header)

# lower-left corner of the 100x100-pixel study area is at pixel (310, 630)
w.pixel_to_world(310, 630)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(f[0].data, origin='lower', cmap='gray_r')

## Loading/viewing HST predictions

In [None]:
import pandas as pd

# catalog from https://catalogs.mast.stsci.edu/hsc/
# target: 323.3357504,-0.807026
# radius of 32 arcsec; v3; detailed; MagAuto <= 24
fn = '/home/regier/bliss/case_studies/dependent_tiling/HSC-10_20_2023.csv'
hst_catalog = pd.read_csv(fn)
hst_catalog.head()

In [None]:
from bliss.catalog import FullCatalog
import torch

ra = torch.from_numpy(hst_catalog['MatchRA'].to_numpy())
dec = torch.from_numpy(hst_catalog['MatchDec'].to_numpy())
plocs = FullCatalog.plocs_from_ra_dec(ra, dec, w)

In [None]:
from matplotlib.patches import Rectangle

plt.imshow(f[0].data, origin='lower', cmap='gray_r')
plt.scatter(plocs[:, 1], plocs[:, 0], s=10, c='r')
rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')
plt.gca().add_patch(rect)

In [None]:
in_bounds = (plocs[:, 1] > 310) & (plocs[:, 1] < 410) & (plocs[:, 0] > 630) & (plocs[:, 0] < 730)

In [None]:
plt.imshow(f[0].data, origin='lower', cmap='gray_r')
plt.scatter(plocs[:, 1][in_bounds], plocs[:, 0][in_bounds], s=10, c='r')
rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')
plt.gca().add_patch(rect)

In [None]:
hst_inbounds = hst_catalog[in_bounds.numpy()]
hst_inbounds

In [None]:
hcat_f814w = hst_inbounds[hst_inbounds.Filter == "F814W"]
hcat_f814w.shape

In [None]:
# The HST's F606W band is like the SDSS r band, but a bit broader
hcat_f606w = hst_inbounds[hst_inbounds.Filter == "F606W"]
hcat_f606w.shape

In [None]:
import torch

hcat = hcat_f814w
# TODO: figure out Bryan's cutoff (1114 stars) and training the corresponding min_flux_threshold
hcat_bright = hcat[hcat.MagAper2 < 22.2]

ra = torch.from_numpy(hcat_bright.MatchRA.values)
dec = torch.from_numpy(hcat_bright.MatchDec.values)
mag = torch.from_numpy(hcat_bright.MagAper2.values)

plocs = FullCatalog.plocs_from_ra_dec(ra, dec, w)
plocs_square = plocs - torch.tensor([630, 310])

In [None]:
hcat.shape, hcat_bright.shape

In [None]:
d = {
    "plocs": plocs_square.unsqueeze(0),
    "star_fluxes": mag.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "galaxy_fluxes": mag.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,
    "n_sources": torch.tensor(plocs.shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs.shape[0]).unsqueeze(0).unsqueeze(2).long(),
}
true_cat = FullCatalog(100, 100, d)
true_cat.n_sources.sum()

In [None]:
true_tile_cat = true_cat.to_tile_catalog(2, 5)
true_tile_cat.n_sources.sum()

In [None]:
true_tile_cat1 = true_tile_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)
true_tile_cat1.n_sources.sum()

## Making predictions with BLISS

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "1"

from pathlib import Path
from hydra import initialize, compose
from bliss.main import predict


environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path="../../case_studies/dependent_tiling/", version_base=None):
    cfg = compose("config", {"encoder.tiles_to_crop=3"})

bliss_cats = predict(cfg.predict)
bliss_cat, = bliss_cats.values()

In [None]:
from bliss.encoder.metrics import BlissMetrics, MetricsMode

metrics = BlissMetrics(
    mode=MetricsMode.FULL, slack=0.5, survey_bands=[0, 1, 2, 3, 4]
)

In [None]:
# TODO: require flux within 0.5 mag (as in Bryan's code) for matches and use L2 distance?
metric = metrics(true_cat, bliss_cat)
metric["detection_recall"], metric["detection_precision"], metric["f1"]

In [None]:
metric = metrics(true_tile_cat1.to_full_catalog(), bliss_cat)
metric["detection_recall"], metric["detection_precision"], metric["f1"]

### marginal

marginal: (0.5738866329193115, 0.47767481207847595, 0.521379292011261)

marginal (brightest): (0.5853960514068604, 0.3984835743904114, 0.4741854667663574)

### dependent

dependent: (0.5829959511756897, 0.5092838406562805, 0.5436526536941528)

dependent (brightest): (0.5952970385551453, 0.4252873659133911, 0.4961320757865906)



In [None]:
from hydra.utils import instantiate

encoder = instantiate(cfg.encoder)
encoder.load_state_dict(torch.load(cfg.predict.weight_save_path))
dataset = instantiate(cfg.predict.dataset)
dataset.prepare_data()

In [None]:
batch = {
    "images": torch.from_numpy(dataset[0]["image"]).unsqueeze(0),
    "background": torch.from_numpy(dataset[0]["background"]).unsqueeze(0),
}
x_features = encoder.get_features(batch)
x_cat_second = encoder.second_net(x_features)
pred_second = encoder.make_layer(x_cat_second)
est_cat2_uncropped = pred_second.sample(use_mode=True)
est_cat2 = est_cat2_uncropped.symmetric_crop(cfg.encoder.tiles_to_crop)


In [None]:
metric = metrics(true_cat, est_cat2.to_full_catalog())
metric["detection_recall"], metric["detection_precision"], metric["f1"]