# JWST Time-Series Analysis Lab v2  
## Gaussian-Process systematics (celerite2 / george) + Transit Spectroscopy (batman) + Imaging TSO photometry (rateints / calints)

This notebook extends the previous MAST+astroquery workflow with three “space-special” upgrades:

1. **Replace simple detrending with Gaussian Processes** for instrument systematics (prefer **celerite2**, optional **george** fallback).  
2. **Fit a transit model (batman) jointly across wavelength bins** to produce a **transmission spectrum** (Rp/Rs vs wavelength).  
3. For **imaging TSOs** (`rateints` / `calints`), perform **aperture photometry per integration** (similar to differential photometry workflows).

Notes:
- JWST time stamps often live in the `INT_TIMES` table; a common mid-time column is `int_mid_MJD_UTC`.  
- Product availability varies by program/mode; this notebook tries **WHTLT → X1DINTS → RATEINTS/CALINTS** in that order.


In [None]:
# If you're missing dependencies, uncomment and run:
# !pip -q install astroquery astropy numpy matplotlib scipy pandas
# !pip -q install celerite2 george
# !pip -q install batman-package
# !pip -q install photutils

print("Ready.")

In [None]:
# -----------------------------
# Imports
# -----------------------------
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import astropy.units as u
from astropy.io import fits
from astropy.table import Table
from astropy.stats import sigma_clip
from astropy.timeseries import LombScargle

from astroquery.mast import Observations

# SciPy (optim + optional filters)
try:
    from scipy.optimize import minimize, least_squares
    from scipy.signal import savgol_filter
    HAS_SCIPY = True
except Exception:
    HAS_SCIPY = False

# Photutils (recommended for imaging TSOs)
try:
    from photutils.aperture import CircularAperture, CircularAnnulus, aperture_photometry
    HAS_PHOTUTILS = True
except Exception:
    HAS_PHOTUTILS = False

# GP backends
try:
    import celerite2
    from celerite2 import terms
    from celerite2 import GaussianProcess
    HAS_CELERITE2 = True
except Exception:
    HAS_CELERITE2 = False

try:
    import george
    from george import kernels
    HAS_GEORGE = True
except Exception:
    HAS_GEORGE = False

# Transit model
try:
    import batman
    HAS_BATMAN = True
except Exception:
    HAS_BATMAN = False

# Optional transit search
try:
    from astropy.timeseries import BoxLeastSquares
    HAS_BLS = True
except Exception:
    HAS_BLS = False

print("HAS_SCIPY     :", HAS_SCIPY)
print("HAS_PHOTUTILS :", HAS_PHOTUTILS)
print("HAS_CELERITE2 :", HAS_CELERITE2)
print("HAS_GEORGE    :", HAS_GEORGE)
print("HAS_BATMAN    :", HAS_BATMAN)
print("HAS_BLS       :", HAS_BLS)


In [None]:
# -----------------------------
# Configuration
# -----------------------------
OUT_DIR = Path("jwst_timeseries_outputs_v2")
OUT_DIR.mkdir(parents=True, exist_ok=True)

DOWNLOAD_DIR = OUT_DIR / "mast_downloads"
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

# Entry points: set either TARGET_NAME (+SEARCH_RADIUS) OR PROGRAM_ID
TARGET_NAME = "WASP-39"
SEARCH_RADIUS = 0.03 * u.deg

PROGRAM_ID = None   # e.g. 1366

# Choose observation row after query
OBS_INDEX = 0

# Product preference order (MAST productSubGroupDescription)
PREFERRED_SUBGROUPS = ["WHTLT", "X1DINTS", "RATEINTS", "CALINTS"]

# Spectroscopic binning (for X1DINTS)
NBINS = 8
WL_RANGE_MICRON = (None, None)  # e.g. (1.1, 1.7)

# Imaging aperture photometry (pixels)
AP_R    = 5.0
ANN_RIN = 9.0
ANN_ROUT= 14.0

# Target centroid selection (imaging only)
TARGET_MODE = "auto"  # "auto" or "manual"
MANUAL_XY = (None, None)  # e.g. (x, y)

# GP systematics settings
GP_BACKEND = "celerite2"   # "celerite2" or "george" (auto-fallback if unavailable)
GP_Q_FIXED = 1/np.sqrt(2)  # SHOTerm Q fixed (celerite2); reduces degeneracy
GP_INIT_LOGS0 = np.log(1e-6)
GP_INIT_LOGW0 = np.log(1.0)  # 1/days
GP_INIT_LOGJITTER = np.log(1e-4)

# Transit model settings (batman)
LD_LAW = "quadratic"
LD_U = [0.1, 0.3]   # limb-darkening coefficients (placeholder; tune for instrument+star)
SUPERSAMPLE = 1     # increase if cadence is coarse
EXP_TIME_DAYS = 0.0 # set if you want batman exposure time integration (days), else 0

# Joint fit: which shared params to fit (keep simple by default)
FIT_SHARED = dict(t0=True, per=False, a=False, inc=False)
# Provide initial guesses (days, days, stellar radii units, degrees)
TRANSIT_INIT = dict(
    t0=None,        # if None, inferred from median time
    per=1.0,        # set your known period if available
    rp=0.15,        # Rp/Rs (white-light initial)
    a=10.0,         # a/Rs
    inc=88.0,       # deg
    ecc=0.0,
    w=90.0
)

print("Output dir:", OUT_DIR.resolve())


