In [40]:
import numpy as np
from astropy.io import fits
from spectral_cube import SpectralCube
from scipy.signal import fftconvolve
from scipy.fftpack import fft2, ifft2, fftshift
import stpsf
import astropy.units as u
import os


In [41]:
from glob import glob
# This is for line files. If you want to use CASA in python to get the line file,
# I will upload the code soon of a repo and link here! Sorry for now!
# However you can use your own line file just make sure the jwst psf target wavelength is same to the 
# gaussian smoothed line file you have!

# ---------- USER PARAMETERS ----------
IN_DIR  = "../../github/astr796_25_V2/NGC253_output/3channel/casa/atomic"
OUT_DIR = "../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth"
REF_CUBE = "../../github/astr796_25_V2/NGC253_output/3channel/casa/atomic/NGC253_PIII_17.9.line.fits"
TARGET_LAMBDA = 17.8846  # Âµm
FILE_GLOB = os.path.join(IN_DIR, "*.fits")
# ------------------------------------

os.makedirs(OUT_DIR, exist_ok=True)
files = sorted(glob(FILE_GLOB))

# --- Load reference cube to get pixel scale ---
with fits.open(REF_CUBE) as ref:
    hdr_ref = ref[0].header
    pixscale = abs(hdr_ref["CDELT2"]) * 3600.0  # arcsec/pixel

# --- Compute PSF kernel once ---
miri = stpsf.MIRI()
miri.mode = 'IFU'
miri.band = '3C'
miri.pixelscale = pixscale

psf = miri.calc_psf(monochromatic=TARGET_LAMBDA * u.micron)
psf_kernel = psf[1].data.astype(float)
psf_kernel /= psf_kernel.sum()

# --- Process each cube ---
for infile in files:
    print(f"Processing: {os.path.basename(infile)}")

    flux = fits.getdata(infile, hdu=1)
    hdr  = fits.getheader(infile, hdu=1)

    nlam = flux.shape[0]
    sm_flux = np.zeros_like(flux)

    for i in range(nlam):

        img = np.nan_to_num(flux[i], nan=0.0)
        mask = (~np.isnan(flux[i])).astype(float)

        sm = fftconvolve(img, psf_kernel, mode='same')
        sm[mask == 0] = np.nan
        sm_flux[i] = sm

    outname = os.path.join(OUT_DIR, os.path.basename(infile).replace(".fits", "_smooth.fits"))
    fits.writeto(outname, sm_flux, hdr, overwrite=True)
    print(f" â†’ Saved: {outname}")

print("All cubes smoothed successfully.")


Processing: NGC253_ArIII_9.0.cont.fits
 â†’ Saved: ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth/NGC253_ArIII_9.0.cont_smooth.fits
Processing: NGC253_ArIII_9.0.line.fits
 â†’ Saved: ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth/NGC253_ArIII_9.0.line_smooth.fits
Processing: NGC253_ArII_7.0.cont.fits
 â†’ Saved: ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth/NGC253_ArII_7.0.cont_smooth.fits
Processing: NGC253_ArII_7.0.line.fits
 â†’ Saved: ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth/NGC253_ArII_7.0.line_smooth.fits
Processing: NGC253_C2H2_13.7.cont.fits
 â†’ Saved: ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth/NGC253_C2H2_13.7.cont_smooth.fits
Processing: NGC253_C2H2_13.7.line.fits
 â†’ Saved: ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth/NGC253_C2H2_13.7.line_smooth.fits
Processing: NGC253_CO2_15.0.cont.fits
 â†’ Saved: 

In [42]:
## Trying to correctly pass the header:
import glob
from astropy.wcs import WCS
from reproject import reproject_interp
# --- paths ---
smooth_dir = "../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_smooth"
out_dir = "../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_regrid"
os.makedirs(out_dir, exist_ok=True)

# Reference cube (Fe II or ch4)
ref_path = os.path.join(smooth_dir, "NGC253_PIII_17.9.line_smooth.fits")
with fits.open(ref_path) as ref_hdul:
    ref_hdr = ref_hdul[0].header.copy()
    ref_data = ref_hdul[0].data.copy()
