## Setup

In [None]:
import numpy as np
import gzip

from astropy.utils.data import download_file
from astropy.io import fits
from astropy.table import Table
from astropy.wcs import WCS

import matplotlib.pyplot as plt

from bliss.utils.download_utils import download_file_to_dst

def plot_image(hdu):
    ax = plt.subplot(projection=WCS(hdu.header))
    ax.imshow(hdu.data,cmap='gray',vmin=hdu.data.min(),vmax=hdu.data.min()+(hdu.data.max()-hdu.data.min())/100.)

URLBASE = "https://portal.nersc.gov/cfs/cosmo/data/legacysurvey/dr9"
ra, dec = 336.635, -0.96

survey_bricks_filename = download_file(f"{URLBASE}/south/survey-bricks-dr9-south.fits.gz", cache=True, show_progress=True, timeout=120)

# ra1 - lower RA boundary
# ra2 - upper RA boundary
# dec1 - lower Dec boundary
# dec2 - upper Dec boundary
survey_bricks = Table.read(survey_bricks_filename)

# Basic (RA, Dec) <--> Brick conversion
brickname = survey_bricks[
            (survey_bricks["ra1"] <= ra)
            & (survey_bricks["ra2"] >= ra)
            & (survey_bricks["dec1"] <= dec)
            & (survey_bricks["dec2"] >= dec)
        ]["brickname"][0]
print(f"Brick for RA, Dec ({ra}, {dec}):", brickname)

## DECaLS Co-added Images

Get images

In [None]:
image_g_filename = download_file("{}/south/coadd/{}/{}/legacysurvey-{}-image-g.fits.fz".format(URLBASE, brickname[:3], brickname, brickname), cache=False)
image_r_filename = download_file("{}/south/coadd/{}/{}/legacysurvey-{}-image-r.fits.fz".format(URLBASE, brickname[:3], brickname, brickname), cache=False)
image_z_filename = download_file("{}/south/coadd/{}/{}/legacysurvey-{}-image-z.fits.fz".format(URLBASE, brickname[:3], brickname, brickname), cache=False)

In [None]:
image_g = fits.open(image_g_filename)
image_r = fits.open(image_r_filename)
image_z = fits.open(image_z_filename)

In [None]:
image_g[1].header["BRICK"]

Inspect SDSS images and catalog

In [None]:
fitsfile = fits.open("/home/zhteoh/871-decals-e2e/data/sdss/94/1/12/frame-g-000094-1-0012.fits")

In [None]:
hdu0 = fitsfile[0]
hdu1 = fitsfile[1]
hdu2 = fitsfile[2]
hdu3 = fitsfile[3]

In [None]:
plot_image(hdu0)

In [None]:
hdu0.data # sky-subtracted nmgy/pixel

## DECaLS PSF

Get DECaLS PSF model

In [None]:
brick_model_g_filename = download_file(f"{URLBASE}/south/coadd/{brickname[:3]}/{brickname}/legacysurvey-{brickname}-model-g.fits.fz", cache=True, show_progress=True, timeout=120)
brick_model_r_filename = download_file(f"{URLBASE}/south/coadd/{brickname[:3]}/{brickname}/legacysurvey-{brickname}-model-r.fits.fz", cache=True, show_progress=True, timeout=120)
brick_model_z_filename = download_file(f"{URLBASE}/south/coadd/{brickname[:3]}/{brickname}/legacysurvey-{brickname}-model-z.fits.fz", cache=True, show_progress=True, timeout=120)

In [None]:
psfsize_g_filename = download_file(f"{URLBASE}/south/coadd/{brickname[:3]}/{brickname}/legacysurvey-{brickname}-psfsize-g.fits.fz", cache=True, show_progress=True, timeout=120)
psfsize_r_filename = download_file(f"{URLBASE}/south/coadd/{brickname[:3]}/{brickname}/legacysurvey-{brickname}-psfsize-r.fits.fz", cache=True, show_progress=True, timeout=120)
psfsize_z_filename = download_file(f"{URLBASE}/south/coadd/{brickname[:3]}/{brickname}/legacysurvey-{brickname}-psfsize-z.fits.fz", cache=True, show_progress=True, timeout=120)

In [None]:
tractor_g_fitsfile = fits.open(brick_model_r_filename)
tractor_r_fitsfile = fits.open(brick_model_r_filename)
tractor_z_fitsfile = fits.open(brick_model_r_filename)

PSF via params in `ccds-annotated-decam-dr9.fits.gz`

In [None]:
# WARNING: DOWNLOAD TAKES A LONG TIME - 1.71GB
from bliss.utils.download_utils import download_file_to_dst
import gzip

