## Load SDSS image data

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from astropy.io import fits
from astropy.wcs import WCS
import numpy as np

f = fits.open('/home/regier/bliss/data/sdss/2583/2/136/frame-r-002583-2-0136.fits')
w = WCS(f[0].header)

# lower-left corner of the 100x100-pixel study area is at pixel (310, 630)
w.pixel_to_world(310, 630)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(f[0].data, origin='lower', cmap='Greys_r')
print("Behold, the M2 globular cluster!")

In [None]:
logimage = np.log(f[0].data - f[0].data.min() + 1)
plt.imshow(logimage, origin='lower', cmap='Greys_r');

In [None]:
from matplotlib.patches import Rectangle

plt.imshow(logimage, origin='lower', cmap='Greys_r')
rect = Rectangle((310, 630), 100, 100, linewidth=2, edgecolor='r', facecolor='none')
_ = plt.gca().add_patch(rect)
plt.xticks([])
plt.yticks([]);

## Loading/viewing HST predictions

In [None]:
from bliss.catalog import FullCatalog
import torch
import numpy as np

hubble_cat_file = "/home/regier/hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt"
hubble_cat = np.loadtxt(hubble_cat_file, skiprows=3, usecols=(9,21,22))

hst_r_mag_all = torch.from_numpy(hubble_cat[:, 0])
ra = torch.from_numpy(hubble_cat[:, 1])
dec = torch.from_numpy(hubble_cat[:, 2])

plocs_all = FullCatalog.plocs_from_ra_dec(ra, dec, w)

In [None]:
original = f[0].data[630:730, 310:410]

arcsinh_median = np.arcsinh((original - np.median(original)))

clipped = original.clip(max=np.quantile(original, 0.98))
arcsinh_clipped = np.arcsinh((clipped - np.median(clipped)));

In [None]:


fig, axs = plt.subplots(1, 3, figsize=(10, 10))

images = [original, arcsinh_median, arcsinh_clipped]
titles = ['original', 'arcsinc', 'arcsinc with clipping']

for i, img in enumerate(images):
    ax = axs[i]
    ax.imshow(img, origin='lower', cmap='Greys_r')
    ax.set_title(titles[i])
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

In [None]:
in_bounds = (plocs_all[:, 1] > 310) & (plocs_all[:, 1] < 410)
in_bounds &= (plocs_all[:, 0] > 630) & (plocs_all[:, 0] < 730)
in_bounds.sum()

In [None]:
hst_r_mag = hst_r_mag_all[in_bounds]
plocs = plocs_all[in_bounds]

In [None]:
plocs_square = plocs - torch.tensor([630, 310])

from bliss.catalog import convert_mag_to_nmgy, convert_nmgy_to_mag
hst_r_nmgy = convert_mag_to_nmgy(hst_r_mag)

# these magnitudes are about 15% off: the hubble fw606 band filter curve
#  isn't exactly the sdss r band filter curve
sdss_r_nmgy = hst_r_nmgy * 1.15
sdss_r_mag = convert_nmgy_to_mag(sdss_r_nmgy)

In [None]:
d = {
    "plocs": plocs_square.unsqueeze(0),
    "star_fluxes": sdss_r_nmgy.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "galaxy_fluxes": sdss_r_nmgy.unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,
    "n_sources": torch.tensor(plocs.shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs.shape[0]).unsqueeze(0).unsqueeze(2).long(),
}

In [None]:
true_cat_all = FullCatalog(100, 100, d)
true_cat_all.n_sources.sum()

In [None]:
true_tile_cat_all = true_cat_all.to_tile_catalog(2, 11)
true_tile_cat_all.n_sources.sum()

In [None]:
is_bright = sdss_r_mag < 22.565
is_bright.sum(), convert_mag_to_nmgy(22.565)

In [None]:
d = {
    "plocs": plocs_square[is_bright].unsqueeze(0),
    "star_fluxes": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]),
    "galaxy_fluxes": sdss_r_nmgy[is_bright].unsqueeze(0).unsqueeze(2).expand([-1, -1, 5]) * 0.0,
    "n_sources": torch.tensor(plocs[is_bright].shape[0]).unsqueeze(0),
    "source_type": torch.zeros(plocs[is_bright].shape[0]).unsqueeze(0).unsqueeze(2).long(),
}
true_cat = FullCatalog(100, 100, d)
true_cat.n_sources.sum()

In [None]:
true_tile_cat = true_cat.to_tile_catalog(2, 5)
true_tile_cat.n_sources.sum()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10, 10))