ref_wcs_3d = WCS(ref_hdr)
ref_wcs_2d = ref_wcs_3d.dropaxis(2)
ref_hdr_2d = ref_wcs_2d.to_header()
ny_ref, nx_ref = ref_data.shape[1], ref_data.shape[2]
shape_out_ref = (ny_ref, nx_ref)

# --- Input cubes ---
cube_files = sorted(glob.glob(os.path.join(smooth_dir, "*.fits")))
print(f"Found {len(cube_files)} smoothed cubes.")

for infile in cube_files:
    name = os.path.basename(infile)
    outfile = os.path.join(out_dir, name.replace(".fits", "_regrid.fits"))

    # --- Read input cube ---
    data_in = fits.getdata(infile)
    hdr_in = fits.getheader(infile)
    wcs_in_3d = WCS(hdr_in)
    wcs_in_2d = wcs_in_3d.dropaxis(2)
    hdr_in_2d = wcs_in_2d.to_header()

    nlam = data_in.shape[0]
    print(f"Regridding {name}: {nlam} slices -> {shape_out_ref}")

    # --- Reproject slice by slice ---
    regridded = np.empty((nlam, ny_ref, nx_ref), dtype=np.float32)
    for i in range(nlam):
        plane = data_in[i].astype(float)
        plane_masked = np.where(np.isfinite(plane), plane, np.nan)
        out_plane, _ = reproject_interp((plane_masked, hdr_in_2d),
                                        ref_hdr_2d, shape_out=shape_out_ref)
        regridded[i] = out_plane

    # --- Build output header ---
    out_hdr = ref_hdr.copy()
    out_hdr["NAXIS3"] = nlam
    for key in hdr_in:
        if "3" in key or key.startswith("REST") or key.startswith("SPECSYS"):
            out_hdr[key] = hdr_in[key]



    fits.writeto(outfile, regridded, out_hdr, overwrite=True)
    print(f"  â†’ wrote {outfile} shape={regridded.shape}")

print("All smoothed cubes regridded successfully.")



Found 100 smoothed cubes.
Regridding NGC253_ArIII_9.0.cont_smooth.fits: 53 slices -> (93, 75)
  â†’ wrote ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_regrid/NGC253_ArIII_9.0.cont_smooth_regrid.fits shape=(53, 93, 75)
Regridding NGC253_ArIII_9.0.line_smooth.fits: 53 slices -> (93, 75)
  â†’ wrote ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_regrid/NGC253_ArIII_9.0.line_smooth_regrid.fits shape=(53, 93, 75)
Regridding NGC253_ArII_7.0.cont_smooth.fits: 67 slices -> (93, 75)
  â†’ wrote ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_regrid/NGC253_ArII_7.0.cont_smooth_regrid.fits shape=(67, 93, 75)
Regridding NGC253_ArII_7.0.line_smooth.fits: 67 slices -> (93, 75)
  â†’ wrote ../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_regrid/NGC253_ArII_7.0.line_smooth_regrid.fits shape=(67, 93, 75)
Regridding NGC253_C2H2_13.7.cont_smooth.fits: 42 slices -> (93, 75)
  â†’ wrote ../../github/astr796_25_V2/NGC253_output/3

In [43]:

# ---- Paths ----
input_dir = "../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_regrid"          #use this for frequency based x axis
output_dir = "../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/V1/line_masked"

os.makedirs(output_dir, exist_ok=True)

# ---- Reference Ch I cube for mask ----
ch1_ref = os.path.join(input_dir, "NGC253_S9_5.0.line_smooth_regrid.fits")         #for frequency based maps
if not os.path.exists(ch1_ref):
    raise FileNotFoundError(f"Ch 1 reference cube not found: {ch1_ref}")

fe_data = fits.getdata(ch1_ref)
if fe_data.ndim != 3:
    raise RuntimeError(f"Ch I cube not 3D: shape={fe_data.shape}")

mask = np.ones_like(fe_data[0], dtype=float)
mask[~np.isfinite(fe_data[0])] = np.nan
print(f"âœ… Mask built from Ch I first plane: shape={mask.shape}")

