# JWST Time-Series Analysis Lab (MAST + astroquery)

This notebook downloads **public JWST time-series–related products** from **MAST** using `astroquery.mast`, then builds a **light curve** and runs:
- robust cleaning + detrending  
- variability metrics  
- Lomb–Scargle periodogram  
- (optional) Box Least Squares (BLS) transit search  
- (optional) spectroscopic (multi-wavelength) light curves if `x1dints` is available

> Tip: JWST "TSO" (Time-Series Observation) programs often have `whtlt` (white-light time series) and/or `x1dints` (integration-resolved 1D spectra).  


In [None]:
# If you're missing dependencies, uncomment and run:
# !pip -q install astroquery astropy numpy matplotlib scipy

print("Ready.")

In [None]:
# -----------------------------
# Imports
# -----------------------------
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

from astropy.io import fits
import astropy.units as u
from astropy.stats import sigma_clip
from astropy.timeseries import LombScargle

# Optional: BLS transit search
try:
    from astropy.timeseries import BoxLeastSquares
    HAS_BLS = True
except Exception:
    HAS_BLS = False

# Optional: SciPy detrending helpers
try:
    import scipy
    from scipy.signal import savgol_filter, medfilt
    HAS_SCIPY = True
except Exception:
    HAS_SCIPY = False

from astroquery.mast import Observations

print("HAS_SCIPY:", HAS_SCIPY, "| HAS_BLS:", HAS_BLS)