In [None]:
# -----------------------------
# Helpers: MAST query
# -----------------------------
def query_jwst():
    if PROGRAM_ID is not None:
        print(f"Querying JWST by PROGRAM_ID={PROGRAM_ID} ...")
        return Observations.query_criteria(obs_collection="JWST", proposal_id=str(PROGRAM_ID))

    print(f"Querying JWST near TARGET_NAME='{TARGET_NAME}' (radius={SEARCH_RADIUS}) ...")
    obs = Observations.query_object(TARGET_NAME, radius=SEARCH_RADIUS)
    if "obs_collection" in obs.colnames:
        obs = obs[obs["obs_collection"] == "JWST"]
    return obs

obs = query_jwst()
print("Total observations returned:", len(obs))
obs[:5]


In [None]:
# -----------------------------
# Pick an observation
# -----------------------------
if len(obs) == 0:
    raise RuntimeError("No observations found. Try a different TARGET_NAME / SEARCH_RADIUS or set PROGRAM_ID.")

cols = [c for c in ["obs_id","target_name","instrument_name","filters","dataproduct_type","t_min","t_max","proposal_id"] if c in obs.colnames]
if "t_min" in obs.colnames:
    obs = obs[np.argsort(obs["t_min"])]

print("Preview columns:", cols)
(obs[cols][:20] if cols else obs[:20])


In [None]:
# -----------------------------
# Products + download
# -----------------------------
obs_row = obs[[OBS_INDEX]]
print("Selected obs_id:", obs_row["obs_id"][0] if "obs_id" in obs_row.colnames else "(unknown)")

products = Observations.get_product_list(obs_row)
print("Products:", len(products))

if "productSubGroupDescription" in products.colnames:
    products["productSubGroupDescription"] = [str(x).upper() for x in products["productSubGroupDescription"]]
    uniq = sorted(set(products["productSubGroupDescription"]))
    print("Unique subgroups (first 60):", uniq[:60])

    wanted = Observations.filter_products(
        products,
        productSubGroupDescription=PREFERRED_SUBGROUPS,
        mrp_only=False
    )
else:
    wanted = products

print("Filtered products:", len(wanted))
wanted[:10]


In [None]:
# Download
if len(wanted) == 0:
    raise RuntimeError("No products matched. Inspect 'products' and adjust PREFERRED_SUBGROUPS.")

manifest = Observations.download_products(
    wanted,
    download_dir=str(DOWNLOAD_DIR),
    cache=True
)
print("Downloaded rows:", len(manifest))
manifest[:10]


In [None]:
# -----------------------------
# Helpers: locate downloaded files
# -----------------------------
def local_paths_from_manifest(manifest):
    paths = []
    for p in manifest["Local Path"]:
        if p and str(p).strip().lower() != "none":
            paths.append(Path(p))
    return [p for p in paths if p.exists()]

local_paths = local_paths_from_manifest(manifest)
print("Local files:", len(local_paths))
for p in local_paths[:15]:
    print(" -", p.name)

def pick_by_subgroup(paths, key):
    key = key.lower()
    return [p for p in paths if key in p.name.lower()]

whtlt_files   = pick_by_subgroup(local_paths, "whtlt")
x1dints_files = pick_by_subgroup(local_paths, "x1dints")
rateints_files= pick_by_subgroup(local_paths, "rateints")
calints_files = pick_by_subgroup(local_paths, "calints")

print("Found:", {"whtlt": len(whtlt_files), "x1dints": len(x1dints_files),
                "rateints": len(rateints_files), "calints": len(calints_files)})


In [None]:
# -----------------------------
# Loaders: WHTLT / X1DINTS / Imaging cubes (RATEINTS/CALINTS)
# -----------------------------
def load_int_times(hdul):
    # Try to load INT_TIMES and return mid-times if present
    for hdu in hdul:
        if isinstance(hdu, fits.BinTableHDU) and ("INT_TIMES" in hdu.name.upper()):
            tab = Table(hdu.data)
            # common JWST column name:
            for cname in tab.colnames:
                if cname.lower() == "int_mid_mjd_utc":
                    return np.asarray(tab[cname], dtype=float), cname
            # fallbacks
            for cname in tab.colnames:
                if ("mid" in cname.lower()) and ("mjd" in cname.lower()):
                    return np.asarray(tab[cname], dtype=float), cname
            for cname in tab.colnames:
                if ("mid" in cname.lower()) and ("bjd" in cname.lower()):
                    return np.asarray(tab[cname], dtype=float), cname
            return None, None
    return None, None

def load_whtlt(path):
    with fits.open(path, memmap=False) as hdul:
        best = None
        for hdu in hdul:
            if not isinstance(hdu, fits.BinTableHDU):
                continue
            cols = [c.upper() for c in hdu.columns.names]
            time_candidates = [c for c in cols if ("TIME" in c or "MJD" in c or "BJD" in c)]
            flux_candidates = [c for c in cols if ("FLUX" in c and "ERR" not in c)]
            if time_candidates and flux_candidates:
                best = hdu
                break
        if best is None:
            raise RuntimeError("Could not find a time/flux table in WHTLT file.")

        tab = Table(best.data)

        tcol = next(c for c in tab.colnames if ("TIME" in c.upper() or "MJD" in c.upper() or "BJD" in c.upper()))
        fcol = next(c for c in tab.colnames if (c.upper() == "FLUX" or (("FLUX" in c.upper()) and ("ERR" not in c.upper()))))
        ecol = None
        for c in tab.colnames:
            if "ERR" in c.upper() and "FLUX" in c.upper():
                ecol = c
                break

        t = np.asarray(tab[tcol], dtype=float)
        y = np.asarray(tab[fcol], dtype=float)
        dy = np.asarray(tab[ecol], dtype=float) if ecol is not None else None
        meta = {"source":"WHTLT","file":str(path),"tcol":tcol,"fcol":fcol,"ecol":ecol}
        return t, y, dy, meta, None