# ---- Apply mask to each cube ----
for fname in sorted(os.listdir(input_dir)):
    if not fname.endswith(".fits"):
        continue
    infile = os.path.join(input_dir, fname)
    data = fits.getdata(infile)
    hdr = fits.getheader(infile)

    if data.ndim != 3:
        print(f"Skipping {fname} (not 3D)")
        continue

    # Apply mask: NaN where Ch I is invalid
    data_masked = data * mask[np.newaxis, :, :]

    # Write masked cube
    outfile = os.path.join(output_dir, fname.replace(".fits", "_maskch1.fits"))
    fits.writeto(outfile, data_masked, hdr, overwrite=True)
    print(f"â†’ Masked {fname}  â†’  {os.path.basename(outfile)}")

print("ðŸŽ‰ All cubes masked using ch1 S9 reference.")


âœ… Mask built from Ch I first plane: shape=(93, 75)
â†’ Masked NGC253_ArIII_9.0.cont_smooth_regrid.fits  â†’  NGC253_ArIII_9.0.cont_smooth_regrid_maskch1.fits
â†’ Masked NGC253_ArIII_9.0.line_smooth_regrid.fits  â†’  NGC253_ArIII_9.0.line_smooth_regrid_maskch1.fits
â†’ Masked NGC253_ArII_7.0.cont_smooth_regrid.fits  â†’  NGC253_ArII_7.0.cont_smooth_regrid_maskch1.fits
â†’ Masked NGC253_ArII_7.0.line_smooth_regrid.fits  â†’  NGC253_ArII_7.0.line_smooth_regrid_maskch1.fits
â†’ Masked NGC253_C2H2_13.7.cont_smooth_regrid.fits  â†’  NGC253_C2H2_13.7.cont_smooth_regrid_maskch1.fits
â†’ Masked NGC253_C2H2_13.7.line_smooth_regrid.fits  â†’  NGC253_C2H2_13.7.line_smooth_regrid_maskch1.fits
â†’ Masked NGC253_CO2_15.0.cont_smooth_regrid.fits  â†’  NGC253_CO2_15.0.cont_smooth_regrid_maskch1.fits
â†’ Masked NGC253_CO2_15.0.line_smooth_regrid.fits  â†’  NGC253_CO2_15.0.line_smooth_regrid_maskch1.fits
â†’ Masked NGC253_ClII_14.4.cont_smooth_regrid.fits  â†’  NGC253_ClII_14.4.cont_smooth_regrid_maskc

In [44]:
raise Exception('stop here')

Exception: stop here

In [None]:


def band_from_wavelength(lam):
    """Return MIRI MRS band for a wavelength in Âµm."""
    lam = float(lam)
    # Channel 1
    if 4.90 <= lam <= 5.74:  return "1A"   # SHORT
    if 5.66 <= lam <= 6.63:  return "1B"   # MEDIUM
    if 6.53 <= lam <= 7.65:  return "1C"   # LONG

    # Channel 2
    if 7.51 <= lam <= 8.77:  return "2A"   # SHORT
    if 8.67 <= lam <= 10.13: return "2B"   # MEDIUM
    if 10.01 <= lam <= 11.70: return "2C"  # LONG

    # Channel 3
    if 11.55 <= lam <= 13.47: return "3A"  # SHORT
    if 13.34 <= lam <= 15.57: return "3B"  # MEDIUM
    if 15.41 <= lam <= 17.98: return "3C"  # LONG

    raise ValueError(f"Wavelength {lam} Âµm is outside MIRI MRS range")


