In [None]:
# reusing imports from baseline as well
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import tensorflow as tf
from astropy.io import fits
from reproject import reproject_interp
from scipy.ndimage import gaussian_filter


# adding
import tensorflow as tf
from tensorflow.keras import layers as L, models as M
import numpy as np


In [None]:
# small unet ; input will be multichannel, 
# chandra counts + chandra/xmm exposure

def build_unet(input_shape=(128, 128, 2)):
    inputs = L.Input(shape=input_shape)

    # Downsampling
    c1 = L.Conv2D(32, 3, padding="same", activation="relu")(inputs)
    c1 = L.Conv2D(32, 3, padding="same", activation="relu")(c1)
    p1 = L.MaxPool2D(2)(c1)

    c2 = L.Conv2D(64, 3, padding="same", activation="relu")(p1)
    c2 = L.Conv2D(64, 3, padding="same", activation="relu")(c2)
    p2 = L.MaxPool2D(2)(c2)

    c3 = L.Conv2D(128, 3, padding="same", activation="relu")(p2)
    c3 = L.Conv2D(128, 3, padding="same", activation="relu")(c3)
    p3 = L.MaxPool2D(2)(c3)

    # Bottleneck
    b = L.Conv2D(256, 3, padding="same", activation="relu")(p3)
    b = L.Conv2D(256, 3, padding="same", activation="relu")(b)

    # Upsampling
    u3 = L.UpSampling2D(2)(b)
    u3 = L.Concatenate()([u3, c3])
    c4 = L.Conv2D(128, 3, padding="same", activation="relu")(u3)
    c4 = L.Conv2D(128, 3, padding="same", activation="relu")(c4)

    u2 = L.UpSampling2D(2)(c4)
    u2 = L.Concatenate()([u2, c2])
    c5 = L.Conv2D(64, 3, padding="same", activation="relu")(u2)
    c5 = L.Conv2D(64, 3, padding="same", activation="relu")(c5)

    u1 = L.UpSampling2D(2)(c5)
    u1 = L.Concatenate()([u1, c1])
    c6 = L.Conv2D(32, 3, padding="same", activation="relu")(u1)
    c6 = L.Conv2D(32, 3, padding="same", activation="relu")(c6)

    outputs = L.Conv2D(1, 1, padding="same", activation="linear")(c6)
    model = M.Model(inputs, outputs, name="unet_chandra_to_xmm")
    return model

unet = build_unet()
unet.summary()


In [None]:
# make training patches/tf.data w aligned reproj_ch+img reproj_ch_exp xmm_img from baseline 
# use here
PATCH = 128
BATCH = 4

def extract_patches(ch_img, ch_exp, xmm_img, n_patches=256, patch_size=PATCH):
    H, W = ch_img.shape
    X_list, Y_list = [], []

    for _ in range(n_patches):
        i = np.random.randint(0, H - patch_size)
        j = np.random.randint(0, W - patch_size)

        ch_patch  = ch_img[i:i+patch_size, j:j+patch_size]
        exp_patch = ch_exp[i:i+patch_size, j:j+patch_size]
        xmm_patch = xmm_img[i:i+patch_size, j:j+patch_size]

        # stack channels: [Chandra counts, Chandra exp]
        inp = np.stack([ch_patch, exp_patch], axis=-1)  # (H,W,2)
        X_list.append(inp.astype(np.float32))
        Y_list.append(xmm_patch[..., None].astype(np.float32))  # (H,W,1)

    X = np.stack(X_list, axis=0)
    Y = np.stack(Y_list, axis=0)
    return X, Y

X_all, Y_all = extract_patches(reproj_ch_img, reproj_ch_exp, xmm_img,
                               n_patches=512, patch_size=PATCH)

print("All patches:", X_all.shape, Y_all.shape)

# simple train/val split
N = X_all.shape[0]
split = int(0.8 * N)
X_train, X_val = X_all[:split], X_all[split:]
Y_train, Y_val = Y_all[:split], Y_all[split:]

print("Train:", X_train.shape, Y_train.shape)
print("Val:",   X_val.shape,   Y_val.shape)

train_ds = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
train_ds = train_ds.shuffle(512).batch(BATCH).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((X_val, Y_val))
val_ds = val_ds.batch(BATCH).prefetch(tf.data.AUTOTUNE)


In [None]:
# physics loss reconstruction+flux; l1 reocnstruction (counts), flux term of absolute diff of flux per patch, ssim term
def flux_loss(y_true, y_pred):
    # total flux per patch
    true_flux = tf.reduce_sum(y_true, axis=[1,2,3])
    pred_flux = tf.reduce_sum(y_pred, axis=[1,2,3])
    return tf.reduce_mean(tf.abs(true_flux - pred_flux))