def load_x1dints(path, wl_range_micron=(None,None)):
    with fits.open(path, memmap=False) as hdul:
        # times
        t, tcol = load_int_times(hdul)

        # extract1d table
        h = None
        for hdu in hdul:
            if isinstance(hdu, fits.BinTableHDU) and (hdu.name.upper() == "EXTRACT1D"):
                h = hdu
                break
        if h is None:
            for hdu in hdul:
                if not isinstance(hdu, fits.BinTableHDU):
                    continue
                cols = [c.upper() for c in hdu.columns.names]
                if "FLUX" in cols and "WAVELENGTH" in cols:
                    h = hdu
                    break
        if h is None:
            raise RuntimeError("Could not find EXTRACT1D-like table with FLUX/WAVELENGTH in X1DINTS.")

        tab = Table(h.data)
        wcol = next(c for c in tab.colnames if c.upper() == "WAVELENGTH")
        fcol = next(c for c in tab.colnames if c.upper() == "FLUX")

        wave = np.asarray(tab[wcol])
        flux = np.asarray(tab[fcol])

        # Normalize shapes: want flux2d = (nint, nwave), wave1d = (nwave,)
        flux = np.asarray(flux)
        wave = np.asarray(wave)

        if flux.ndim == 1:
            raise RuntimeError("FLUX is 1D; not integration-resolved. Try another file/obs.")

        if wave.ndim == 1:
            if flux.shape[1] == wave.shape[0]:
                flux2d = flux
                wave1d = wave
            elif flux.shape[0] == wave.shape[0]:
                flux2d = flux.T
                wave1d = wave
            else:
                flux2d = flux.reshape(flux.shape[0], -1)
                wave1d = np.linspace(0, 1, flux2d.shape[1])
        else:
            wave1d = wave[0] if wave.ndim >= 2 else wave
            flux2d = flux if flux.shape[1] == len(wave1d) else flux.T

        nint, nw = flux2d.shape
        if t is None:
            t = np.arange(nint, dtype=float)
            tcol = "index"

        wmin, wmax = wl_range_micron
        if wmin is None: wmin = float(np.nanmin(wave1d))
        if wmax is None: wmax = float(np.nanmax(wave1d))
        mask = (wave1d >= wmin) & (wave1d <= wmax)
        if not np.any(mask):
            raise RuntimeError("Wavelength mask empty; adjust WL_RANGE_MICRON.")

        # white-light curve
        y = np.nansum(flux2d[:, mask], axis=1)
        meta = {"source":"X1DINTS","file":str(path),"tcol":tcol,"wmin":wmin,"wmax":wmax,"nint":nint,"nwave":nw}
        return t, y, None, meta, {"wave": wave1d, "flux2d": flux2d}

def _find_sci_hdu(hdul):
    # Prefer named SCI extension, else first image extension
    for hdu in hdul:
        if isinstance(hdu, fits.ImageHDU) and hdu.name.upper() == "SCI":
            return hdu
    # primary can also be image
    for hdu in hdul:
        if isinstance(hdu, (fits.PrimaryHDU, fits.ImageHDU)) and (hdu.data is not None) and (np.asarray(hdu.data).ndim >= 2):
            return hdu
    return None

def detect_integration_axis(arr):
    arr = np.asarray(arr)
    if arr.ndim < 3:
        return None
    # assume spatial axes are the two largest dimensions (common in imaging)
    sizes = list(arr.shape)
    spatial_axes = sorted(range(arr.ndim), key=lambda ax: sizes[ax], reverse=True)[:2]
    non_spatial = [ax for ax in range(arr.ndim) if ax not in spatial_axes]
    if not non_spatial:
        return None
    # integration axis often larger than other non-spatial (e.g., groups); choose largest among non-spatial
    non_spatial.sort(key=lambda ax: sizes[ax], reverse=True)
    return non_spatial[0], spatial_axes

def iter_integration_frames(arr):
    arr = np.asarray(arr)
    iax, spatial_axes = detect_integration_axis(arr)
    if iax is None:
        raise RuntimeError("Could not detect integration axis in imaging cube.")
    # Move integration axis first
    x = np.moveaxis(arr, iax, 0)
    # Now need to reduce any remaining non-spatial axes (e.g., group axis) by selecting the last plane
    while x.ndim > 3:
        # take last index along axis=1 until shape is (nint, ny, nx) in some order
        x = x[:, -1]
    # Ensure last two axes are spatial: if not, permute by choosing two largest as spatial
    if x.ndim != 3:
        raise RuntimeError("Unexpected imaging cube shape after reduction.")
    # reorder to (nint, ny, nx)
    sizes = list(x.shape)
    spatial = sorted(range(1, 3), key=lambda ax: sizes[ax], reverse=True)  # among last two
    # last two already spatial; return as is
    return x

def load_imaging_cube(path):
    with fits.open(path, memmap=False) as hdul:
        t, tcol = load_int_times(hdul)
        sci = _find_sci_hdu(hdul)
        if sci is None:
            raise RuntimeError("Could not find SCI image data in file.")
        cube = np.asarray(sci.data)
        cube3 = iter_integration_frames(cube)  # (nint, ny, nx)
        nint = cube3.shape[0]
        if t is None:
            t = np.arange(nint, dtype=float)
            tcol = "index"
        meta = {"source":"IMAGING", "file":str(path), "tcol":tcol, "shape":cube3.shape}
        return t, cube3, meta