def compute_matching_kernel(psf_native, psf_target, reg=1e-6):
    """
    Compute kernel K such that psf_native * K â‰ˆ psf_target.
    Uses FFT deconvolution with a small regularisation.
    Both psf_native and psf_target must be 2D arrays of the same shape.
    Returns K (same shape) normalised to sum = 1.
    """
    # Ensure same shape; pad if necessary (here we assume they are same)
    if psf_native.shape != psf_target.shape:
        raise ValueError("PSFs must have the same shape")
    
    # FFT
    Fn = fft2(psf_native)
    Ft = fft2(psf_target)
    
    # Wiener deconvolution: K_hat = conj(Fn) * Ft / (|Fn|^2 + reg)
    # To avoid complex division, we compute the complex ratio directly with regularisation.
    denom = np.abs(Fn)**2 + reg
    K_hat = np.conj(Fn) * Ft / denom
    
    # Inverse FFT
    K = np.real(ifft2(K_hat))
    
    # Shift to centre (optional, but good for visualisation)
    K = fftshift(K)
    
    # Normalise to sum=1 (flux conservation)
    K /= K.sum()
    
    return K

def smooth_cube_to_target_psf(infile, outfile, target_wavelength=17.98):
    """
    Homogenise each slice to the PSF at target_wavelength.
    For each slice, compute the matching kernel via FFT deconvolution
    of the native PSF (at slice wavelength) and the target PSF.
    """
    # Read cube
    sc = SpectralCube.read(infile, hdu=1)
    flux = fits.getdata(infile, hdu=1)
    err  = fits.getdata(infile, hdu=2)
    sci_hdr = sc.header.copy()

    # Build wavelength array
    if all(k in sci_hdr for k in ["CRVAL3", "CDELT3", "CRPIX3"]):
        nlam = sci_hdr["NAXIS3"]
        lambdas = ((np.arange(nlam) - (sci_hdr["CRPIX3"] - 1)) *
                   sci_hdr["CDELT3"] + sci_hdr["CRVAL3"])
    else:
        raise ValueError("Spectral WCS keywords missing from header!")

    # Pixel scale in arcsec
    pixscale = sci_hdr["CDELT2"] * 3600.0

    # Allocate output arrays
    sm_flux = np.empty_like(flux)
    sm_err  = np.empty_like(err)

    # --- Preâ€‘compute target PSF (once) ---
    target_band = band_from_wavelength(target_wavelength)
    print(f"Computing target PSF at Î» = {target_wavelength} Âµm (band {target_band})")
    miri_target = stpsf.MIRI()
    miri_target.mode = 'IFU'
    miri_target.band = target_band
    miri_target.pixelscale = pixscale
    psf_target_hdu = miri_target.calc_psf(monochromatic=target_wavelength * u.micron)
    psf_target = psf_target_hdu[1].data.astype(float)
    psf_target /= psf_target.sum()      # normalise

    # Loop over slices
    for i in range(nlam):
        lam = lambdas[i]
        band = band_from_wavelength(lam)

        print(f"Processing slice {i+1}/{nlam}: Î» = {lam:.3f} Âµm, band = {band}")

        # --- Generate native PSF for this slice ---
        miri = stpsf.MIRI()
        miri.mode = 'IFU'
        miri.band = band
        miri.pixelscale = pixscale
        psf_native_hdu = miri.calc_psf(monochromatic=lam * u.micron)
        psf_native = psf_native_hdu[1].data.astype(float)
        psf_native /= psf_native.sum()   # normalise

        # --- Compute matching kernel ---
        # Ensure both PSFs are same size; if not, pad the smaller one.
        # Here we assume they are both generated with same pixel scale and FOV,
        # but stpsf may return different array sizes for different wavelengths.
        # We'll pad to the larger size.
        h1, w1 = psf_native.shape
        h2, w2 = psf_target.shape
        h = max(h1, h2)
        w = max(w1, w2)
        psf_native_pad = np.zeros((h, w))
        psf_target_pad = np.zeros((h, w))
        # Place PSFs at centre
        y1, x1 = (h - h1)//2, (w - w1)//2
        y2, x2 = (h - h2)//2, (w - w2)//2
        psf_native_pad[y1:y1+h1, x1:x1+w1] = psf_native
        psf_target_pad[y2:y2+h2, x2:x2+w2] = psf_target

        # Compute matching kernel
        kernel = compute_matching_kernel(psf_native_pad, psf_target_pad, reg=1e-6)

        # --- Convolve flux ---
        img = np.nan_to_num(flux[i], nan=0.0)
        valid_mask = (~np.isnan(flux[i])).astype(float)

        # Use FFT convolution (kernel is same size as padded PSFs, but may be larger than image).
        # We can crop kernel to a reasonable size (e.g., to the size of the larger PSF)
        # to speed up convolution. Here we simply use fftconvolve with mode='same'.
        sm_slice = fftconvolve(img, kernel, mode='same')
        sm_slice *= valid_mask
        sm_slice[valid_mask == 0] = np.nan
        sm_flux[i] = sm_slice

        # --- Propagate errors (variance) ---
        var = np.nan_to_num(err[i]**2, nan=0.0)
        kernel_sq = kernel**2
        kernel_sq /= kernel_sq.sum()     # normalise squared kernel
        #f I dont normalise this PAHFIT breaks! for line files we wont need this

        sm_var = fftconvolve(var, kernel_sq, mode='same')
        sm_var = np.clip(sm_var, 0, None)
        sm_err[i] = np.sqrt(sm_var)
        sm_err[i][valid_mask == 0] = np.nan

    # Write output FITS file
    primary = fits.PrimaryHDU()
    sci_hdu = fits.ImageHDU(data=sm_flux.astype(np.float32), header=sci_hdr, name='SCI')
    err_hdu = fits.ImageHDU(data=sm_err.astype(np.float32), header=sci_hdr.copy(), name='ERR')
    dq_hdu  = fits.ImageHDU(data=np.zeros_like(sm_flux, dtype=np.uint32),
                            header=sci_hdr.copy(), name='DQ')
    hdul = fits.HDUList([primary, sci_hdu, err_hdu, dq_hdu])
    hdul.writeto(outfile, overwrite=True)

    print(f"Wrote {outfile}, shape={sm_flux.shape}")