In [None]:
# -----------------------------
# User configuration
# -----------------------------
OUT_DIR = Path("jwst_timeseries_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

DOWNLOAD_DIR = OUT_DIR / "mast_downloads"
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

# Pick ONE of these entrypoints:
TARGET_NAME = "WASP-39"     # typical JWST exoplanet TSO target (public data exists, but not guaranteed)
SEARCH_RADIUS = 0.03 * u.deg

PROGRAM_ID = None          # e.g. 1366, 2736, ... (set to an int/str to search by JWST program)

# After we query, we pick an observation row by index:
OBS_INDEX = 0

# Product preference order:
PREFERRED_SUBGROUPS = ["WHTLT", "X1DINTS", "CALINTS", "RATEINTS"]

# If using X1DINTS spectroscopic time-series, integrate over this wavelength range (micron).
# Set to None to auto-use full wavelength coverage.
WL_RANGE_MICRON = (None, None)  # e.g. (1.1, 1.7)

# Detrending
DETREND_MODE = "median"   # "median", "savgol", "none"
MED_WIN = 31              # odd
SAVGOL_WIN = 51           # odd
SAVGOL_POLY = 2

# Period search windows (in days)
MIN_PERIOD = 0.05
MAX_PERIOD = 5.0
N_FREQ = 20000

print("Output dir:", OUT_DIR.resolve())


In [None]:
# -----------------------------
# Query MAST for JWST observations
# -----------------------------
def query_jwst():
    if PROGRAM_ID is not None:
        # Search by JWST program/proposal ID
        print(f"Querying JWST by PROGRAM_ID={PROGRAM_ID} ...")
        try:
            obs = Observations.query_criteria(obs_collection="JWST", proposal_id=str(PROGRAM_ID))
        except Exception as e:
            print("query_criteria with proposal_id failed:", repr(e))
            print("Trying proposal_id as int ...")
            obs = Observations.query_criteria(obs_collection="JWST", proposal_id=int(PROGRAM_ID))
        return obs

    # Search around a target name (name resolution happens server-side for query_object)
    print(f"Querying JWST near TARGET_NAME='{TARGET_NAME}' (radius={SEARCH_RADIUS}) ...")
    obs = Observations.query_object(TARGET_NAME, radius=SEARCH_RADIUS)
    # Restrict to JWST
    if "obs_collection" in obs.colnames:
        obs = obs[obs["obs_collection"] == "JWST"]
    return obs

obs = query_jwst()
print("Total observations returned:", len(obs))
obs[:5]


In [None]:
# -----------------------------
# Inspect + pick an observation
# -----------------------------
if len(obs) == 0:
    raise RuntimeError("No observations found. Try a different TARGET_NAME, larger SEARCH_RADIUS, or set PROGRAM_ID.")

# Show helpful columns if present
cols = [c for c in ["obs_id","target_name","instrument_name","filters","dataproduct_type","t_min","t_max","proposal_id"] if c in obs.colnames]
print("Preview columns:", cols)

# Sort by time if available
if "t_min" in obs.colnames:
    obs = obs[np.argsort(obs["t_min"])]

preview = obs[cols][:20] if cols else obs[:20]
preview


In [None]:
# -----------------------------
# Get product list + filter to time-series–useful products
# -----------------------------
obs_row = obs[[OBS_INDEX]]  # keep as a table
print("Selected obs_id:", obs_row["obs_id"][0] if "obs_id" in obs_row.colnames else "(unknown)")

products = Observations.get_product_list(obs_row)
print("Products:", len(products))
print("Columns:", products.colnames)

# Look at available product subgroups
if "productSubGroupDescription" in products.colnames:
    uniq = sorted(set(str(x).upper() for x in products["productSubGroupDescription"]))
    print("Unique productSubGroupDescription (first 40):")
    print(uniq[:40])

# Filter to preferred subgroups (case-insensitive)
if "productSubGroupDescription" in products.colnames:
    # astroquery filter wants exact matches; we try upper-case list
    products["productSubGroupDescription"] = [str(x).upper() for x in products["productSubGroupDescription"]]
    wanted = Observations.filter_products(
        products,
        productSubGroupDescription=PREFERRED_SUBGROUPS,
        mrp_only=False
    )
else:
    wanted = products  # fallback

print("Filtered products:", len(wanted))
wanted[:10]


In [None]:
# -----------------------------
# Download selected products
# -----------------------------
if len(wanted) == 0:
    raise RuntimeError("No products matched PREFERRED_SUBGROUPS. Inspect 'products' and adjust PREFERRED_SUBGROUPS.")

manifest = Observations.download_products(
    wanted,
    download_dir=str(DOWNLOAD_DIR),
    cache=True
)

print("Downloaded rows:", len(manifest))
manifest[:10]


In [None]:
# -----------------------------
# Load a time series from downloaded files:
# - Prefer WHTLT (white-light time series) if present
# - Else use X1DINTS (spectra per integration) and integrate over wavelength to make a light curve
# - Else (CALINTS/RATEINTS) are usually image cubes per integration; you could do aperture photometry (not implemented here)
# -----------------------------
from astropy.table import Table

def _find_local_paths(manifest):
    paths = []
    for p in manifest["Local Path"]:
        if p and str(p).strip().lower() != "none":
            paths.append(Path(p))
    return [p for p in paths if p.exists()]

local_paths = _find_local_paths(manifest)
print("Local files:", len(local_paths))
for p in local_paths[:10]:
    print(" -", p.name)

def _is_subgroup(path, subgroup):
    return subgroup.lower() in path.name.lower()

whtlt_files = [p for p in local_paths if _is_subgroup(p, "whtlt")]
x1dints_files = [p for p in local_paths if _is_subgroup(p, "x1dints")]
calints_files = [p for p in local_paths if _is_subgroup(p, "calints")]
rateints_files = [p for p in local_paths if _is_subgroup(p, "rateints")]

print("Found:", {"whtlt": len(whtlt_files), "x1dints": len(x1dints_files), "calints": len(calints_files), "rateints": len(rateints_files)})

def load_whtlt(path):
    with fits.open(path, memmap=False) as hdul:
        # Find first binary table with a plausible time+flux
        best = None
        for hdu in hdul:
            if not isinstance(hdu, fits.BinTableHDU):
                continue
            cols = [c.upper() for c in hdu.columns.names]
            time_candidates = [c for c in cols if ("TIME" in c or "MJD" in c or "BJD" in c)]
            flux_candidates = [c for c in cols if ("FLUX" in c and "ERR" not in c)]
            if time_candidates and flux_candidates:
                best = hdu
                break
        if best is None:
            raise RuntimeError("Could not find a time/flux table in WHTLT file.")

        tab = Table(best.data)
        # pick columns
        tcol = None
        for c in tab.colnames:
            cu = c.upper()
            if cu in ("TIME","MJD","BJD","TMID","T"):
                tcol = c
                break
        if tcol is None:
            # fallback: first time-like column
            tcol = next(c for c in tab.colnames if ("TIME" in c.upper() or "MJD" in c.upper() or "BJD" in c.upper()))

        fcol = next(c for c in tab.colnames if (c.upper() == "FLUX" or (("FLUX" in c.upper()) and ("ERR" not in c.upper()))))
        ecol = None
        for c in tab.colnames:
            if "ERR" in c.upper() and "FLUX" in c.upper():
                ecol = c
                break

        t = np.asarray(tab[tcol], dtype=float)
        y = np.asarray(tab[fcol], dtype=float)
        dy = np.asarray(tab[ecol], dtype=float) if ecol is not None else None

        return t, y, dy, {"source": "WHTLT", "tcol": tcol, "fcol": fcol, "ecol": ecol, "file": str(path)}

def load_x1dints(path, wl_range_micron=(None,None)):
    with fits.open(path, memmap=False) as hdul:
        # Find EXTRACT1D
        h = None
        for hdu in hdul:
            if isinstance(hdu, fits.BinTableHDU) and (hdu.name.upper() == "EXTRACT1D"):
                h = hdu
                break
        if h is None:
            # fallback: first table with FLUX + WAVELENGTH
            for hdu in hdul:
                if not isinstance(hdu, fits.BinTableHDU):
                    continue
                cols = [c.upper() for c in hdu.columns.names]
                if "FLUX" in cols and "WAVELENGTH" in cols:
                    h = hdu
                    break
        if h is None:
            raise RuntimeError("Could not find EXTRACT1D-like table with FLUX/WAVELENGTH in X1DINTS.")

        tab = Table(h.data)
        # Columns vary by mode; try common names
        wcol = next(c for c in tab.colnames if c.upper() == "WAVELENGTH")
        fcol = next(c for c in tab.colnames if c.upper() == "FLUX")
        ecol = None
        for c in tab.colnames:
            if c.upper() in ("ERROR","FLUX_ERROR","ERR","FLUXERR"):
                ecol = c
                break

        wave = np.asarray(tab[wcol])
        flux = np.asarray(tab[fcol])
        err = np.asarray(tab[ecol]) if ecol is not None else None

        # Try to read integration mid-times from INT_TIMES extension if present
        t = None
        for hdu in hdul:
            if isinstance(hdu, fits.BinTableHDU) and ("INT_TIMES" in hdu.name.upper()):
                ttab = Table(hdu.data)
                # common column in jwst pipeline products
                for cname in ttab.colnames:
                    cu = cname.lower()
                    if "mid" in cu and "mjd" in cu:
                        t = np.asarray(ttab[cname], dtype=float)
                        break
                if t is None:
                    for cname in ttab.colnames:
                        cu = cname.lower()
                        if "mid" in cu and ("time" in cu):
                            t = np.asarray(ttab[cname], dtype=float)
                            break
                break

        # Normalize wave/flux shapes
        # wave may be (nwave,) and flux may be (nint, nwave) OR (nwave, nint) OR object arrays.
        flux = np.asarray(flux)
        wave = np.asarray(wave)

        # If flux is 1D but rows represent wavelengths, this isn't really x1dints; handle minimally
        if flux.ndim == 1:
            # Treat as single spectrum: no time axis
            raise RuntimeError("FLUX is 1D in this file; not integration-resolved. Try another observation/product.")

        # Determine orientation
        if wave.ndim == 1:
            # want flux shape (nint, nwave)
            if flux.shape[1] == wave.shape[0]:
                f2 = flux
                w1 = wave
            elif flux.shape[0] == wave.shape[0]:
                f2 = flux.T
                w1 = wave
            else:
                # worst-case: squeeze and guess
                f2 = flux.reshape(flux.shape[0], -1)
                w1 = np.linspace(0, 1, f2.shape[1])
        else:
            # wave might be 2D; try to take first row as wavelength grid
            w1 = wave[0] if wave.ndim >= 2 else wave
            f2 = flux if flux.shape[1] == len(w1) else flux.T

        nint, nw = f2.shape
        if t is None:
            t = np.arange(nint, dtype=float)  # fallback: index

        # wavelength filtering
        wmin, wmax = wl_range_micron
        if wmin is None: wmin = np.nanmin(w1)
        if wmax is None: wmax = np.nanmax(w1)
        mask = (w1 >= wmin) & (w1 <= wmax)
        if not np.any(mask):
            raise RuntimeError("Wavelength mask is empty; adjust WL_RANGE_MICRON.")

        # White-light: sum over wavelength bins (ignore NaNs)
        y = np.nansum(f2[:, mask], axis=1)

        dy = None
        if err is not None:
            e = np.asarray(err)
            if e.ndim == 2:
                e2 = e if e.shape == f2.shape else e.T
                dy = np.sqrt(np.nansum((e2[:, mask])**2, axis=1))

        meta = {
            "source": "X1DINTS",
            "file": str(path),
            "wmin": float(wmin),
            "wmax": float(wmax),
            "nint": int(nint),
            "nwave": int(nw),
        }
        return t, y, dy, meta, w1, f2, (err if err is not None else None)

# Choose best available
if len(whtlt_files) > 0:
    t, y, dy, meta = load_whtlt(whtlt_files[0])
    wave = flux2d = err2d = None
elif len(x1dints_files) > 0:
    t, y, dy, meta, wave, flux2d, err2d = load_x1dints(x1dints_files[0], WL_RANGE_MICRON)
else:
    raise RuntimeError("No WHTLT or X1DINTS found. Try another observation, or adjust filters/subgroups.")

print("Loaded:", meta)
print("t range:", float(np.nanmin(t)), "to", float(np.nanmax(t)), "| N =", len(t))
print("y median:", float(np.nanmedian(y)))


In [None]:
# -----------------------------
# Plot raw light curve
# -----------------------------
t = np.asarray(t, dtype=float)
y = np.asarray(y, dtype=float)
dy = (np.asarray(dy, dtype=float) if dy is not None else None)

# Basic normalization
y0 = np.nanmedian(y)
yn = y / y0

plt.figure(figsize=(10,4))
plt.plot(t, yn, ".", ms=3)
plt.xlabel("Time (as stored; often MJD or index)")
plt.ylabel("Normalized flux")
plt.title(f"Raw light curve ({meta.get('source','?')})")
plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# Clean + detrend
# -----------------------------
def rolling_median(y, win):
    win = int(win)
    if win < 3:
        return np.nanmedian(y) * np.ones_like(y)
    if win % 2 == 0:
        win += 1
    n = len(y)
    out = np.empty(n, dtype=float)
    half = win // 2
    for i in range(n):
        lo = max(0, i-half)
        hi = min(n, i+half+1)
        out[i] = np.nanmedian(y[lo:hi])
    return out

# Sigma-clip outliers on normalized flux
clipped = sigma_clip(yn, sigma=5.0, maxiters=5)
mask_ok = ~clipped.mask if hasattr(clipped, "mask") else np.isfinite(clipped)

t_clean = t[mask_ok]
y_clean = yn[mask_ok]
dy_clean = (dy[mask_ok]/y0 if dy is not None else None)

# Detrend
if DETREND_MODE == "none":
    trend = np.ones_like(y_clean)
elif DETREND_MODE == "savgol" and HAS_SCIPY:
    win = SAVGOL_WIN + (SAVGOL_WIN % 2 == 0)
    trend = savgol_filter(y_clean, window_length=win, polyorder=SAVGOL_POLY, mode="interp")
elif DETREND_MODE == "median":
    trend = rolling_median(y_clean, MED_WIN)
else:
    print("Requested DETREND_MODE not available; using median.")
    trend = rolling_median(y_clean, MED_WIN)

# For transit-like work, division detrend is common:
y_detr = y_clean / trend

plt.figure(figsize=(10,4))
plt.plot(t_clean, y_clean, ".", ms=3, label="clean")
plt.plot(t_clean, trend, "-", lw=2, label="trend")
plt.xlabel("Time")
plt.ylabel("Normalized flux")
plt.title("Cleaned + trend")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(10,4))
plt.plot(t_clean, y_detr, ".", ms=3)
plt.axhline(1.0, ls="--")
plt.xlabel("Time")
plt.ylabel("Detrended flux")
plt.title("Detrended light curve")
plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# Variability metrics (quick astrophysics-style summary)
# -----------------------------
def robust_mad(x):
    x = np.asarray(x)
    med = np.nanmedian(x)
    return np.nanmedian(np.abs(x - med))