# Choose best available
extra = None
if len(whtlt_files) > 0:
    t, y, dy, meta, extra = load_whtlt(whtlt_files[0])
elif len(x1dints_files) > 0:
    t, y, dy, meta, extra = load_x1dints(x1dints_files[0], WL_RANGE_MICRON)
elif len(rateints_files) > 0:
    t, cube3, meta = load_imaging_cube(rateints_files[0])
    y = dy = None
    extra = {"cube3": cube3}
elif len(calints_files) > 0:
    t, cube3, meta = load_imaging_cube(calints_files[0])
    y = dy = None
    extra = {"cube3": cube3}
else:
    raise RuntimeError("No WHTLT/X1DINTS/RATEINTS/CALINTS found for this observation.")

print("Loaded meta:", meta)
print("Time length:", len(t), "| min..max:", float(np.nanmin(t)), float(np.nanmax(t)))


In [None]:
# -----------------------------
# If imaging cube: aperture photometry per integration -> light curve
# -----------------------------
def sanitize_finite(img):
    img = np.asarray(img, dtype=float)
    finite = np.isfinite(img)
    fill = np.nanmedian(img[finite]) if np.any(finite) else 0.0
    return np.where(finite, img, fill)

def auto_centroid_peak(img, box_halfsize=15):
    img = sanitize_finite(img)
    ny, nx = img.shape
    cy, cx = ny//2, nx//2
    y0, y1 = max(0, cy-box_halfsize), min(ny, cy+box_halfsize+1)
    x0, x1 = max(0, cx-box_halfsize), min(nx, cx+box_halfsize+1)
    sub = img[y0:y1, x0:x1]
    iy, ix = np.unravel_index(np.argmax(sub), sub.shape)
    return float(x0+ix), float(y0+iy)

def aperture_photometry_manual(img, x0, y0, r, rin, rout):
    img = sanitize_finite(img)
    ny, nx = img.shape
    yy, xx = np.indices((ny, nx))
    rr = np.sqrt((xx-x0)**2 + (yy-y0)**2)
    aper = rr <= r
    ann = (rr >= rin) & (rr <= rout)
    bkg = np.nanmedian(img[ann]) if np.any(ann) else 0.0
    flux = np.nansum(img[aper] - bkg)
    return flux, bkg

def extract_lightcurve_from_cube(cube3, t, ap_r=5.0, ann_rin=9.0, ann_rout=14.0):
    cube3 = np.asarray(cube3)
    nint = cube3.shape[0]
    frame0 = cube3[0]

    if TARGET_MODE == "manual" and (MANUAL_XY[0] is not None) and (MANUAL_XY[1] is not None):
        x0, y0 = float(MANUAL_XY[0]), float(MANUAL_XY[1])
    else:
        x0, y0 = auto_centroid_peak(frame0)

    flux = np.zeros(nint, dtype=float)
    bkg  = np.zeros(nint, dtype=float)

    if HAS_PHOTUTILS:
        positions = [(x0, y0)]
        aper = CircularAperture(positions, r=ap_r)
        ann  = CircularAnnulus(positions, r_in=ann_rin, r_out=ann_rout)
        aper_area = aper.area
        ann_area = ann.area

        for i in range(nint):
            img = sanitize_finite(cube3[i])
            ap_tbl = aperture_photometry(img, aper)
            an_tbl = aperture_photometry(img, ann)
            ann_sum = float(an_tbl["aperture_sum"][0])
            ann_mean = ann_sum / ann_area
            flux[i] = float(ap_tbl["aperture_sum"][0]) - ann_mean * aper_area
            bkg[i] = ann_mean
    else:
        for i in range(nint):
            f, b = aperture_photometry_manual(cube3[i], x0, y0, ap_r, ann_rin, ann_rout)
            flux[i] = f
            bkg[i] = b

    return np.asarray(t, dtype=float), flux, None, {"x0":x0,"y0":y0,"ap_r":ap_r,"ann_rin":ann_rin,"ann_rout":ann_rout}

if meta.get("source") == "IMAGING":
    tt, yy, ddy, phot_meta = extract_lightcurve_from_cube(extra["cube3"], t, AP_R, ANN_RIN, ANN_ROUT)
    t, y, dy = tt, yy, ddy
    meta.update({"photometry": phot_meta})
    print("Aperture photometry centroid:", phot_meta["x0"], phot_meta["y0"])


In [None]:
# -----------------------------
# Plot raw light curve (any source)
# -----------------------------
t = np.asarray(t, dtype=float)
y = np.asarray(y, dtype=float)
dy = (np.asarray(dy, dtype=float) if dy is not None else None)

y0 = np.nanmedian(y)
yn = y / y0

plt.figure(figsize=(10,4))
plt.plot(t, yn, ".", ms=3)
plt.xlabel(f"Time ({meta.get('tcol','time')})")
plt.ylabel("Normalized flux")
plt.title(f"Raw light curve ({meta.get('source','?')})")
plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# Clean (sigma-clip) and define a transit mask (optional)
# -----------------------------
# Outlier clipping
cl = sigma_clip(yn, sigma=5.0, maxiters=5)
mask_ok = ~cl.mask if hasattr(cl, "mask") else np.isfinite(cl)
t_clean = t[mask_ok]
y_clean = yn[mask_ok]
dy_clean = (dy[mask_ok]/y0 if dy is not None else None)

