## Load SDSS image data

In [None]:
from astropy.io import fits
from astropy.wcs import WCS

f = fits.open('/home/regier/bliss/data/sdss/2583/2/136/frame-r-002583-2-0136.fits')
w = WCS(f[0].header)

# lower-left corner of the 100x100-pixel study area is at pixel (310, 630)
w.pixel_to_world(310, 630)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(f[0].data, origin='lower', cmap='gray_r')

## Loading/viewing HST predictions

In [None]:
from bliss.catalog import FullCatalog
import torch
import numpy as np

hubble_cat_file = "/home/regier/hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt"
hubble_cat = np.loadtxt(hubble_cat_file, skiprows=3, usecols=(9,21,22))

rmag = torch.from_numpy(hubble_cat[:, 0])
ra = torch.from_numpy(hubble_cat[:, 1])
dec = torch.from_numpy(hubble_cat[:, 2])

plocs = FullCatalog.plocs_from_ra_dec(ra, dec, w)

In [None]:
from matplotlib.patches import Rectangle

plt.imshow(f[0].data, origin='lower', cmap='gray_r')
plt.scatter(plocs[:, 1], plocs[:, 0], s=10, c='r')
rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')
plt.gca().add_patch(rect)

In [None]:
in_bounds = (plocs[:, 1] > 310) & (plocs[:, 1] < 410) & (plocs[:, 0] > 630) & (plocs[:, 0] < 730)
in_bounds.sum()

In [None]:
plt.imshow(f[0].data, origin='lower', cmap='gray_r')
plt.scatter(plocs[:, 1][in_bounds], plocs[:, 0][in_bounds], s=10, c='r')
rect = Rectangle((310, 630), 100, 100, linewidth=1, edgecolor='b', facecolor='none')
plt.gca().add_patch(rect)

In [None]:
rmag = rmag[in_bounds]
plocs = plocs[in_bounds]

In [None]:
plocs_square = plocs - torch.tensor([630, 310])

from bliss.utils.flux_units import convert_mag_to_nmgy
r_fluxes_nmgy = convert_mag_to_nmgy(rmag, nelec_per_nmgy=1)

# these magnitudes are about 22% off: the hubble fw606 band filter curve
#  isn't exactly the sdss r band filter curve
r_fluxes_nmgy *= 1.22

In [None]:
d = {
    "plocs": plocs_square.unsqueeze(0),
    "star_fluxes": r_fluxes_nmgy.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "galaxy_fluxes": r_fluxes_nmgy.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,
    "n_sources": torch.tensor(plocs.shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs.shape[0]).unsqueeze(0).unsqueeze(2).long(),
}

In [None]:
true_cat_all = FullCatalog(100, 100, d)
true_cat_all.n_sources.sum()

In [None]:
true_tile_cat_all = true_cat_all.to_tile_catalog(2, 11)
true_tile_cat_all.n_sources.sum()

In [None]:
# TODO: figure out Bryan's cutoff (1114 stars) and training the corresponding min_flux_threshold
is_bright = rmag < 22
is_bright.sum()

In [None]:
d = {
    "plocs": plocs_square[is_bright].unsqueeze(0),
    "star_fluxes": r_fluxes_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "galaxy_fluxes": r_fluxes_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,
    "n_sources": torch.tensor(plocs[is_bright].shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs[is_bright].shape[0]).unsqueeze(0).unsqueeze(2).long(),
}
true_cat = FullCatalog(100, 100, d)
true_cat.n_sources.sum()

In [None]:
true_tile_cat = true_cat.to_tile_catalog(2, 5)
true_tile_cat.n_sources.sum()

In [None]:
true_tile_cat1 = true_tile_cat.get_brightest_sources_per_tile(band=2, exclude_num=0)
true_tile_cat1.n_sources.sum()

## Making predictions with BLISS

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "4"

from pathlib import Path
from hydra import initialize, compose
from bliss.main import predict

environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path="../../case_studies/dependent_tiling/", version_base=None):
    cfg = compose("m2_config", {
        "encoder.tiles_to_crop=3",
        "predict.weight_save_path=/home/regier/bliss/output/sample2fixed/version_0/checkpoints/best_encoder.ckpt"
        })

bliss_cats = predict(cfg.predict)
bliss_cat, = bliss_cats.values()

In [None]:
true_cat.n_sources.sum(), bliss_cat.n_sources.sum()

In [None]:
from bliss.encoder.metrics import CatalogMetrics

metrics = CatalogMetrics(
    mode="matching", slack=1, survey_bands=[0, 1, 2, 3, 4]
)

In [None]:
# TODO: require flux within 0.5 mag (as in Bryan's code) for matches and use 0.5 distance (L2?)
metric = metrics(true_cat, bliss_cat)
metric["detection_recall"], metric["detection_precision"], metric["f1"]

In [None]:
metric = metrics(true_tile_cat1.to_full_catalog(), bliss_cat)
metric["detection_recall"], metric["detection_precision"], metric["f1"]

### marginal

marginal: (0.5738866329193115, 0.47767481207847595, 0.521379292011261)

marginal (brightest): (0.5853960514068604, 0.3984835743904114, 0.4741854667663574)