# -----------------------
# Example usage for 3 channels
# -----------------------
outdir = "./psf_smooth_cubes"
os.makedirs(outdir, exist_ok=True)

for ch in range(1, 4):
    infile = f"../../github/astr796_25_V2/files/NGC253_sky_v1_17_1_ch{ch}-shortmediumlong_s3d.fits"
    outfile = os.path.join(outdir, f"NGC253_sky_v1_17_1_ch{ch}-smooth.fits")
    smooth_cube_to_target_psf(infile, outfile, target_wavelength=17.98)



In [None]:
#Comparision 

# Choose your own file, please make sure to use same wavelength kernel for Gaussian smoothing
# I will link a repo for that too and put link here if you want to know how to impliment that
gauss_file = "../../github/astr796_25_V2/Task5/3channel/final_cube/NGC253_sky_v1_17_1_ch1-shortmediumlong_s3d_smooth.regrid-ch3.mask.fits"
psf_file   = "../../github/astr796_25_V2/Task5/3channel/psf/final_cube/NGC253_sky_v1_17_1_ch1-shortmediumlong_s3d_smooth.regrid-ch3.mask.fits"

#psf_file   = "../../github/astr796_25_V2/Task5/3channel/psf_smooth_cubes/NGC253_ch1_psfmatched.fits"

gauss_hdul = fits.open(gauss_file)
psf_hdul   = fits.open(psf_file)

gauss_flux = gauss_hdul['SCI'].data
psf_flux   = psf_hdul['SCI'].data
hdr        = gauss_hdul['SCI'].header

nlam   = hdr['NAXIS3']
crval3 = hdr['CRVAL3']   # starting wavelength
cdelt3 = hdr['CDELT3']   # wavelength step
crpix3 = hdr['CRPIX3']   # reference pixel

idx = np.arange(nlam)
lambdas_um = (idx - (crpix3 - 1)) * cdelt3 + crval3   # wavelength array in microns

#arii_wave = 6.985  # microns, adjust if you use a slightly different value
arii_wave = 5.525   #I am just changing wavelength to see what are the effects on another slice of data
i_arii = np.argmin(np.abs(lambdas_um - arii_wave))
print("Closest slice to [Ar II]:", i_arii, "lambda =", lambdas_um[i_arii])

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