# Define a transit mask (exclude transit) if you know approximate ephemeris.
# This is helpful for learning GP systematics using out-of-transit data.
# If unknown, set TRANSIT_MASK = None and we fit everything.
if TRANSIT_INIT["t0"] is None:
    TRANSIT_INIT["t0"] = float(np.nanmedian(t_clean))

# Example: mask a window +/- 0.05 days around t0 (edit!)
TRANSIT_HALF_WINDOW_DAYS = 0.05
TRANSIT_MASK = np.abs(t_clean - TRANSIT_INIT["t0"]) > TRANSIT_HALF_WINDOW_DAYS  # True = keep

print("Clean points:", len(t_clean))
print("Transit mask keeps:", int(np.sum(TRANSIT_MASK)), "points")


In [None]:
# -----------------------------
# GP systematics model (celerite2 preferred; george fallback)
# We fit: y(t) = mean_model(t) + GP(t) + noise
# Here mean_model defaults to a constant 1.0 (normalized flux baseline)
# -----------------------------
def gp_systematics_fit(t, y, yerr=None, mean_model=None, mask=None):
    t = np.asarray(t, dtype=float)
    y = np.asarray(y, dtype=float)
    if yerr is None:
        yerr = 1e-3 * np.ones_like(y)
    else:
        yerr = np.asarray(yerr, dtype=float)

    if mean_model is None:
        mean_model = lambda tt: np.ones_like(tt)

    if mask is None:
        mask = np.isfinite(t) & np.isfinite(y) & np.isfinite(yerr)
    else:
        mask = np.asarray(mask, dtype=bool) & np.isfinite(t) & np.isfinite(y) & np.isfinite(yerr)

    tt = t[mask]
    yy = y[mask]
    ee = yerr[mask]
    # stabilize large MJDs
    t0 = np.nanmin(tt)
    x = tt - t0

    backend = GP_BACKEND
    if backend == "celerite2" and not HAS_CELERITE2 and HAS_GEORGE:
        backend = "george"
    if backend == "george" and not HAS_GEORGE and HAS_CELERITE2:
        backend = "celerite2"

    if backend == "celerite2" and HAS_CELERITE2:
        def nll(p):
            logS0, logw0, logjit = p
            jitter = np.exp(logjit)
            kernel = terms.SHOTerm(S0=np.exp(logS0), w0=np.exp(logw0), Q=GP_Q_FIXED)
            gp = GaussianProcess(kernel, mean=0.0)
            gp.compute(x, yerr=np.sqrt(ee**2 + jitter**2))
            resid = yy - mean_model(tt)
            return -gp.log_likelihood(resid)

        p0 = np.array([GP_INIT_LOGS0, GP_INIT_LOGW0, GP_INIT_LOGJITTER], dtype=float)
        sol = minimize(nll, p0, method="L-BFGS-B")
        logS0, logw0, logjit = sol.x
        jitter = np.exp(logjit)
        kernel = terms.SHOTerm(S0=np.exp(logS0), w0=np.exp(logw0), Q=GP_Q_FIXED)
        gp = GaussianProcess(kernel, mean=0.0)
        gp.compute(x, yerr=np.sqrt(ee**2 + jitter**2))
        resid = yy - mean_model(tt)
        mu = gp.predict(resid, x, return_cov=False)
        # Build full-length prediction (NaN elsewhere)
        mu_full = np.full_like(t, np.nan, dtype=float)
        mu_full[mask] = mu
        meta = {"backend":"celerite2","logS0":float(logS0),"logw0":float(logw0),"logjitter":float(logjit),"success":bool(sol.success)}
        return mu_full, meta

    if backend == "george" and HAS_GEORGE:
        # Simple ExpSquared kernel as a generic systematics model
        def nll(p):
            log_amp, log_tau, logjit = p
            amp = np.exp(log_amp)
            tau = np.exp(log_tau)
            jitter = np.exp(logjit)
            k = amp * kernels.ExpSquaredKernel(tau**2)
            gp = george.GP(k)
            gp.compute(x, np.sqrt(ee**2 + jitter**2))
            resid = yy - mean_model(tt)
            return -gp.log_likelihood(resid)

        p0 = np.array([np.log(1e-3), np.log(0.1), GP_INIT_LOGJITTER], dtype=float)
        sol = minimize(nll, p0, method="L-BFGS-B")
        log_amp, log_tau, logjit = sol.x
        amp, tau, jitter = np.exp(log_amp), np.exp(log_tau), np.exp(logjit)
        k = amp * kernels.ExpSquaredKernel(tau**2)
        gp = george.GP(k)
        gp.compute(x, np.sqrt(ee**2 + jitter**2))
        resid = yy - mean_model(tt)
        mu, _ = gp.predict(resid, x, return_cov=True)
        mu_full = np.full_like(t, np.nan, dtype=float)
        mu_full[mask] = mu
        meta = {"backend":"george","log_amp":float(log_amp),"log_tau":float(log_tau),"logjitter":float(logjit),"success":bool(sol.success)}
        return mu_full, meta

    raise RuntimeError("No GP backend available. Install celerite2 or george.")

# Fit GP systematics using out-of-transit points (TRANSIT_MASK). 
# mean model is just 1.0 baseline here; we will fit transit separately later.
try:
    mu_sys, gp_meta = gp_systematics_fit(t_clean, y_clean, yerr=(dy_clean if dy_clean is not None else None),
                                         mean_model=lambda tt: np.ones_like(tt),
                                         mask=TRANSIT_MASK)
    print("GP meta:", gp_meta)