### dependent

dependent: (0.5829959511756897, 0.5092838406562805, 0.5436526536941528)

dependent (brightest): (0.5952970385551453, 0.4252873659133911, 0.4961320757865906)



In [None]:
from hydra.utils import instantiate

encoder = instantiate(cfg.encoder)
encoder.load_state_dict(torch.load(cfg.predict.weight_save_path)["state_dict"])
dataset = instantiate(cfg.predict.dataset)
dataset.prepare_data()

## Tune the prior

summary statistics for default prior

### real m2 ###

```raw_images[0, 0].min().item()
740.6287841796875

raw_images[0, 0].max().item()
123561.09375

(raw_images[0, 0, 0,20:100,20:100] < 740).sum()
tensor(0, device='cuda:0')

(raw_images[0, 0, 0,20:100,20:100] < 800).sum()
tensor(1, device='cuda:0')

(raw_images[0, 0, 0,20:100,20:100] < 900).sum()
tensor(166, device='cuda:0')
```


### synthetic m2 ###

```[raw_images[i, 0].min().item() for i in range(5)]
[606.47802734375, 609.2470703125, 578.3068237304688, 605.1005859375, 603.1253662109375]

[raw_images[i, 0].max().item() for i in range(5)]
[110249.140625, 114663.9453125, 113192.21875, 108567.078125, 102371

(raw_images[0, 0, 0,20:100,20:100] < 640).sum()
tensor(1, device='cuda:0')

(raw_images[0, 0, 0,20:100,20:100] < 740).sum()
tensor(64, device='cuda:0')

(raw_images[0, 0, 0,20:100,20:100] < 800).sum()
tensor(253, device='cuda:0')

(raw_images[0, 0, 0,20:100,20:100] < 900).sum()
tensor(756, device='cuda:0')
```

In [None]:
obs_image = torch.from_numpy(dataset[0]["image"][2][6:-6, 6:-6])
plt.imshow(obs_image)
(obs_image - dataset[0]["background"][2, 6:-6, 6:-6]).abs().sum() / obs_image.sum()

In [None]:
simulator = instantiate(cfg.simulator)
truth_images, _, _, _ = simulator.image_decoder.render_images(true_tile_cat_all, [(2583, 2, 136)])

In [None]:
true_recon = truth_images[0][2] + dataset[0]["background"][2][6:-6, 6:-6]
plt.imshow(true_recon)
(true_recon - obs_image).abs().sum() / obs_image.sum()

In [None]:
true_cat_all["star_fluxes"][0, :, 2].max() * 856, truth_images[0][2][74:83, 64:73].sum()

In [None]:
(obs_image - truth_images[0][2]).median(), (obs_image - truth_images[0][2]).mean()

In [None]:
ss_obs = (obs_image - dataset[0]["background"][2, 6:-6, 6:-6])
ss_obs[74:83, 63:73].sum()

In [None]:
simulator = instantiate(cfg.simulator)
truth_images, _, _, _ = simulator.image_decoder.render_images(true_tile_cat, [(2583, 2, 136)])
(obs_image - truth_images[0][2]).median(), (obs_image - truth_images[0][2]).mean()

In [None]:
bliss_tile_cat = bliss_cat.to_tile_catalog(2, 2)
bliss_images, _, _, _ = simulator.image_decoder.render_images(bliss_tile_cat, [(2583, 2, 136)])


In [None]:
bliss_recon = 1.1 * bliss_images[0][2] + dataset[0]["background"][2, 6:-6, 6:-6]
(obs_image - bliss_recon).median(), (obs_image - bliss_recon).mean()

In [None]:
dataset[0]["background"][2, 6:-6, 6:-6]

In [None]:
plt.imshow(bliss_recon)
# TODO: use the rescaled bliss catalog to infer than ~710 is a better background; regenerate data; retrain (why?)
# TODO: compare marginal bliss catalog, checkerboard, and multidetect in terms of f1 (synthetic & m2)
# TODO: repeat analysis with samples of the approximate posterior (synthetic & m2)
# TODO: boxplot of diff in f1 for synthetic frames of various methods?
# TODO: require flux match of 0.5 mag

In [None]:
(
    r_fluxes_nmgy[is_bright].sum() * 856 / obs_image.sum(),  # flux prop from bright sources
    r_fluxes_nmgy[(rmag > 22) & (rmag < 23)].sum() * 856 / obs_image.sum(),  # flux prop from dim sources
    r_fluxes_nmgy[rmag > 23].sum() * 856 / obs_image.sum(),  # flux prop from very dim sources
    dataset[0]["background"][2, 6:-6, 6:-6].sum() / obs_image.sum()  # flux prop from background
)

In [None]:
(r_fluxes_nmgy[rmag > 23].sum() * 856) / 100

In [None]:
plt.hist(r_fluxes_nmgy, log=True)

In [None]:
from scipy.stats import truncpareto
rv = truncpareto(0.5, 1014)

In [None]:
r_counts = r_fluxes_nmgy * 856
r_counts = r_counts[r_counts > 1]

In [None]:
rv.cdf(10)

In [None]:
(r_counts < 10).sum()