cutoffs = [20, 22.065, 24]

for i, cutoff in enumerate(cutoffs):
    is_bright = sdss_r_mag < cutoff
    plocs_square_bright = plocs_square[is_bright]
    ax = axs[i]
    ax.imshow(arcsinh_clipped, origin='lower', cmap='Greys_r')
    ax.scatter(plocs_square_bright[:, 1], plocs_square_bright[:, 0], s=5, c='r')
    ax.set_title(f"magnitude < {cutoff}")
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 100)
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()


## Making predictions on M2 with BLISS

In [None]:
from os import environ
environ["CUDA_VISIBLE_DEVICES"] = "2"

from pathlib import Path
from hydra import initialize, compose
from bliss.main import predict

environ["BLISS_HOME"] = str(Path().resolve().parents[1])
with initialize(config_path="../../case_studies/dependent_tiling/", version_base=None):
    cfg = compose("m2_config", {
        "encoder.tiles_to_crop=3",
        "predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt",
 #       "encoder.double_detect=false"
        })

In [None]:
bliss_cats = predict(cfg.predict)
bliss_cat_pair, = bliss_cats.values()
bliss_cat = bliss_cat_pair["mode_cat"].to_full_catalog()
true_cat.n_sources.sum(), bliss_cat.n_sources.sum()

In [None]:
from hydra.utils import instantiate

matcher = instantiate(cfg.encoder.matcher)
metrics = instantiate(cfg.encoder.metrics)

In [None]:
matching = matcher.match_catalogs(true_cat, bliss_cat)
metric = metrics(true_cat, bliss_cat, matching)
metric["detection_recall"], metric["detection_precision"], metric["detection_f1"]

In [None]:
starnet = {
    "recall": [0.95, 0.91, 0.79, 0.7, 0.7, 0.62, 0.59, 0.4],
    "precision": [0.96, 0.97, 0.79, 0.8, 0.68, 0.6, 0.45, 0.35]
}

starnet["f1"] = 2 * np.array(starnet["recall"]) * np.array(starnet["precision"])
starnet["f1"] /= (np.array(starnet["recall"]) + np.array(starnet["precision"]))

for name, metric in metrics.items():
    metric.plot()


In [None]:
Check calibration:

In [None]:
%%capture
counts = []

for i in range(15):
    bliss_cats = predict(cfg.predict)
    bliss_cat_pair, = bliss_cats.values()
    bliss_cat = bliss_cat_pair["sample_cat"].to_full_catalog()
    counts.append(bliss_cat.n_sources.sum())

counts

In [None]:
cs = torch.tensor([c.item() for c in counts]).float()
cs.mean(), cs.quantile(0.05), cs.quantile(0.95)

### Independent tiling (baseline)

In [None]:
from copy import deepcopy
cfg2 = deepcopy(cfg)
cfg2.encoder.use_checkerboard = False

In [None]:
bliss_cats = predict(cfg2.predict)
bliss_cat_pair, = bliss_cats.values()
bliss_cat_marginal = bliss_cat_pair["mode_cat"].to_full_catalog()
matching = matcher.match_catalogs(true_cat, bliss_cat_marginal)
metric = metrics(true_cat, bliss_cat_marginal, matching)

m = metrics["DetectionPerformance"]
m.plot()

metric["detection_recall"], metric["detection_precision"], metric["detection_f1"]

In [None]:
recall = m.n_true_matches / m.n_true_sources
precision = m.n_est_matches / m.n_est_sources
f1 = 2 * precision * recall / (precision + recall)
real = {"recall": recall, "precision": precision, "f1": f1}

## BLISS performance on synthetic data

In [None]:
with initialize(config_path="../../case_studies/dependent_tiling/", version_base=None):
    cfg3 = compose("m2_config", {
        "train.trainer.logger=null",
        "train.trainer.max_epochs=0",
        "train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt",
        "cached_simulator.cached_data_path=/data/scratch/regier/toy_m2",
        "+train.trainer.num_sanity_val_steps=0",
#        "encoder.double_detect=false"
    })

from bliss.main import train
train(cfg3.train)

In [None]:
from copy import deepcopy
cfg4 = deepcopy(cfg3)
cfg4.train.encoder.use_checkerboard = False
train(cfg4.train)

## Assess the model and BLISS fit visually

In [None]:

from hydra.utils import instantiate

dataset = instantiate(cfg.predict.dataset)
dataset.prepare_data()