# REPLACE THIS WITH YOUR ${workspaceFolder}
WORKSPACE_FOLDER = "/home/zhteoh/871-decals-e2e"
dst_filename = f"{WORKSPACE_FOLDER}/data/ccds-annotated-decam-dr9.fits.gz"
ccds_annotated_filename = download_file_to_dst(f"{URLBASE}/ccds-annotated-decam-dr9.fits.gz", 
                                               dst_filename,
                                               gzip.decompress)

In [None]:
ccds_annotated = Table.read("/home/zhteoh/871-decals-e2e/data/decals/ccds-annotated-decam-dr9.fits")

In [None]:
select_cols = ["ccdname", "ra", "dec"]
psf_cols = [col for col in ccds_annotated.colnames if col.startswith("psf") or col.startswith("gal") or col.startswith("gauss") or col in select_cols]
ccds_annotated[psf_cols][:5].show_in_notebook()

Get CCDs used for brick

In [None]:
brick_ccds_filename = download_file("{}/south/coadd/{}/{}/legacysurvey-{}-ccds.fits".format(URLBASE, brickname[:3], brickname, brickname), cache=False)
brick_ccds = Table.read(brick_ccds_filename)

In [None]:
ccds_for_brick = brick_ccds["ccdname"]

In [None]:
# Show rows in ccds_annotated that match the ccds_for_brick
mask = np.isin(ccds_annotated["ccdname"], ccds_for_brick)
ccds_psf_r = ccds_annotated[mask & (ccds_annotated["filter"] == 'r')][psf_cols]

Create PSF model based on these CCDs

In [None]:
psf_mx2 = np.median(ccds_psf_r["psf_mx2"])
psf_my2 = np.median(ccds_psf_r["psf_my2"])
psf_mxy = np.median(ccds_psf_r["psf_mxy"])
psf_a = np.median(ccds_psf_r["psf_a"])
psf_b = np.median(ccds_psf_r["psf_b"])
psf_theta = np.median(ccds_psf_r["psf_theta"])
psf_ell = np.median(ccds_psf_r["psf_ell"])

psfnorm_mean = np.median(ccds_psf_r["psfnorm_mean"])
psfnorm_std = np.median(ccds_psf_r["psfnorm_std"])

psfdepth = np.median(ccds_psf_r["psfdepth"])
galdepth = np.median(ccds_psf_r["galdepth"])
gausspsfdepth = np.median(ccds_psf_r["gausspsfdepth"])
gaussgaldepth = np.median(ccds_psf_r["gaussgaldepth"])

Or use FWHM for brick

In [None]:
brick_fwhm_g_filename = download_file("{}/south/coadd/{}/{}/legacysurvey-{}-psfsize-g.fits.fz".format(URLBASE, brickname[:3], brickname, brickname), cache=False)
brick_fwhm_r_filename = download_file("{}/south/coadd/{}/{}/legacysurvey-{}-psfsize-r.fits.fz".format(URLBASE, brickname[:3], brickname, brickname), cache=False)
brick_fwhm_z_filename = download_file("{}/south/coadd/{}/{}/legacysurvey-{}-psfsize-z.fits.fz".format(URLBASE, brickname[:3], brickname, brickname), cache=False)

In [None]:
brick_fwhm_g = fits.open(brick_fwhm_g_filename)
brick_fwhm_r = fits.open(brick_fwhm_r_filename)
brick_fwhm_z = fits.open(brick_fwhm_z_filename)

psf_fwhm = np.median(brick_fwhm_r[1].data)

In [None]:
psf_fwhm.shape

In [None]:
pixel_scale = 0.262
psf_slen = 63

In [None]:
import numpy as np
import torch

def _get_mgrid():
    """Construct the base grid for the PSF."""
    offset = (psf_slen - 1) / 2
    x, y = np.mgrid[-offset : (offset + 1), -offset : (offset + 1)]
    mgrid = torch.tensor(np.dstack((y, x))) / offset
    return mgrid.float()

In [None]:
import galsim
import math

# Create the PSF model

# Inner Moffat profile
psf_inner = galsim.Moffat(beta=psf_b, fwhm=psf_fwhm)