mad = robust_mad(y_detr)
robust_rms = 1.4826 * mad  # ~sigma for normal distribution
pp = np.nanpercentile(y_detr, [1, 5, 50, 95, 99])

print("Detrended flux percentiles [1,5,50,95,99]:", pp)
print("MAD:", mad)
print("Robust RMS (1.4826*MAD):", robust_rms)

# Save a CSV
import pandas as pd
df = pd.DataFrame({
    "t": t_clean,
    "flux_norm": y_clean,
    "flux_detr": y_detr,
})
csv_path = OUT_DIR / "lightcurve.csv"
df.to_csv(csv_path, index=False)
print("Saved:", csv_path)


In [None]:
# -----------------------------
# Lomb–Scargle periodogram (good for stellar variability)
# -----------------------------
# LombScargle expects finite values
ok = np.isfinite(t_clean) & np.isfinite(y_detr)
tt = t_clean[ok]
yy = y_detr[ok]
ddy = dy_clean[ok] if dy_clean is not None else None

# If time is big (e.g., MJD ~ 60000), subtract a reference for numerical stability
tref = np.nanmin(tt)
tt0 = tt - tref

min_f = 1.0 / MAX_PERIOD
max_f = 1.0 / MIN_PERIOD
freq = np.linspace(min_f, max_f, N_FREQ)