except Exception as e:
    mu_sys, gp_meta = None, None
    print("GP fit skipped / failed:", repr(e))

if mu_sys is not None:
    y_gp_corrected = y_clean - mu_sys + 1.0  # remove systematics, keep around baseline 1
    plt.figure(figsize=(10,4))
    plt.plot(t_clean, y_clean, ".", ms=3, label="clean")
    plt.plot(t_clean, mu_sys, "-", lw=2, label="GP systematics")
    plt.xlabel("Time")
    plt.ylabel("Flux (normalized)")
    plt.title("GP systematics fit (out-of-transit)")
    plt.legend()
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10,4))
    plt.plot(t_clean, y_gp_corrected, ".", ms=3)
    plt.axhline(1.0, ls="--")
    plt.xlabel("Time")
    plt.ylabel("Flux (GP-corrected)")
    plt.title("Light curve after GP systematics correction")
    plt.tight_layout()
    plt.show()


In [None]:
# -----------------------------
# Lomb–Scargle on GP-corrected residuals (optional variability search)
# -----------------------------
yy = y_gp_corrected if (mu_sys is not None) else y_clean
ok = np.isfinite(t_clean) & np.isfinite(yy)
tt = t_clean[ok]
ff = yy[ok] - np.nanmedian(yy[ok])

# If MJD-like, subtract for stability
tref = np.nanmin(tt)
x = tt - tref

min_period = 0.05
max_period = 10.0
freq = np.linspace(1/max_period, 1/min_period, 20000)

ls = LombScargle(x, ff)
power = ls.power(freq)
best_f = freq[np.argmax(power)]
best_period = 1/best_f

print("Best LS period ~", best_period, "days")

plt.figure(figsize=(10,4))
plt.plot(1/freq, power)
plt.gca().invert_xaxis()
plt.xlabel("Period (days)")
plt.ylabel("LS Power")
plt.title("Lomb–Scargle Periodogram (GP-corrected residuals)")
plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# Transit model with batman (white-light fit)
# We'll fit a simple transit to the GP-corrected light curve.
# -----------------------------
if not HAS_BATMAN or not HAS_SCIPY:
    print("Need batman-package and scipy to fit transits. Skipping.")
else:
    tfit = np.asarray(t_clean, dtype=float)
    yfit = np.asarray(y_gp_corrected if mu_sys is not None else y_clean, dtype=float)
    efit = (np.asarray(dy_clean, dtype=float) if dy_clean is not None else 1e-3*np.ones_like(yfit))

    # Use days and stabilize time
    t0ref = np.nanmin(tfit)
    x = tfit - t0ref

    def batman_model(x, p_shared, rp, u=LD_U):
        params = batman.TransitParams()
        params.t0 = float(p_shared["t0"]) - t0ref
        params.per = float(p_shared["per"])
        params.rp  = float(rp)
        params.a   = float(p_shared["a"])
        params.inc = float(p_shared["inc"])
        params.ecc = float(p_shared.get("ecc", 0.0))
        params.w   = float(p_shared.get("w", 90.0))
        params.limb_dark = LD_LAW
        params.u = list(u)

        if EXP_TIME_DAYS and EXP_TIME_DAYS > 0 and SUPERSAMPLE > 1:
            m = batman.TransitModel(params, x, supersample_factor=SUPERSAMPLE, exp_time=EXP_TIME_DAYS)
        else:
            m = batman.TransitModel(params, x)
        return m.light_curve(params)

    # Initial guess
    p_shared = dict(TRANSIT_INIT)
    if p_shared["t0"] is None:
        p_shared["t0"] = float(np.nanmedian(tfit))
    # per is often known; set TRANSIT_INIT['per'] accordingly for real data

    def residuals_white(theta):
        # theta = [t0, rp, a, inc, c0]
        t0, rp, a, inc, c0 = theta
        ps = dict(p_shared)
        ps["t0"], ps["a"], ps["inc"] = t0, a, inc
        model = c0 * batman_model(x, ps, rp)
        return (yfit - model) / efit

    theta0 = np.array([p_shared["t0"], p_shared["rp"], p_shared["a"], p_shared["inc"], 1.0], dtype=float)
    # mild bounds
    lo = [theta0[0]-1.0, 0.001, 1.0, 60.0, 0.8]
    hi = [theta0[0]+1.0, 0.5, 200.0, 90.0, 1.2]

    sol = least_squares(residuals_white, theta0, bounds=(lo,hi))
    t0_w, rp_w, a_w, inc_w, c0_w = sol.x

    print("White-light fit:")
    print(" t0 :", t0_w)
    print(" rp :", rp_w)
    print(" a  :", a_w)
    print(" inc:", inc_w)
    print(" c0 :", c0_w)

    p_shared_fit = dict(p_shared)
    p_shared_fit.update({"t0": float(t0_w), "a": float(a_w), "inc": float(inc_w)})

    m_white = c0_w * batman_model(x, p_shared_fit, rp_w)

    plt.figure(figsize=(10,4))
    plt.plot(tfit, yfit, ".", ms=3, label="data")
    plt.plot(tfit, m_white, "-", lw=2, label="batman fit")
    plt.xlabel("Time")
    plt.ylabel("Flux (GP-corrected, normalized)")
    plt.title("White-light transit fit")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Save white-light params
    import pandas as pd
    pd.DataFrame([{
        "t0": t0_w, "per": p_shared_fit["per"], "rp": rp_w, "a": a_w, "inc": inc_w,
        "c0": c0_w, "gp_backend": (gp_meta.get("backend") if gp_meta else None)
    }]).to_csv(OUT_DIR/"white_light_fit.csv", index=False)
    print("Saved:", OUT_DIR/"white_light_fit.csv")