In [None]:
obs_image = torch.from_numpy(dataset[0]["image"][2][6:-6, 6:-6])
plt.imshow(obs_image, origin='lower', cmap='Greys_r')
_ = plt.colorbar()

In [None]:
simulator = instantiate(cfg.simulator)
truth_images, _, _, _ = simulator.image_decoder.render_images(true_tile_cat_all, [(2583, 2, 136)])

In [None]:
true_recon_all = truth_images[0][2] + dataset[0]["background"][2][6:-6, 6:-6]
plt.imshow(true_recon_all, origin='lower', cmap='Greys_r')
_ = plt.colorbar()

In [None]:
simulator = instantiate(cfg.simulator)
truth_images, _, _, _ = simulator.image_decoder.render_images(true_tile_cat, [(2583, 2, 136)])

In [None]:
true_recon = truth_images[0][2] + dataset[0]["background"][2][6:-6, 6:-6]
plt.imshow(true_recon, origin='lower', cmap='Greys_r')
_ = plt.colorbar()

In [None]:
bliss_tile_cat = bliss_cat.to_tile_catalog(2, 5)
bliss_images, _, _, _ = simulator.image_decoder.render_images(bliss_tile_cat, [(2583, 2, 136)])

In [None]:
bliss_recon = bliss_images[0, 2] + dataset[0]["background"][2][6:-6, 6:-6]
plt.imshow(bliss_recon, origin='lower', cmap='Greys_r')
_ = plt.colorbar()

# Flux Prior Elicitation

In [None]:
oob = (plocs_all[:, 1] > 210) & (plocs_all[:, 1] < 510)
oob &= (plocs_all[:, 0] > 530) & (plocs_all[:, 0] < 830)
oob &= ~in_bounds
oob.sum() # some of this region (about half) is outside of our HST cat coverage

In [None]:
hst_oob = hst_r_mag_all[oob]
hst_oob_nmgy = convert_mag_to_nmgy(hst_oob) * 1.15
hst_oob_mag = convert_nmgy_to_mag(hst_oob_nmgy)
training_data = hst_oob_nmgy[hst_oob_mag < 24]
training_data.shape[0], training_data.max().item()

In [None]:
from scipy.stats import truncpareto
alpha, trunc, loc, scale = truncpareto.fit(training_data)
alpha, trunc, loc, scale

In [None]:
from scipy.stats import truncpareto

x = np.logspace(hst_oob_nmgy.log10().min(), hst_oob_nmgy.log10().max(), num=100)

_ = plt.plot(x, truncpareto.pdf(x, alpha, trunc, loc, scale), 'r-', lw=3, alpha=0.7, label='new prior')
_ = plt.plot(x, truncpareto.pdf(x, 0.5, 1014, 0, 0.63), 'g-', lw=3, alpha=0.7, label='old prior')
_ = plt.hist(hst_oob_nmgy, log=True, bins=100, label='star_fluxes histogram', density=True)
plt.legend()

In [None]:
_ = plt.plot(x, truncpareto.pdf(x, alpha, trunc, loc, scale), 'r-', lw=3, alpha=0.7, label='new prior')
_ = plt.plot(x, truncpareto.pdf(x, 0.5, 1014, 0, 0.63), 'g-', lw=3, alpha=0.7, label='old prior')
plt.legend()
plt.loglog()

In [None]:
samples = truncpareto.rvs(alpha, trunc, loc, scale, size=1500)
sorted(samples, reverse=True)[:10]


In [None]:
prior = instantiate(cfg.prior)
prior.sample().on_fluxes[0, :, :, :, 2].view(-1).topk(100)[0]

In [None]:
# estimate rate with oob data
(hst_oob_mag < 24).sum() / (4 * 1e4)

## Assess the two-point correlation function

In [None]:
from hydra import initialize, compose
from bliss.main import train

with initialize(config_path="../../case_studies/dependent_tiling/", version_base=None):
    cfg5 = compose("m2_config", {
        "train.trainer.logger=null",
        "train.trainer.max_epochs=0",
        "train.pretrained_weights=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt",
        "cached_simulator.cached_data_path=/data/scratch/regier/toy_m2",
        "+train.trainer.num_sanity_val_steps=0",
        "cached_simulator.splits=0:10/10:20/0:100",
#        "encoder.double_detect=false",
    })

cfg5.train.encoder.metrics.metrics = [{'_target_': 'case_studies.dependent_tiling.two_point_metric.TwoPointMetric'}]
train(cfg5.train)