ls = LombScargle(tt0, yy, dy=ddy)
power = ls.power(freq)

best_f = freq[np.argmax(power)]
best_period = 1.0 / best_f
print("Best LS period (days):", best_period)

plt.figure(figsize=(10,4))
plt.plot(1.0/freq, power)
plt.gca().invert_xaxis()
plt.xlabel("Period (days)")
plt.ylabel("LS Power")
plt.title("Lomb–Scargle Periodogram")
plt.tight_layout()
plt.show()

# Fold
phase = (tt0 % best_period) / best_period
order = np.argsort(phase)

plt.figure(figsize=(8,4))
plt.plot(phase[order], yy[order], ".", ms=3)
plt.xlabel("Phase")
plt.ylabel("Detrended flux")
plt.title(f"Folded light curve (P={best_period:.5f} d)")
plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# Optional: BLS transit search (Box Least Squares)
# -----------------------------
if not HAS_BLS:
    print("BLS not available in this astropy; skipping.")
else:
    # Use tt0 in days (float). BLS is sensitive to window choice.
    periods = np.linspace(MIN_PERIOD, MAX_PERIOD, 5000)
    durations = np.linspace(0.005, 0.15, 20)  # days (7.2 min to 3.6 hr)

    bls = BoxLeastSquares(tt0, yy, dy=ddy)
    res = bls.power(periods, durations)

    i_best = np.argmax(res.power)
    p_best = res.period[i_best]
    d_best = res.duration[i_best]
    t0_best = res.transit_time[i_best]
    print("Best BLS period:", float(p_best), "days")
    print("Best duration:", float(d_best), "days")
    print("Best t0 (relative):", float(t0_best), "days")

    plt.figure(figsize=(10,4))
    plt.plot(res.period, res.power)
    plt.xlabel("Period (days)")
    plt.ylabel("BLS Power")
    plt.title("BLS Periodogram")
    plt.tight_layout()
    plt.show()

    # Fold at best period and show transit window
    phase = ((tt0 - t0_best + 0.5*p_best) % p_best) / p_best - 0.5
    order = np.argsort(phase)

    plt.figure(figsize=(8,4))
    plt.plot(phase[order], yy[order], ".", ms=3)
    plt.axvspan(-0.5*float(d_best/p_best), 0.5*float(d_best/p_best), alpha=0.2)
    plt.xlabel("Phase (centered on transit)")
    plt.ylabel("Detrended flux")
    plt.title(f"Folded (BLS) P={float(p_best):.5f} d")
    plt.tight_layout()
    plt.show()