def recon_loss_l1(y_true, y_pred):
    return tf.reduce_mean(tf.abs(y_true - y_pred))

def ssim_loss(y_true, y_pred):
    maxv = tf.reduce_max(y_true) + 1e-6
    ssim = tf.image.ssim(y_true, y_pred, max_val=maxv)
    return 1.0 - tf.reduce_mean(ssim)  # minimize 1-SSIM


In [None]:
# combined
lambda_flux = 1.0
lambda_ssim = 0.1  # small

def total_loss(y_true, y_pred):
    l_recon = recon_loss_l1(y_true, y_pred)
    l_flux  = flux_loss(y_true, y_pred)
    l_ssim  = ssim_loss(y_true, y_pred)
    return l_recon + lambda_flux * l_flux + lambda_ssim * l_ssim


In [None]:
#ablation
#lambda_flux = 0.0
#lambda_ssim = 0.1

In [None]:
# training step
optimizer = tf.keras.optimizers.Adam(2e-4)

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pred = unet(x, training=True)
        loss = total_loss(y, y_pred)
    grads = tape.gradient(loss, unet.trainable_variables)
    optimizer.apply_gradients(zip(grads, unet.trainable_variables))
    return loss


In [None]:
# training loop
EPOCHS = 5
train_history = []
val_history = []

for epoch in range(EPOCHS):
    # training
    train_losses = []
    for x_batch, y_batch in train_ds:
        loss = train_step(x_batch, y_batch)
        train_losses.append(loss.numpy())
    train_loss = float(np.mean(train_losses))

    # validation
    val_losses = []
    for x_batch, y_batch in val_ds:
        y_pred = unet(x_batch, training=False)
        val_loss = total_loss(y_batch, y_pred).numpy()
        val_losses.append(val_loss)
    val_loss_mean = float(np.mean(val_losses))

    train_history.append(train_loss)
    val_history.append(val_loss_mean)
    print(f"Epoch {epoch+1}/{EPOCHS} - train: {train_loss:.4f}  val: {val_loss_mean:.4f}")

In [None]:
#plt.figure()
#plt.plot(train_history, label="train")
#plt.plot(val_history, label="val")
#plt.xlabel("Epoch")
#plt.ylabel("Loss")
#plt.legend()
#plt.title("U-Net training vs validation loss")
#plt.show()

In [None]:
# pick one consistent patch region from the *original* image to compare
i0, j0 = 100, 100  # choose something inside 
patch = PATCH

# extract for comparison
ch_patch  = reproj_ch_img[i0:i0+patch, j0:j0+patch]
exp_patch = reproj_ch_exp[i0:i0+patch, j0:j0+patch]
xmm_patch = xmm_img[i0:i0+patch, j0:j0+patch]

# baseline on patch
baseline_patch = forward_model_baseline(
    ch_patch, exp_patch, xmm_exp[i0:i0+patch, j0:j0+patch],
    sigma_px=4.0, poisson_scale=100.0, add_poisson=False
)

# unet prediction on that patch
inp = np.stack([ch_patch, exp_patch], axis=-1)[None, ...].astype(np.float32) # (1,H,W,2)
unet_pred = unet(inp, training=False).numpy()[0, ..., 0]  # (H,W)

# metrics (baseline compute_metrics)
psnr_b, ssim_b, flux_t, flux_b, ferr_b = compute_metrics(baseline_patch, xmm_patch)
psnr_u, ssim_u, _,      flux_u, ferr_u = compute_metrics(unet_pred,      xmm_patch)

print("Baseline vs XMM:")
print(f"  PSNR: {psnr_b:.3f} dB, SSIM: {ssim_b:.3f}, flux err: {ferr_b:.2%}")
print("U-Net vs XMM:")
print(f"  PSNR: {psnr_u:.3f} dB, SSIM: {ssim_u:.3f}, flux err: {ferr_u:.2%}")


In [None]:
#vmin = np.percentile(xmm_patch, 5)
#vmax = np.percentile(xmm_patch, 99)

#fig, ax = plt.subplots(1, 4, figsize=(16, 4))
#for a in ax: a.set_axis_off()

#ax[0].imshow(xmm_patch, origin="lower", vmin=vmin, vmax=vmax)
#ax[0].set_title("True XMM patch")

#ax[1].imshow(baseline_patch, origin="lower", vmin=vmin, vmax=vmax)
#ax[1].set_title("Baseline")

#ax[2].imshow(unet_pred, origin="lower", vmin=vmin, vmax=vmax)
#ax[2].set_title("U-Net")

#resid_unet = xmm_patch - unet_pred
#im = ax[3].imshow(resid_unet, origin="lower")
#ax[3].set_title("Residual (XMM - U-Net)")

#plt.tight_layout()
#plt.show()