In [None]:
# -----------------------------
# Build wavelength-binned light curves (X1DINTS only)
# Then: common-mode correction using GP systematics from white-light,
# and a JOINT transit fit across bins to estimate Rp/Rs(λ) (transmission spectrum).
# -----------------------------
if meta.get("source") != "X1DINTS":
    print("Not an X1DINTS dataset; skipping transmission spectrum steps.")
elif not (HAS_BATMAN and HAS_SCIPY):
    print("Need batman + scipy. Skipping.")
else:
    wave = np.asarray(extra["wave"], dtype=float)
    flux2d = np.asarray(extra["flux2d"], dtype=float)  # (nint, nwave)
    # Restrict to clean indices (mask_ok computed earlier on white-light)
    idx_all = np.arange(len(t))
    idx_clean = idx_all[mask_ok]
    t_bin = t_clean
    x = t_bin - np.nanmin(t_bin)

    flux2d_c = flux2d[idx_clean, :]

    # Define wavelength bins
    wmin, wmax = WL_RANGE_MICRON
    if wmin is None: wmin = float(np.nanmin(wave))
    if wmax is None: wmax = float(np.nanmax(wave))
    wmask = (wave >= wmin) & (wave <= wmax)
    w_use = wave[wmask]

    edges = np.linspace(np.nanmin(w_use), np.nanmax(w_use), NBINS+1)
    centers = 0.5*(edges[:-1] + edges[1:])

    # White-light GP systematics vector on clean points (mu_sys)
    # If no GP, use zeros (no correction)
    sys = mu_sys if mu_sys is not None else np.zeros_like(t_clean)

    # Construct binned light curves
    Y = []
    for i in range(NBINS):
        m = (wave >= edges[i]) & (wave < edges[i+1])
        m = m & wmask
        if not np.any(m):
            Y.append(np.full_like(t_bin, np.nan))
            continue
        lc = np.nansum(flux2d_c[:, m], axis=1)
        lc /= np.nanmedian(lc)
        # common-mode systematics correction (subtract systematics and re-add baseline)
        lc_corr = lc - sys + 1.0
        Y.append(lc_corr)

    Y = np.array(Y)  # (nbin, ntime)

    # Simple per-point error proxy
    E = 1e-3*np.ones_like(Y)

    # Joint fit parameters:
    # shared: (t0, [optional per,a,inc]) and per-bin rp and per-bin baseline c0
    # We'll use the white-light fitted shared parameters if available.
    # If previous cell wasn't run, fall back to TRANSIT_INIT.
    try:
        t0_shared = float(t0_w)
        per_shared = float(p_shared_fit["per"])
        a_shared = float(a_w)
        inc_shared = float(inc_w)
    except Exception:
        t0_shared = float(TRANSIT_INIT["t0"])
        per_shared = float(TRANSIT_INIT["per"])
        a_shared = float(TRANSIT_INIT["a"])
        inc_shared = float(TRANSIT_INIT["inc"])

    # Build batman model callable
    def batman_curve(x, t0_abs, per, rp, a, inc, u=LD_U):
        params = batman.TransitParams()
        params.t0 = t0_abs - np.nanmin(t_bin)  # in the same shifted system as x
        params.per = per
        params.rp  = rp
        params.a   = a
        params.inc = inc
        params.ecc = float(TRANSIT_INIT.get("ecc", 0.0))
        params.w   = float(TRANSIT_INIT.get("w", 90.0))
        params.limb_dark = LD_LAW
        params.u = list(u)
        if EXP_TIME_DAYS and EXP_TIME_DAYS > 0 and SUPERSAMPLE > 1:
            m = batman.TransitModel(params, x, supersample_factor=SUPERSAMPLE, exp_time=EXP_TIME_DAYS)
        else:
            m = batman.TransitModel(params, x)
        return m.light_curve(params)

    # Parameter packing
    # theta = [t0] + [rp_i]*NBINS + [c0_i]*NBINS
    def pack(theta_t0, rps, c0s):
        return np.concatenate([[theta_t0], np.asarray(rps), np.asarray(c0s)])

    def unpack(theta):
        theta = np.asarray(theta, dtype=float)
        t0 = theta[0]
        rps = theta[1:1+NBINS]
        c0s = theta[1+NBINS:1+2*NBINS]
        return t0, rps, c0s

    # Initial guesses
    rp0 = np.full(NBINS, float(rp_w) if "rp_w" in globals() else float(TRANSIT_INIT["rp"]))
    c00 = np.ones(NBINS, dtype=float)
    theta0 = pack(t0_shared, rp0, c00)

    # Bounds
    lo = pack(t0_shared-0.2, np.full(NBINS, 0.001), np.full(NBINS, 0.8))
    hi = pack(t0_shared+0.2, np.full(NBINS, 0.5),   np.full(NBINS, 1.2))

    def resid_joint(theta):
        t0, rps, c0s = unpack(theta)
        res_all = []
        for i in range(NBINS):
            yi = Y[i]
            ei = E[i]
            ok = np.isfinite(yi) & np.isfinite(ei)
            if not np.any(ok):
                continue
            model = c0s[i] * batman_curve(x[ok], t0, per_shared, rps[i], a_shared, inc_shared)
            res_all.append((yi[ok] - model) / ei[ok])
        if len(res_all) == 0:
            return np.array([0.0])
        return np.concatenate(res_all)

    sol = least_squares(resid_joint, theta0, bounds=(lo, hi))
    t0_j, rps_j, c0s_j = unpack(sol.x)

    # Uncertainty estimate (rough): diagonal from (J^T J)^-1 scaled by residual variance
    rp_err = np.full(NBINS, np.nan)
    try:
        J = sol.jac
        # covariance ~ s^2 * (J^T J)^-1
        _, svals, VT = np.linalg.svd(J, full_matrices=False)
        tol = np.finfo(float).eps * max(J.shape) * svals[0]
        svals = svals[svals > tol]
        VT = VT[:len(svals)]
        cov = (VT.T / (svals**2)) @ VT
        dof = max(1, len(sol.fun) - len(sol.x))
        s2 = 2*np.sum(sol.fun**2) / dof
        cov *= s2
        rp_err = np.sqrt(np.diag(cov)[1:1+NBINS])
    except Exception as e:
        print("Could not estimate uncertainties:", repr(e))

    print("Joint fit done. t0 =", t0_j, "| per fixed =", per_shared)

    # Plot binned fits
    plt.figure(figsize=(10,6))
    offset = 0.0
    for i in range(NBINS):
        yi = Y[i]
        ok = np.isfinite(yi)
        if not np.any(ok):
            continue
        model = c0s_j[i] * batman_curve(x[ok], t0_j, per_shared, rps_j[i], a_shared, inc_shared)
        plt.plot(t_bin[ok], yi[ok] + offset, ".", ms=2)
        plt.plot(t_bin[ok], model + offset, "-", lw=1.5)
        plt.text(t_bin[ok][0], yi[ok][0] + offset, f"{centers[i]:.2f} μm", fontsize=8, va="bottom")
        offset += 0.03
    plt.xlabel("Time")
    plt.ylabel("Flux (stacked)")
    plt.title("Spectroscopic bins: data + joint transit fits (after GP common-mode correction)")
    plt.tight_layout()
    plt.show()

    # Transmission spectrum
    depth = rps_j**2
    depth_err = 2*rps_j*rp_err

    import pandas as pd
    spec = pd.DataFrame({
        "wl_center_micron": centers,
        "rp_rs": rps_j,
        "rp_rs_err": rp_err,
        "depth": depth,
        "depth_err": depth_err
    })
    spec_path = OUT_DIR / "transmission_spectrum.csv"
    spec.to_csv(spec_path, index=False)
    print("Saved:", spec_path)

    plt.figure(figsize=(8,4))
    plt.errorbar(centers, rps_j, yerr=rp_err, fmt="o")
    plt.xlabel("Wavelength (micron)")
    plt.ylabel("Rp/Rs")
    plt.title("Transmission spectrum (Rp/Rs vs wavelength)")
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(8,4))
    plt.errorbar(centers, depth, yerr=depth_err, fmt="o")
    plt.xlabel("Wavelength (micron)")
    plt.ylabel("Transit depth (Rp/Rs)^2")
    plt.title("Transmission spectrum (depth)")
    plt.tight_layout()
    plt.show()