# Outer profile - Moffat or Power-law
band = 'r'
if band == 'z':
    alpha, beta, weight = 17.650, 1.7, 0.0145
    compact_z_ccds = ['N20', 'S8', 'S10', 'S18', 'S21', 'S27']
    if ccds_for_brick in compact_z_ccds:
        alpha, beta, weight = 16, 2.3, 0.0095
    # Create the first Moffat PSF component
    moffat1 = galsim.Moffat(beta=beta, fwhm=2 * alpha * math.sqrt(2**(1/beta) - 1))
    # Create the second Moffat PSF component
    moffat2 = galsim.Moffat(beta=beta, fwhm=2 * alpha * math.sqrt(2**(1/beta) - 1) / weight)
    # Combine the two Moffat components using Moffat weighting
    weighted_moffat = weight * moffat1 + (1 - weight) * moffat2
    psf_outer = weighted_moffat
else:
    assert band in ['g', 'r', 'i']
    grid = _get_mgrid() * (psf_slen - 1) / 2
    radii_grid = (grid**2).sum(2).sqrt()
    if band == 'g':
        outer = 0.00045 * radii_grid**(-2)
    elif band == 'r':
        outer = 0.00033 * radii_grid**(-2)
    elif band == 'i':
        outer = 0.00033 * radii_grid**(-2)
    psf_outer = galsim.InterpolatedImage(galsim.Image(outer.numpy(), scale=pixel_scale)).withFlux(1.0)

# Combine the inner and outer profiles
psf_combined = galsim.Convolve([psf_inner, psf_outer])

# Apply ellipticity and position angle to the PSF model
psf_combined = psf_combined.shear(e=psf_ell, beta=psf_theta * galsim.degrees)

psf_image = galsim.Image(psf_combined.detach().numpy(), scale=pixel_scale)
psf_image /= psfnorm_mean

psf = galsim.InterpolatedImage(psf_image).withFlux(1.0)


## SDSS PSF

Inspect SDSS PSF

In [None]:
bands = [0, 1, 2, 3, 4]
pixel_scale=0.396
psf_slen=25

In [None]:
import torch

def _psf_fun(r, sigma1, sigma2, sigmap, beta, b, p0):
    term1 = torch.exp(-(r**2) / (2 * sigma1))
    term2 = b * torch.exp(-(r**2) / (2 * sigma2))
    term3 = p0 * (1 + r**2 / (beta * sigmap)) ** (-beta / 2)
    return (term1 + term2 + term3) / (1 + b + p0)

In [None]:
import numpy as np

def _get_mgrid():
    """Construct the base grid for the PSF."""
    offset = (psf_slen - 1) / 2
    x, y = np.mgrid[-offset : (offset + 1), -offset : (offset + 1)]
    mgrid = torch.tensor(np.dstack((y, x))) / offset
    return mgrid.float()

In [None]:
from einops import rearrange, reduce
import torch

def get_psf(params):
    """Construct PSF image from parameters. This is the main entry point for generating the psf.

    Args:
        params: list of psf parameters, loaded from _get_fit_file_psf_params

    Returns:
        images (List[InterpolatedImage]): list of psf transformations for each band
    """
    # get psf in each band
    psf_list = []
    for i in range(n_bands):
        grid = _get_mgrid() * (psf_slen - 1) / 2
        radii_grid = (grid**2).sum(2).sqrt()
        band_psf = _psf_fun(radii_grid, *params[i])
        psf_list.append(band_psf.unsqueeze(0))
    psf = torch.cat(psf_list)
    assert (psf > 0).all()

    # ensure it's normalized
    norm = reduce(psf, "b m k -> b", "sum")
    psf *= rearrange(1 / norm, "b -> b 1 1")

    # check format
    n_bands, psf_slen, _ = psf.shape
    assert n_bands == n_bands and (psf_slen % 2) == 1 and psf_slen == psf.shape[2]

    # convert to image
    images = []
    for i in range(n_bands):
        psf_image = galsim.Image(psf.detach().numpy()[i], scale=pixel_scale)
        images.append(galsim.InterpolatedImage(psf_image).withFlux(1.0))

    return images

In [None]:
from typing import Tuple
from pathlib import Path

def _get_fit_file_psf_params(psf_fit_file: str, bands: Tuple[int, ...]):
    """Load psf parameters from fits file.

    See https://data.sdss.org/datamodel/files/PHOTO_REDUX/RERUN/RUN/objcs/CAMCOL/psField.html
    for details on the parameters.

    Args:
        psf_fit_file (str): file to load from
        bands (Tuple[int, ...]): SDSS bands to load

    Returns:
        psf_params: tensor of parameters for each band
    """
    msg = (
        f"{psf_fit_file} does not exist. "
        + "Make sure data files are available for fields specified in config."
    )
    assert Path(psf_fit_file).exists(), msg
    # HDU 6 contains the PSF header (after primary and eigenimages)
    data = fits.open(psf_fit_file, ignore_missing_end=True).pop(6).data
    psf_params = torch.zeros(len(bands), 6)
    for i, band in enumerate(bands):
        sigma1 = data["psf_sigma1"][0][band] ** 2
        sigma2 = data["psf_sigma2"][0][band] ** 2
        sigmap = data["psf_sigmap"][0][band] ** 2
        beta = data["psf_beta"][0][band]
        b = data["psf_b"][0][band]
        p0 = data["psf_p0"][0][band]

        psf_params[i] = torch.tensor([sigma1, sigma2, sigmap, beta, b, p0])

    return psf_params

