In [None]:
# imports
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import tensorflow as tf
from tensorflow.keras import layers as L, models as M
from astropy.io import fits
from reproject import reproject_interp
from scipy.ndimage import gaussian_filter
# !pip install astropy reproject scipy tensorflow

In [None]:
# basic FITS IO
def read_fits(path):
    path = Path(path)
    with fits.open(path) as hdul:
        data = hdul[0].data.astype(np.float32)
        hdr  = hdul[0].header
    data = np.nan_to_num(data, nan=0.0)
    return data, hdr

In [None]:
# load matched field data + exposure 
chandra_img_path   = "data/chandra/field1_img.fits"
chandra_exp_path   = "data/chandra/field1_expmap.fits"

xmm_img_path       = "data/xmm/field1_img.fits"
xmm_exp_path       = "data/xmm/field1_expmap.fits"

ch_img, ch_hdr     = read_fits(chandra_img_path)
ch_exp, _          = read_fits(chandra_exp_path)

xmm_img, xmm_hdr   = read_fits(xmm_img_path)
xmm_exp, _         = read_fits(xmm_exp_path)

print("Chandra img/exp:", ch_img.shape, ch_exp.shape)
print("XMM img/exp:", xmm_img.shape, xmm_exp.shape)


In [None]:
# WCS reprojection
# Reproject Chandra image and expmap to XMM grid
reproj_ch_img, _  = reproject_interp((ch_img, ch_hdr), xmm_hdr)
reproj_ch_exp, _  = reproject_interp((ch_exp, ch_hdr), xmm_hdr)

reproj_ch_img = np.nan_to_num(reproj_ch_img, nan=0.0).astype(np.float32)
reproj_ch_exp = np.nan_to_num(reproj_ch_exp, nan=0.0).astype(np.float32)

xmm_img  = np.nan_to_num(xmm_img,  nan=0.0).astype(np.float32)
xmm_exp  = np.nan_to_num(xmm_exp,  nan=0.0).astype(np.float32)

print("Reprojected Chandra:", reproj_ch_img.shape, reproj_ch_exp.shape)
print("XMM:", xmm_img.shape, xmm_exp.shape)


In [None]:
# forward model baseline psf + exposure correction + posson

def forward_model_baseline(ch_counts, ch_exp, xmm_exp,
                           sigma_px=4.0, poisson_scale=100.0, add_poisson=False):
    """
    ch_counts: Chandra image on XMM grid
    ch_exp:    Chandra expmap on XMM grid
    xmm_exp:   XMM expmap on XMM grid
    sigma_px:  PSF Gaussian sigma in pixels
    """
    # Avoid division by zero
    ch_exp_safe  = np.where(ch_exp > 0, ch_exp, np.nan)
    xmm_exp_safe = np.where(xmm_exp > 0, xmm_exp, np.nan)

    # Rough exposure correction: scale to XMM exposure
    ratio = xmm_exp_safe / ch_exp_safe
    ratio = np.nan_to_num(ratio, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

    ch_scaled = ch_counts * ratio

    # PSF blur (approximate XMM PSF)
    blurred = gaussian_filter(ch_scaled, sigma=sigma_px)
    blurred = np.clip(blurred, 0, None).astype(np.float32)

    if add_poisson:
        lam = blurred * poisson_scale
        noisy_counts = np.random.poisson(lam).astype(np.float32)
        model = noisy_counts / poisson_scale
    else:
        model = blurred

    return model

baseline = forward_model_baseline(reproj_ch_img, reproj_ch_exp, xmm_exp,
                                  sigma_px=4.0, poisson_scale=100.0,
                                  add_poisson=False)
print("Baseline shape:", baseline.shape)



In [None]:
# my metrics
def compute_metrics(pred, true):
    pred_tf = tf.convert_to_tensor(pred[None, ..., None], tf.float32)
    true_tf = tf.convert_to_tensor(true[None, ..., None], tf.float32)

    maxv = float(np.max(true) + 1e-6)
    psnr = tf.image.psnr(pred_tf, true_tf, max_val=maxv)[0].numpy()
    ssim = tf.image.ssim(pred_tf, true_tf, max_val=maxv)[0].numpy()

    flux_true = float(np.sum(true))
    flux_pred = float(np.sum(pred))
    rel_flux_err = (flux_pred - flux_true) / (flux_true + 1e-8)

    return psnr, ssim, flux_true, flux_pred, rel_flux_err

psnr_b, ssim_b, flux_t, flux_b, ferr_b = compute_metrics(baseline, xmm_img)

print(f"Baseline PSNR: {psnr_b:.3f} dB")
print(f"Baseline SSIM: {ssim_b:.3f}")
print(f"True flux: {flux_t:.3e}")
print(f"Baseline flux: {flux_b:.3e}")
print(f"Baseline relative flux error: {ferr_b:.3%}")


In [None]:
# small sweep over PSF sigma 
sigmas = [2.0, 4.0, 6.0]
results = []

for s in sigmas:
    base = forward_model_baseline(
        reproj_ch_img, reproj_ch_exp, xmm_exp,
        sigma_px=s, poisson_scale=100.0,
        add_poisson=False
    )
    psnr_s, ssim_s, flux_t_s, flux_b_s, ferr_s = compute_metrics(base, xmm_img)
    results.append((s, psnr_s, ssim_s, ferr_s))

print("Baseline performance vs PSF sigma:")
for s, psnr_s, ssim_s, ferr_s in results:
    print(f"  sigma={s:.1f}: PSNR={psnr_s:.3f} dB, "
          f"SSIM={ssim_s:.3f}, flux err={ferr_s:.2%}")


In [None]:
# my visuals
vmin = np.percentile(xmm_img, 5)
vmax = np.percentile(xmm_img, 99)

fig, ax = plt.subplots(1, 4, figsize=(16, 4))
for a in ax: a.set_axis_off()

ax[0].imshow(reproj_ch_img, origin="lower", vmin=vmin, vmax=vmax)
ax[0].set_title("Chandra (reproj)")

ax[1].imshow(xmm_img, origin="lower", vmin=vmin, vmax=vmax)
ax[1].set_title("True XMM")

ax[2].imshow(baseline, origin="lower", vmin=vmin, vmax=vmax)
ax[2].set_title("Forward-model baseline")

resid = xmm_img - baseline
im = ax[3].imshow(resid, origin="lower")
ax[3].set_title("Residual (XMM - model)")

plt.tight_layout()
Path("figures").mkdir(exist_ok=True, parents=True)
plt.savefig("figures/forward_baseline_triptych_field1.png", dpi=250)
plt.show()