vmin = np.nanpercentile(gauss_flux[i_arii], 5)
vmax = np.nanpercentile(gauss_flux[i_arii], 95)

ax[0].imshow(gauss_flux[i_arii], origin='lower', vmin=vmin, vmax=vmax)
ax[0].set_title(f"Gaussian @ [Ar II] ~ {lambdas_um[i_arii]:.3f} Âµm")

ax[1].imshow(psf_flux[i_arii], origin='lower', vmin=vmin, vmax=vmax)
ax[1].set_title("JWST PSF smoothed")

ax[2].imshow(psf_flux[i_arii] - gauss_flux[i_arii], origin='lower')
ax[2].set_title("Difference (PSF âˆ’ Gaussian)")

plt.tight_layout()
plt.show()

slice_gauss = gauss_flux[i_arii]
slice_psf   = psf_flux[i_arii]

diff = slice_psf - slice_gauss

rms_diff = np.nanstd(diff)
print("RMS difference at [Ar II]:", rms_diff)

mean_signal = np.nanmean(slice_gauss)
rel_rms = rms_diff / mean_signal
print("Relative RMS difference (RMS / mean signal):", rel_rms)
mask = np.isfinite(slice_gauss) & np.isfinite(slice_psf)
g = slice_gauss[mask].ravel()
p = slice_psf[mask].ravel()

corr = np.corrcoef(g, p)[0, 1]
print("Pixel-wise correlation at [Ar II]:", corr)
ny, nx = slice_gauss.shape
yc, xc = ny // 2, nx // 2
r = 5  # radius in pixels

y, x = np.indices(slice_gauss.shape)
mask_ap = (x - xc)**2 + (y - yc)**2 <= r**2

flux_gauss_ap = np.nansum(slice_gauss[mask_ap])
flux_psf_ap   = np.nansum(slice_psf[mask_ap])

print("Aperture flux (Gaussian):", flux_gauss_ap)
print("Aperture flux (PSF):     ", flux_psf_ap)
print("Relative difference:", (flux_psf_ap - flux_gauss_ap) / flux_gauss_ap)