psf_fit_file = "/home/zhteoh/871-decals-e2e/data/sdss/94/1/12/psField-000094-1-0012.fits"

In [None]:
data = fits.open(psf_fit_file, ignore_missing_end=True).pop(6).data
psf_params = torch.zeros(len(bands), 6)

data["psf_sigma2"][0]

## DECaLS Prior via single-exposure DECam CCDs

In [None]:
BRICKNAME = brickname
BRICKNAME

In [None]:
brick_ccds_filename = download_file(f"{URLBASE}/south/coadd/{BRICKNAME[:3]}/{BRICKNAME}/legacysurvey-{BRICKNAME}-ccds.fits", cache=False, show_progress=True, timeout=120)
brick_ccds = Table.read(brick_ccds_filename)
brick_ccds.show_in_notebook(display_length=2)

In [None]:
# Choose a CCD (i.e., given image)
CCDNAME = "S28"

# NOTE: make sure images exist for these filters
BANDS = ["g", "r", "z"]

In [None]:
brick_ccds_fits = fits.open(brick_ccds_filename)

# Get image data from `CCDNAME`, for each band
image_basenames_without_ext = {}
matching_rows = []
for band in BANDS:
    matching_row = np.where((brick_ccds["ccdname"] == CCDNAME) & (brick_ccds["filter"] == band))[0][0]
    image_basenames_without_ext[band] = brick_ccds["image_filename"][matching_row].replace(".fits.fz", "")
    matching_rows.append(matching_row)

print(image_basenames_without_ext)
brick_ccds[matching_rows].show_in_notebook()

Get single calibrated images (for each band) for this CCD

In [None]:
from numpy.core.defchararray import find

from pyvo.dal import sia

DEF_ACCESS_URL = "https://datalab.noirlab.edu/sia/calibrated_all"
svc = sia.SIAService(DEF_ACCESS_URL)

imgTable = svc.search((ra,dec)).to_table()
img_access_urls = {} # indexed by band
for b, basename in image_basenames_without_ext.items():
    # remove everything except filename (i.e., after last '/'), and after _`b` character
    img_filename = basename.split("/")[-1].split(f"_{b}")[0] + f"_{b}"
    # print(f"Searching for {img_filename} in {b}-band")
    b_access_url = imgTable[find(imgTable["access_url"].astype(str), img_filename) != -1][0]["access_url"]
    img_access_urls[b] = b_access_url

img_access_urls

In [None]:
imgTable

In [None]:
# Download single images
ccd_images = {} # indexed by band
ccd_image_filenames = {} # indexed by band

for band, img_access_url in img_access_urls.items():
    img_fits_filename = download_file(img_access_url, cache=True, show_progress=True, timeout=120)
    ccd_image_filenames[band] = img_fits_filename
    img_fits = fits.open(img_fits_filename)
    ccd_images[band] = img_fits[0].data


In [None]:
img_fits[0].header

In [None]:
# Shape of each single image
ccd_images["g"].shape

In [None]:
print("g-band MAGZERO from downloaded image header:", fits.open(ccd_image_filenames["g"])[0].header["MAGZERO"])
print("g-band MAGZPT from downloaded image header:", fits.open(ccd_image_filenames["g"])[0].header["MAGZPT"])

print("g-band CCDZPT from brick-ccds:", brick_ccds[np.where((brick_ccds["ccdname"] == CCDNAME) & (brick_ccds["filter"] == "g"))[0][0]]["ccdzpt"])

### DECaLS PSF via single-exposure PSFEx FITS

