## Load SDSS image data

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from astropy.io import fits
from astropy.wcs import WCS
import numpy as np

f = fits.open('/home/regier/bliss/data/sdss/2583/2/136/frame-r-002583-2-0136.fits')
w = WCS(f[0].header)

# lower-left corner of the 100x100-pixel study area is at pixel (310, 630)
w.pixel_to_world(310, 630)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(f[0].data, origin='lower', cmap='gray_r')
print("Behold, the M2 globular cluster!")

In [None]:
logimage = np.log(f[0].data - f[0].data.min() + 1)
_ = plt.imshow(logimage, origin='lower', cmap='gray_r')

## Loading/viewing HST predictions

In [None]:
from bliss.catalog import FullCatalog
import torch
import numpy as np

hubble_cat_file = "/home/regier/hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt"
hubble_cat = np.loadtxt(hubble_cat_file, skiprows=3, usecols=(9,21,22))

hst_r_mag_all = torch.from_numpy(hubble_cat[:, 0])
ra = torch.from_numpy(hubble_cat[:, 1])
dec = torch.from_numpy(hubble_cat[:, 2])

plocs = FullCatalog.plocs_from_ra_dec(ra, dec, w)

In [None]:
from matplotlib.patches import Rectangle

plt.imshow(logimage, origin='lower', cmap='gray_r')
plt.scatter(plocs[:, 1], plocs[:, 0], s=10, c='r')
rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')
plt.gca().add_patch(rect)

In [None]:
in_bounds = (plocs[:, 1] > 310) & (plocs[:, 1] < 410) & (plocs[:, 0] > 630) & (plocs[:, 0] < 730)
in_bounds.sum()

In [None]:
oob = (plocs[:, 1] > 310) & (plocs[:, 1] < 410) & (plocs[:, 0] > 730) & (plocs[:, 0] < 830)
oob.sum()

In [None]:
plt.imshow(logimage, origin='lower', cmap='gray_r')
plt.scatter(plocs[:, 1][in_bounds], plocs[:, 0][in_bounds], s=10, c='r')
rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')
_ = plt.gca().add_patch(rect)

In [None]:
hst_r_mag = hst_r_mag_all[in_bounds]
plocs = plocs[in_bounds]

In [None]:
plocs_square = plocs - torch.tensor([630, 310])

from bliss.catalog import convert_mag_to_nmgy, convert_nmgy_to_mag
hst_r_nmgy = convert_mag_to_nmgy(hst_r_mag)

# these magnitudes are about 22% off: the hubble fw606 band filter curve
#  isn't exactly the sdss r band filter curve
sdss_r_nmgy = hst_r_nmgy * 1.22
sdss_r_mag = convert_nmgy_to_mag(sdss_r_nmgy)

In [None]:
d = {
    "plocs": plocs_square.unsqueeze(0),
    "star_fluxes": sdss_r_nmgy.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "galaxy_fluxes": sdss_r_nmgy.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,
    "n_sources": torch.tensor(plocs.shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs.shape[0]).unsqueeze(0).unsqueeze(2).long(),
}

In [None]:
true_cat_all = FullCatalog(100, 100, d)
true_cat_all.n_sources.sum()

In [None]:
true_tile_cat_all = true_cat_all.to_tile_catalog(2, 11)
true_tile_cat_all.n_sources.sum()

In [None]:
is_bright = sdss_r_mag < 22.565
is_bright.sum(), convert_mag_to_nmgy(22.565)

In [None]:
d = {
    "plocs": plocs_square[is_bright].unsqueeze(0),
    "star_fluxes": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "galaxy_fluxes": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,
    "n_sources": torch.tensor(plocs[is_bright].shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs[is_bright].shape[0]).unsqueeze(0).unsqueeze(2).long(),
}
true_cat = FullCatalog(100, 100, d)
true_cat.n_sources.sum()

In [None]:
true_tile_cat = true_cat.to_tile_catalog(2, 5)
true_tile_cat.n_sources.sum()

## Making predictions with BLISS

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "4"

from pathlib import Path
from hydra import initialize, compose
from bliss.main import predict

environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path="../../case_studies/dependent_tiling/", version_base=None):
    cfg = compose("m2_config", {
        "encoder.tiles_to_crop=3",
        "predict.weight_save_path=/home/regier/bliss/output/NewFluxPrior/version_0/checkpoints/best_encoder.ckpt"
        })