In [None]:
'''
def smooth_cube_with_error_3d_scipy(infile, outfile):
    """
    Smooth each wavelength slice in 2D (y, x) using the JWST MIRI 3C PSF
    at 17.98 micron instead of a Gaussian.
    """

    from scipy.signal import fftconvolve

    # Read cube
    sc = SpectralCube.read(infile, hdu=1)
    flux = fits.getdata(infile, hdu=1)
    err  = fits.getdata(infile, hdu=2)
    sci_hdr = sc.header.copy()

    # Wavelength array (spectral WCS)
    if all(k in sci_hdr for k in ["CRVAL3", "CDELT3", "CRPIX3"]):
        nlam = sci_hdr["NAXIS3"]
        lambdas = ((np.arange(nlam) - (sci_hdr["CRPIX3"] - 1)) *
                   sci_hdr["CDELT3"] + sci_hdr["CRVAL3"])
    else:
        raise ValueError("Spectral WCS keywords missing from SCI header!")

    # Pixel scale (deg -> arcsec)
    pixscale = sci_hdr["CDELT2"] * 3600.0

    #https://stpsf.readthedocs.io/en/latest/jwst_ifu_datacubes.html#MIRI-MRS-example
    # Allocate output arrays
    sm_flux = np.empty_like(flux)
    sm_err  = np.empty_like(err)

    # --- PRECOMPUTE PSF OUTSIDE LOOP ---
    miri = stpsf.MIRI()
    miri.mode = 'IFU'
    miri.band = '3C'
    miri.pixelscale = pixscale

    psf = miri.calc_psf(monochromatic=17.98 * u.micron) #https://stpsf.readthedocs.io/en/latest/api/stpsf.JWInstrument.html#stpsf.JWInstrument.calc_psf
    psf_kernel = psf[1].data.astype(float) #0-header only 1- data
    psf_kernel /= psf_kernel.sum() #normalization

    # Loop over slices and apply PSF convolution
    for i in range(nlam):

        img = np.nan_to_num(flux[i], nan=0.0) #extract slice and replace nan
        valid_mask = (~np.isnan(flux[i])).astype(float) #build mask of valid pixel

        # Convolve flux #takes the img input, blur it using jwst psf and return same size output
        sm_slice = fftconvolve(img, psf_kernel, mode='same') #https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.fftconvolve.html
        sm_slice *= valid_mask #this is a memory and the slice will have nan mapped
        sm_slice[valid_mask == 0] = np.nan #this replaces any value in slice with nan which was nan in input
        sm_flux[i] = sm_slice #this is loop over slices 

        #Because I will be using the data into PAHFIT I need to propagate error too
        #If you dont need to do so, look for the line fine method which doesnot have the below step

        # Convolve variance
        var = np.nan_to_num(err[i]**2, nan=0.0) 
        psf_sq = psf_kernel**2
        psf_sq /= psf_sq.sum()

        sm_var = fftconvolve(var, psf_sq, mode='same')
        sm_var = np.clip(sm_var, 0, None)
        sm_err[i] = np.sqrt(sm_var)
        sm_err[i][valid_mask == 0] = np.nan

    # ---- Write FITS in JWST-like structure ----
    primary = fits.PrimaryHDU()
    sci_hdu = fits.ImageHDU(data=sm_flux.astype(np.float32), header=sci_hdr, name='SCI')
    err_hdu = fits.ImageHDU(data=sm_err.astype(np.float32), header=sci_hdr.copy(), name='ERR')
    dq_hdu  = fits.ImageHDU(data=np.zeros_like(sm_flux, dtype=np.uint32),
                            header=sci_hdr.copy(), name='DQ')

    hdul = fits.HDUList([primary, sci_hdu, err_hdu, dq_hdu])
    hdul.writeto(outfile, overwrite=True)

    print(f"Wrote {outfile}, shape={sm_flux.shape}")
    print("  spectral keys:",
          {k: sci_hdr.get(k) for k in ['CRVAL3','CRPIX3','CDELT3','CTYPE3','CUNIT3']})

#Sorry the directories are crazzy on my side please choose to use your own directory
# Ensure output directory exists
outdir = "../../github/astr796_25_V2/Task5/3channel/psf_smooth_cubes"
os.makedirs(outdir, exist_ok=True)

#I cant upload more than 100mb file on github, please get the data from mast or choose the directory of your data
# ---- RUN FOR CHANNELS 1â€“3 ----
for ch in range(1, 4):
    infile  = f"../../github/astr796_25_V2/files/NGC253_sky_v1_17_1_ch{ch}-shortmediumlong_s3d.fits"
    outfile = f"{outdir}/NGC253_sky_v1_17_1_ch{ch}-shortmediumlong_s3d_smooth_3d.fits"
    smooth_cube_with_error_3d_scipy(infile, outfile)

'''

### For individual line files

In [None]:

# Choose the line file you want to compare
line_name = "NGC253_FeII_5.3.line_smooth.fits"

gauss_file = f"../../github/astr796_25_V2/NGC253_output/3channel/casa/line_smooth/{line_name}"
psf_file   = f"../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/line_smooth/{line_name}"

# Load cubes
gauss = fits.getdata(gauss_file)
psf   = fits.getdata(psf_file)
hdr   = fits.getheader(gauss_file)

# --- Build wavelength grid from WCS ---
nlam   = hdr["NAXIS3"]
crval3 = hdr["CRVAL3"]
cdelt3 = hdr["CDELT3"]
crpix3 = hdr["CRPIX3"]

idx = np.arange(nlam)
lam = (idx - (crpix3 - 1)) * cdelt3 + crval3   # wavelength array

# This is just picking up the brigtest slice from line file 
# needs to be constrained and still working on this
# If the line file is appropriate the brightest should be the peak thats the logic here for comparision

# --- Automatically pick the brightest slice ---
slice_means = np.nanmean(gauss, axis=(1,2))
i_slice = np.nanargmax(slice_means)

print("Comparing slice index:", i_slice)

# Extract slices
g = gauss[i_slice]
p = psf[i_slice]
d = p - g

