In [None]:
import numpy as np
import torch
from pathlib import Path
from astropy.io import fits
from matplotlib import pyplot as plt

from bliss.surveys.sdss import column_to_tensor, SDSSDownloader

# Load Fluxes from SDSS catalogs

In [None]:
from tqdm import tqdm

def load_color_distribution(sdss_fields, sdss_dir):
    star_flux_lst = []
    gal_flux_lst = []

    # load all star, galaxy fluxes relative to r-band required for sampling
    for run, camcol, field in tqdm(sdss_fields):

        # Set photoObj file path
        # NOTE: This is the necessary directory structure!
        sdss_path = Path(sdss_dir)
        field_dir = sdss_path / str(run) / str(camcol) / str(field)
        po_path = field_dir / f"photoObj-{run:06d}-{camcol:d}-{field:04d}.fits"

        if not po_path.exists():
            rcf = (run, camcol, field)
            SDSSDownloader(rcf, str(sdss_path)).download_catalog(rcf)
        msg = (
            f"{po_path} does not exist. "
            + "Make sure data files are available for fields specified in config."
        )
        assert Path(po_path).exists(), msg
        po_fits = fits.getdata(po_path)

        # retrieve object-specific information for ratio computing
        objc_type = column_to_tensor(po_fits, "objc_type").numpy()
        thing_id = column_to_tensor(po_fits, "thing_id").numpy()

        # mask fluxes based on object identity & validity
        galaxy_bools = (objc_type == 3) & (thing_id != -1)
        star_bools = (objc_type == 6) & (thing_id != -1)
        star_fluxes = column_to_tensor(po_fits, "psfflux") * star_bools.reshape(-1, 1)
        gal_fluxes = column_to_tensor(po_fits, "cmodelflux") * galaxy_bools.reshape(-1, 1)
        fluxes = star_fluxes + gal_fluxes

        for obj, _ in enumerate(objc_type):
            if thing_id[obj] != -1 and torch.all(fluxes[obj] > 0):
                if objc_type[obj] == 6:
                    star_flux_lst.append(fluxes[obj])
                elif objc_type[obj] == 3:
                    gal_flux_lst.append(fluxes[obj])

    ref_band = 2  # r-band
    star_fluxes = torch.stack(star_flux_lst, dim=0)[:, ref_band]
    gal_fluxes = torch.stack(gal_flux_lst, dim=0)[:, ref_band]

    return star_fluxes, gal_fluxes

In [None]:
# load color distributions from sdss files
fields = (
    [(94, 1, f) for f in np.arange(12, 492, 10, dtype=int)] +
    [(125, 1, f) for f in np.arange(15, 575, 10, dtype=int)] + 
    [(752, 1, f) for f in np.arange(15, 695, 10, dtype=int)] +
    [(3900, 6, f) for f in np.arange(16, 606, 10, dtype=int)]
)
star_fluxes, gal_fluxes = load_color_distribution(fields, "/data/scratch/sdss")

# Determine Reference-Band Star Flux Prior

In [None]:
from bliss.catalog import convert_nmgy_to_mag
star_mags = convert_nmgy_to_mag(star_fluxes)
stars_to_use = star_mags < 23.0
star_fluxes_to_use = star_fluxes[stars_to_use]

from scipy.stats import pareto
star_alpha, star_loc, star_scale = pareto.fit(star_fluxes_to_use)

star_scale, star_alpha, star_loc

In [None]:
from scipy.stats import truncpareto
star_exponent, star_truncation, star_loc_tp, star_scale_tp = truncpareto.fit(star_fluxes_to_use)

star_exponent, star_truncation, star_loc_tp, star_scale_tp

In [None]:
x = np.logspace(star_fluxes_to_use.log10().min(), 1.05 * star_fluxes_to_use.log10().max(), num=200)

pdf_vals = pareto.pdf(x, star_alpha, star_loc, star_scale)
_ = plt.plot(x, pdf_vals, 'r-', lw=5, alpha=0.6, label='pareto pdf')

pdf_vals_trunc = truncpareto.pdf(x, star_exponent, star_truncation, star_loc_tp, star_scale_tp)
_ = plt.plot(x, pdf_vals_trunc, 'g.', lw=5, alpha=0.6, label='trucpareto pdf')

_ = plt.hist(star_fluxes_to_use, log=True, bins=200, label='star_fluxes histogram', density=True)
plt.legend();