In [None]:
# -----------------------------
# Optional: BLS transit search (useful if ephemeris unknown; best on white-light)
# -----------------------------
if not HAS_BLS:
    print("BLS not available in this astropy; skipping.")
else:
    # Use GP-corrected flux
    y_use = y_gp_corrected if mu_sys is not None else y_clean
    ok = np.isfinite(t_clean) & np.isfinite(y_use)
    tt = t_clean[ok]
    yy = y_use[ok]
    ee = (dy_clean[ok] if dy_clean is not None else 1e-3*np.ones_like(yy))

    tref = np.nanmin(tt)
    x = tt - tref

    periods = np.linspace(0.2, 10.0, 4000)
    durations = np.linspace(0.005, 0.2, 25)  # days

    bls = BoxLeastSquares(x, yy, dy=ee)
    res = bls.power(periods, durations)
    i_best = np.argmax(res.power)
    p_best = res.period[i_best]
    d_best = res.duration[i_best]
    t0_best = res.transit_time[i_best]

    print("Best BLS period:", float(p_best), "days | duration:", float(d_best), "days")

    plt.figure(figsize=(10,4))
    plt.plot(res.period, res.power)
    plt.xlabel("Period (days)")
    plt.ylabel("BLS Power")
    plt.title("BLS Periodogram")
    plt.tight_layout()
    plt.show()

    phase = ((x - t0_best + 0.5*p_best) % p_best) / p_best - 0.5
    order = np.argsort(phase)
    plt.figure(figsize=(8,4))
    plt.plot(phase[order], yy[order], ".", ms=2)
    plt.axvspan(-0.5*float(d_best/p_best), 0.5*float(d_best/p_best), alpha=0.2)
    plt.xlabel("Phase (centered)")
    plt.ylabel("Flux")
    plt.title("Folded at best BLS period")
    plt.tight_layout()
    plt.show()


## Practical upgrades you can add next

- Replace the simple uncertainty model with JWST-provided variance arrays (`VAR_POISSON`, `VAR_RNOISE`) where available in `rateints`.  
- Fit **GP + transit simultaneously** (instead of common-mode GP then transit), using MCMC (e.g., `emcee`) if you want robust uncertainties.  
- Use mode-specific limb darkening (e.g., `ExoTiC-LD`, `ldtk`, or tables) for more realistic transmission spectra.