Findings from spotchecking sources of nonzero two-point correlation above:
* two pairs of sources with modes within 1e-4: catty corner, need 4 color checkerboard
* one pair of sources with modes within 1e-2: consecutive columns in a row, near a corner; one source a second detect; need conditioning info to second detect
* one pair of sources with sampled modes within 0.1: a double detect solidly within a tile; high uncertainty about whether second exists (it doesn't); first correctly identified; second detected source hovers around the pixel (of 4) containing the source; need conditioning info for the second detect

In [None]:
from copy import deepcopy

cfg6 = deepcopy(cfg5)
cfg6.train.encoder.use_checkerboard = False
train(cfg6.train)

## Semi-synthetic M2 inference

In [None]:
with initialize(config_path="../../case_studies/dependent_tiling/", version_base=None):
    cfg = compose("m2_config", {
        "encoder.tiles_to_crop=3",
        "predict.weight_save_path=/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt",
 #       "encoder.double_detect=false"
        })

d2 = deepcopy(true_cat_all.to_dict())
d2["plocs"] += 6
true_cat_pad = FullCatalog(112, 112, d2)

truth_images, _, _, _ = simulator.image_decoder.render_images(
    true_cat_pad.to_tile_catalog(2, 11), [(2583, 2, 136)]
)

In [None]:
from torch.distributions import Normal

true_recon_all = truth_images[0] + dataset[0]["background"]
true_recon_all = Normal(true_recon_all, true_recon_all.sqrt()).sample()
plt.imshow(true_recon_all[2])
_ = plt.colorbar()

In [None]:
encoder = instantiate(cfg.encoder)
enc_state_dict = torch.load("/home/regier/bliss/output/new_log_transforms/version_0/checkpoints/best_encoder.ckpt")
enc_state_dict = enc_state_dict["state_dict"]
encoder.load_state_dict(enc_state_dict)
encoder.eval()

batch = {
    "images": true_recon_all.unsqueeze(0),
    "background": torch.from_numpy(dataset[0]["background"]).unsqueeze(0),
}

In [None]:
with torch.no_grad():
    mode_cat, sample_cat = encoder.predict_step(batch, 0).values()

mode_cat = mode_cat.to_full_catalog()
matching = matcher.match_catalogs(true_cat_all, mode_cat)
metric = metrics(true_cat_all, mode_cat, matching)
metric["detection_recall"], metric["detection_precision"], metric["detection_f1"]

In [None]:
# let's try it again but with a max of two source per tiles
mode_images, _, _, _ = simulator.image_decoder.render_images(
    mode_cat.to_tile_catalog(2, 2), [(2583, 2, 136)]
)
plt.imshow(mode_images[0, 2])
plt.colorbar();

In [None]:
encoder.use_checkerboard = False

with torch.no_grad():
    mode_cat, sample_cat = encoder.predict_step(batch, 0).values()

mode_cat = mode_cat.to_full_catalog()
matching = matcher.match_catalogs(true_cat_all, mode_cat)
metric = metrics(true_cat_all, mode_cat, matching)
metric["detection_recall"], metric["detection_precision"], metric["detection_f1"]

In [None]:
### filtered true catalog

true_tile_cat_pad = true_cat_pad.to_tile_catalog(2, 11)
true_tile_cat_pad = true_tile_cat_pad.filter_tile_catalog_by_flux(0.63)
true_tile_cat_pad = true_tile_cat_pad.get_brightest_sources_per_tile(band=2, top_k=2, exclude_num=0)

truth_images, _, _, _ = simulator.image_decoder.render_images(
    true_tile_cat_pad, [(2583, 2, 136)]
)

In [None]:
from torch.distributions import Normal

true_recon_all = truth_images[0] + dataset[0]["background"]
true_recon_all = Normal(true_recon_all, true_recon_all.sqrt()).sample()
plt.imshow(true_recon_all[2])
_ = plt.colorbar()

In [None]:
batch = {
    "images": true_recon_all.unsqueeze(0),
    "background": torch.from_numpy(dataset[0]["background"]).unsqueeze(0),
}

In [None]:
encoder.use_checkerboard = True

with torch.no_grad():
    mode_cat, sample_cat = encoder.predict_step(batch, 0).values()

# target_cat = true_tile_cat_pad.symmetric_crop(3).to_full_catalog()

mode_cat = mode_cat.to_full_catalog()
matching = matcher.match_catalogs(true_cat_all, mode_cat)
metric = metrics(true_cat_all, mode_cat, matching)
metric["detection_recall"], metric["detection_precision"], metric["detection_f1"]