# --- Plot ---
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

vmin = np.nanpercentile(g, 5)
vmax = np.nanpercentile(g, 95)

ax[0].imshow(g, origin='lower', vmin=vmin, vmax=vmax)
ax[0].set_title("Gaussian smoothed")

ax[1].imshow(p, origin='lower', vmin=vmin, vmax=vmax)
ax[1].set_title("JWST PSF smoothed")

ax[2].imshow(d, origin='lower')
ax[2].set_title("Difference (PSF âˆ’ Gaussian)")

plt.tight_layout()
plt.show()

# --- Metrics ---
rms = np.nanstd(d)

mask = np.isfinite(g) & np.isfinite(p)
corr = np.corrcoef(g[mask], p[mask])[0,1]

print("RMS difference:", rms)
print("Correlation:", corr)

# --- Aperture flux comparison ---
ny, nx = g.shape
yc, xc = ny//2, nx//2
r = 5

Y, X = np.indices(g.shape)
mask = (X-xc)**2 + (Y-yc)**2 <= r**2

flux_g = np.nansum(g[mask])
flux_p = np.nansum(p[mask])

print("Aperture flux (Gaussian):", flux_g)
print("Aperture flux (PSF):     ", flux_p)
print("Relative difference:", (flux_p - flux_g)/flux_g)


In [None]:
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt

# Choose the line file you want to compare
line_name = "NGC253_FeII_5.3.line_smooth.fits"

gauss_file = f"../../github/astr796_25_V2/NGC253_output/3channel/casa/line_smooth/{line_name}"
psf_file   = f"../../github/astr796_25_V2/NGC253_output/3channel/jwst_psf/line_smooth/{line_name}"

# Load cubes
gauss = fits.getdata(gauss_file)
psf   = fits.getdata(psf_file)
hdr   = fits.getheader(gauss_file)

# --- Build wavelength grid from WCS ---
nlam   = hdr["NAXIS3"]
crval3 = hdr["CRVAL3"]
cdelt3 = hdr["CDELT3"]
crpix3 = hdr["CRPIX3"]

idx = np.arange(nlam)
lam = (idx - (crpix3 - 1)) * cdelt3 + crval3

# Convert meters â†’ microns if needed
if hdr.get("CUNIT3", "").lower() == "m":
    lam *= 1e6

#This is not working, I am still looking how to accurately get the targetted wavelength
#as casa tools has messed up headers while extracting line files

# --- Pick slice closest to a given wavelength ---
target_wavelength = 5.34  # microns (Fe II line)
i_slice = np.argmin(np.abs(lam - target_wavelength))

print("Comparing slice index:", i_slice)
print("Wavelength of this slice:", lam[i_slice], "micron")

# Extract slices
g = gauss[i_slice]
p = psf[i_slice]
d = p - g

# --- Plot ---
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

vmin = np.nanpercentile(g, 5)
vmax = np.nanpercentile(g, 95)

ax[0].imshow(g, origin='lower', vmin=vmin, vmax=vmax)
ax[0].set_title("Gaussian smoothed")

ax[1].imshow(p, origin='lower', vmin=vmin, vmax=vmax)
ax[1].set_title("JWST PSF smoothed")

ax[2].imshow(d, origin='lower')
ax[2].set_title("Difference (PSF âˆ’ Gaussian)")

plt.tight_layout()
plt.show()

# --- Metrics ---
rms = np.nanstd(d)

mask = np.isfinite(g) & np.isfinite(p)
corr = np.corrcoef(g[mask], p[mask])[0,1]

print("RMS difference:", rms)
print("Correlation:", corr)

# --- Aperture flux comparison ---
ny, nx = g.shape
yc, xc = ny//2, nx//2
r = 5

Y, X = np.indices(g.shape)
mask = (X-xc)**2 + (Y-yc)**2 <= r**2

flux_g = np.nansum(g[mask])
flux_p = np.nansum(p[mask])

print("Aperture flux (Gaussian):", flux_g)
print("Aperture flux (PSF):     ", flux_p)
print("Relative difference:", (flux_p - flux_g)/flux_g)