In [None]:
from torch.distributions import Pareto
n = star_fluxes_to_use.size(0)
samples = Pareto(star_scale, star_alpha).sample((n,)) + star_loc

# tune `stars_to_use` to make these distributions match for various `k`
k = 5
star_fluxes.topk(k,largest=True)[0], samples.topk(k, largest=True)[0]

# Determine Reference-Band Galaxy Flux Prior

In [None]:
gal_mags = convert_nmgy_to_mag(gal_fluxes)
gals_to_use = gal_mags < 23.0
gal_fluxes_to_use = gal_fluxes[gals_to_use]

gal_alpha, gal_loc, gal_scale = pareto.fit(gal_fluxes_to_use)
gal_alpha, gal_loc, gal_scale

In [None]:
gal_exponent, gal_trunc, gal_loc_tp, gal_scale_tp = truncpareto.fit(gal_fluxes_to_use)
gal_exponent, gal_trunc, gal_loc_tp, gal_scale_tp

In [None]:
x = np.logspace(gal_fluxes.log10().min(), 1.05 * gal_fluxes.log10().max(), num=200)

pdf_vals = pareto.pdf(x, gal_alpha, gal_loc, gal_scale)
_ = plt.plot(x, pdf_vals, 'r-', lw=5, alpha=0.6, label='pareto pdf')

pdf_vals = truncpareto.pdf(x, gal_exponent, gal_trunc, gal_loc_tp, gal_scale_tp)
_ = plt.plot(x, pdf_vals, 'g.', lw=5, alpha=0.6, label='truncpareto pdf')

_ = plt.hist(gal_fluxes_to_use, log=True, bins=200, label='galaxy fluxes histogram', density=True)
plt.legend();

In [None]:
n = gal_fluxes.size(0)
sample = Pareto(gal_scale, gal_alpha).sample((n,)) + gal_loc

sample.topk(10, largest=False)[0]

In [None]:
gal_fluxes.topk(10)[0]

In [None]:
from bliss.catalog import convert_mag_to_nmgy

# shift loc downward to simulate galaxies a bit below the detection threshold
# (don't have to do this for stars because loc is already at 23 mag)
min_flux = gal_scale + gal_loc
gal_loc2 = gal_loc - (min_flux - convert_mag_to_nmgy(23.0))


In [None]:
x = np.logspace(gal_fluxes.log10().min(), gal_fluxes.log10().max(), num=1000)
pdf_vals = pareto.pdf(x, gal_alpha, gal_loc2, gal_scale)

_ = plt.plot(x, pdf_vals, 'r-', lw=5, alpha=0.6, label='pareto pdf')
_ = plt.hist(gal_fluxes, log=True, bins=200, label='star_fluxes histogram', density=True)
plt.legend()

In [None]:
n = gal_fluxes.size(0)
sample = Pareto(gal_scale, gal_alpha).sample((n,)) + gal_loc2

sample.topk(10, largest=True)[0]

In [None]:
gal_scale, gal_alpha, gal_loc2

# Determine mean number of sources per field

In [None]:
n_fields = len(fields)
n_fields

In [None]:
easy_detection_threshold = 21.5
n_easy_stars = (star_mags < easy_detection_threshold).sum()
n_easy_gals = (gal_mags < easy_detection_threshold).sum()
n_easy_stars, n_easy_gals


In [None]:
x = convert_mag_to_nmgy(easy_detection_threshold)

In [None]:
star_easy_prop = 1 - truncpareto.cdf(x, star_exponent, star_truncation, star_loc_tp, star_scale_tp)
gal_easy_prop = 1 - truncpareto.cdf(x, gal_exponent, gal_trunc, gal_loc_tp, gal_scale_tp)
star_easy_prop, gal_easy_prop

In [None]:
implied_n_sources = n_easy_stars / star_easy_prop + n_easy_gals / gal_easy_prop
implied_n_sources

In [None]:
implied_prop_galaxy = (n_easy_gals / gal_easy_prop) / implied_n_sources
implied_prop_galaxy

In [None]:
mean_sources_per_field = implied_n_sources / n_fields
mean_sources_per_field

In [None]:
# more direct calculation, but assumes SDSS catalog accuracy
(len(star_fluxes) + len(gal_fluxes)) / n_fields

In [None]:
tiles_per_field = (2048 * 1489) / (4 * 4)
tiles_per_field

In [None]:
mean_sources_per_tile = mean_sources_per_field / tiles_per_field
mean_sources_per_tile