In [None]:
bliss_cats = predict(cfg.predict)
bliss_cat, = bliss_cats.values()
true_cat.n_sources.sum(), bliss_cat.n_sources.sum()

In [None]:
from hydra.utils import instantiate

matcher = instantiate(cfg.encoder.matcher)
metrics = instantiate(cfg.encoder.metrics)

In [None]:
matching = matcher.match_catalogs(true_cat, bliss_cat)
metric = metrics(true_cat, bliss_cat, matching)
metric["detection_recall"], metric["detection_precision"], metric["detection_f1"]

In [None]:
for name, metric in metrics.items():
    metric.plot()

In [None]:
cfg.encoder.use_checkerboard = False
bliss_cats = predict(cfg.predict)
bliss_cat_marginal, = bliss_cats.values()
matching = matcher.match_catalogs(true_cat, bliss_cat_marginal)
metric = metrics(true_cat, bliss_cat_marginal, matching)
for name, m in metrics.items():
    m.plot()

metric["detection_recall"], metric["detection_precision"], metric["detection_f1"]

## Assess the model and BLISS fit visually

In [None]:

from hydra.utils import instantiate

dataset = instantiate(cfg.predict.dataset)
dataset.prepare_data()

In [None]:
obs_image = torch.from_numpy(dataset[0]["image"][2][6:-6, 6:-6])
plt.imshow(obs_image)
_ = plt.colorbar()

In [None]:
simulator = instantiate(cfg.simulator)
truth_images, _, _, _ = simulator.image_decoder.render_images(true_tile_cat_all, [(2583, 2, 136)])

In [None]:
true_recon_all = truth_images[0][2] + dataset[0]["background"][2][6:-6, 6:-6]
plt.imshow(true_recon_all)
_ = plt.colorbar()

In [None]:
simulator = instantiate(cfg.simulator)
truth_images, _, _, _ = simulator.image_decoder.render_images(true_tile_cat, [(2583, 2, 136)])

In [None]:
true_recon = truth_images[0][2] + dataset[0]["background"][2][6:-6, 6:-6]
plt.imshow(true_recon)
_ = plt.colorbar()

In [None]:
bliss_tile_cat = bliss_cat.to_tile_catalog(2, 5)
bliss_images, _, _, _ = simulator.image_decoder.render_images(bliss_tile_cat, [(2583, 2, 136)])

In [None]:
bliss_recon = bliss_images[0, 2] + dataset[0]["background"][2][6:-6, 6:-6]
plt.imshow(bliss_recon)
_ = plt.colorbar()

# Flux Prior Elicitation

In [None]:
hst_oob = hst_r_mag_all[oob]
hst_oob_nmgy = convert_mag_to_nmgy(hst_oob) * 1.22
hst_oob_mag = convert_nmgy_to_mag(hst_oob_nmgy)
training_data = hst_oob_nmgy[hst_oob_mag < 24]
training_data.shape

In [None]:
from scipy.stats import pareto
alpha, loc, scale = pareto.fit(training_data)
alpha, loc, scale

In [None]:
from scipy.stats import truncpareto

x = np.logspace(hst_oob_nmgy.log10().min(), hst_oob_nmgy.log10().max(), num=100)

_ = plt.plot(x, pareto.pdf(x, alpha, loc, scale), 'r-', lw=3, alpha=0.7, label='new prior')
_ = plt.plot(x, truncpareto.pdf(x, 0.5, 1014, 0, 0.63), 'g-', lw=3, alpha=0.7, label='old prior')
_ = plt.hist(hst_oob_nmgy, log=True, bins=100, label='star_fluxes histogram', density=True)
plt.legend()

In [None]:
x = np.logspace(hst_oob_nmgy.log10().min(), hst_oob_nmgy.log10().max(), num=100)

_ = plt.plot(x, pareto.pdf(x, alpha, loc, scale), 'r-', lw=3, alpha=0.7, label='new prior')
_ = plt.plot(x, truncpareto.pdf(x, 0.5, 1014, 0, 0.63), 'g-', lw=3, alpha=0.7, label='old prior')
plt.legend()
plt.loglog()