In [None]:
# Get PSFEx fits, for each band
psfex_hdus = {}
for band in BANDS:
    psfex_fits_fn = download_file(f"{URLBASE}/calib/psfex/{image_basenames_without_ext[band]}-psfex.fits", cache=False)
    psfex_fits = fits.open(psfex_fits_fn)

    psfex_table_hdu = psfex_fits[1]

    # Get `row` corresponding to DECam image (i.e., CCD)
    rows = np.where(psfex_table_hdu.data["ccdname"] == CCDNAME)[0]
    assert len(rows) == 1
    row = rows[0]
    # print("Row:", row)

    # Get single values for the following parameters
    params = ["polnaxis", "polzero1", "polzero2", "polscal1", "polscal2", "polname1", "polname2", "polngrp", "polgrp1", "polgrp2", "poldeg1", "psfnaxis", "psfaxis1", "psfaxis2", "psfaxis3", "psf_samp"]

    # Create new BinTableHDU with just the row corresponding to the DECam image
    with fits.open(psfex_fits_fn, mode="update") as hdul:
        bintable = hdul[1]

        # Add to HDU header
        for param in params:
            bintable.header[param.upper()] = psfex_table_hdu.data[row][param]
        
        bintable.data = bintable.data[row:row+1]
        bintable.header["NAXIS2"] = 1
        bintable.header["NAXIS1"] = len(bintable.columns)

        hdul.flush()

    psfex_fits = fits.open(psfex_fits_fn)
    psfex_table_hdu = psfex_fits[1]
    psfex_hdus[band] = psfex_table_hdu

psfex_hdus

In [None]:
for b, hdu in psfex_hdus.items():
    t = Table.read(hdu)
    cols = t.colnames
    cols.sort()
    print(cols)

Create the PSF model, passing in the PSFEx and image HDUs

In [None]:
image_filename_r = "/home/zhteoh/871-decals-e2e-des/data/des/336/3366m010/c4d_170927_025457_ooi_r.fits"

image_r_fits = fits.open(image_filename_r)

In [None]:
ccd_r_fits = fits.open(ccd_image_filenames["r"])

In [None]:
ccd_image_filenames["r"]

In [None]:
import galsim.des

des_psfex_r = galsim.des.DES_PSFEx(psfex_hdus["r"], image_filename_r)

In [None]:
image_pos = galsim.PositionD(0, 0)
psf_r = des_psfex_r.getPSF(image_pos)

In [None]:
psf_r

### DECaLS background for CCD

In [None]:
background_hdus = {}
for band in BANDS:
    background_fits_fn = download_file(f"{URLBASE}/calib/sky/{image_basenames_without_ext[band]}-splinesky.fits", cache=False)
    background_fits = fits.open(background_fits_fn)

    background_table_hdu = background_fits[1]
    background_table = Table.read(background_table_hdu)

    # Get `row` corresponding to DECam image (i.e., CCD)
    rows = np.where(background_table["ccdname"] == CCDNAME)[0]
    assert len(rows) == 1
    row = rows[0]
    # print("Row:", row)

    # Create new BinTableHDU with just the row corresponding to the DECam image
    with fits.open(background_fits_fn, mode="update") as hdul:
        bintable = hdul[1]
        bintable.data = bintable.data[row:row+1]
        bintable.header["NAXIS2"] = 1
        bintable.header["NAXIS1"] = len(bintable.columns)

        hdul.flush()

    background_fits = fits.open(background_fits_fn)
    background_table_hdu = background_fits[1]
    background_hdus[band] = background_table_hdu

background_hdus

In [None]:
bg_r = Table.read(background_hdus["r"])

bg_r[0]

In [None]:
from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import zoom

splinesky_params = bg_r[0]
gridw = splinesky_params["gridw"]
gridh = splinesky_params["gridh"]
gridvals = splinesky_params["gridvals"]
xgrid = splinesky_params["xgrid"]
ygrid = splinesky_params["ygrid"]
order = splinesky_params["order"]

print("Spline grid shape:", (gridw, gridh))
print("Spline interpolation order:", order)

# Example image data
image_w, image_h = 3600, 3600
image = np.random.rand(image_w, image_h)

# Meshgrid for pixel coordinates on smaller grid
x, y = np.meshgrid(np.arange(gridw), np.arange(gridh))
# Initialize the B-spline sky model with the extracted parameters
splinesky_x = RectBivariateSpline(ygrid, xgrid, gridvals, kx=order, ky=order)
splinesky_y = RectBivariateSpline(ygrid, xgrid, gridvals, kx=order, ky=order)

# Evaluate the sky model at the given pixel coordinates
background_values_grid = splinesky_x(y.flatten(), x.flatten(), grid=False).reshape(gridh, gridw)

# Upscale the background values from the smaller grid to the original image size using bilinear interpolation
background_values = zoom(background_values_grid, zoom=(image_h / gridh, image_w / gridw), order=order, mode='nearest')

print("===")
print("B-spline interpolated background shape:", background_values.shape)
background_values