In [None]:
# -----------------------------
# Optional: Spectroscopic (multi-wavelength) light curves if X1DINTS
# -----------------------------
if meta.get("source") != "X1DINTS":
    print("No X1DINTS loaded; skipping spectroscopic light curves.")
else:
    # flux2d shape: (nint, nwave)
    w = np.asarray(wave, dtype=float)
    f2 = np.asarray(flux2d, dtype=float)
    nint, nw = f2.shape

    # Define bins in wavelength
    NBINS = 6
    wmin, wmax = np.nanmin(w), np.nanmax(w)
    edges = np.linspace(wmin, wmax, NBINS+1)

    # Use the same clean mask as before (mask_ok) but ensure same length
    # t_clean corresponds to mask_ok subset; map indices
    idx_all = np.arange(len(t))
    idx_clean = idx_all[mask_ok]
    f2c = f2[idx_clean, :]

    lcs = []
    labels = []
    for i in range(NBINS):
        m = (w >= edges[i]) & (w < edges[i+1])
        if not np.any(m):
            continue
        lc = np.nansum(f2c[:, m], axis=1)
        lc /= np.nanmedian(lc)
        lcs.append(lc)
        labels.append(f"{edges[i]:.2f}-{edges[i+1]:.2f} μm")

    lcs = np.array(lcs)  # (nbin, ntime)
    # Common-mode correction using white-light trend
    common = trend  # from earlier (same time base)
    lcs_cm = lcs / common[None, :]

    plt.figure(figsize=(10,6))
    offset = 0.0
    for i in range(lcs_cm.shape[0]):
        plt.plot(t_clean, lcs_cm[i] + offset, ".", ms=2)
        plt.text(t_clean[0], (lcs_cm[i][0] + offset), labels[i], fontsize=9, va="bottom")
        offset += 0.03  # stack
    plt.xlabel("Time")
    plt.ylabel("Flux (common-mode corrected, stacked)")
    plt.title("Spectroscopic light curves (stacked)")
    plt.tight_layout()
    plt.show()


## Next ideas (fun “space-y” upgrades)

- Replace detrending with **Gaussian Processes** (e.g., `celerite2` / `george`) for instrument systematics.  
- Fit a **transit model** (`batman-package`) jointly across wavelength bins to get a transmission spectrum.  
- For imaging TSOs (`calints` / `rateints`), do **aperture photometry per integration** (very similar to ground-based differential photometry).  
