# Astronomy Time-Series Analysis Notebook (Template)

This notebook is a **general workflow** for astronomy time-series (light curves / radial velocities / photometry).

**What you can do here**
- Load a time series from **CSV / TSV / FITS**
- Clean & inspect: gaps, outliers, uncertainties
- Detrend / normalize
- Search periodicity: **Lomb–Scargle**
- Phase-fold, bin, and export results

**Tip:** If you ever see errors like `NameError: PROGRAM_ID not defined`, use **Kernel → Restart & Run All** so config cells run before helpers.


In [None]:
# (Optional) Install dependencies if needed:
# !pip -q install numpy pandas matplotlib scipy astropy astroquery lightkurve celerite2 statsmodels

print("Environment ready.")


In [None]:
# -----------------------------
# Imports
# -----------------------------
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from astropy.time import Time
from astropy.table import Table
import astropy.units as u
from astropy.stats import sigma_clip
from astropy.timeseries import LombScargle

# Optional: MAST query for JWST/HST/etc.
try:
    from astroquery.mast import Observations
    HAVE_MAST = True
except Exception as e:
    HAVE_MAST = False
    print("astroquery.mast not available (MAST querying disabled):", repr(e))

# Optional: smoothing/detrending
try:
    from scipy.signal import savgol_filter
    HAVE_SCIPY = True
except Exception as e:
    HAVE_SCIPY = False
    print("scipy not available (some detrending disabled):", repr(e))


In [None]:
# -----------------------------
# Configuration
# -----------------------------
# Option A: Analyze a local file (recommended)
DATA_PATH = ""   # e.g. "my_lightcurve.csv" or "timeseries.fits"

# Column names if using a table-like file (CSV/TSV)
TIME_COL = "time"
FLUX_COL = "flux"
FLUX_ERR_COL = "flux_err"   # can be None if not available

# Time scale/format hints (choose what matches your data)
# Examples:
#   format="jd" / "mjd" / "isot"   (see astropy Time formats)
#   scale="utc" / "tdb" / "tai"    (choose correctly if you know it)
TIME_FORMAT = "mjd"
TIME_SCALE = "utc"

# Option B (optional): Query JWST from MAST by target or program id
TARGET_NAME = "WASP-39"
SEARCH_RADIUS = 0.03 * u.deg
PROGRAM_ID = None   # e.g. 1366 (set to an integer to query by program)

# Cleaning
SIGMA_CLIP = 5.0

# Period search (Lomb–Scargle)
MIN_PERIOD = 0.05   # days
MAX_PERIOD = 20.0   # days
N_FREQ = 20000

print("Config loaded.")


In [None]:
# -----------------------------
# Helpers
# -----------------------------
def load_timeseries(path, time_col=TIME_COL, flux_col=FLUX_COL, flux_err_col=FLUX_ERR_COL):
    """Load time series from CSV/TSV/FITS into a pandas DataFrame with columns: time, flux, flux_err."""
    if not path:
        raise ValueError("Set DATA_PATH to a CSV/TSV/FITS file, or use the MAST query section below.")

    ext = path.lower().split(".")[-1]
    if ext in ("csv",):
        df = pd.read_csv(path)
    elif ext in ("tsv", "tab", "txt"):
        df = pd.read_csv(path, sep="\t")
    elif ext in ("fits", "fit", "fz"):
        tab = Table.read(path)
        df = tab.to_pandas()
    else:
        raise ValueError(f"Unsupported file extension: .{ext}")

    if time_col not in df.columns or flux_col not in df.columns:
        raise KeyError(f"Missing required columns. Found: {list(df.columns)[:30]} ...")

    out = pd.DataFrame({
        "time": df[time_col].to_numpy(),
        "flux": df[flux_col].to_numpy()
    })
    if flux_err_col and flux_err_col in df.columns:
        out["flux_err"] = df[flux_err_col].to_numpy()
    else:
        out["flux_err"] = np.nan

    return out


def sanitize_timeseries(df, sigma=SIGMA_CLIP):
    """Drop NaNs/infs, sort by time, and sigma-clip flux outliers."""
    x = df.copy()
    x = x.replace([np.inf, -np.inf], np.nan).dropna(subset=["time", "flux"])
    x = x.sort_values("time").reset_index(drop=True)

    clipped = sigma_clip(x["flux"].to_numpy(), sigma=sigma, maxiters=5, masked=True)
    keep = ~clipped.mask
    x = x.loc[keep].reset_index(drop=True)
    return x


def to_astropy_time(time_array, format=TIME_FORMAT, scale=TIME_SCALE):
    """Convert numeric/string time array to astropy Time."""
    return Time(time_array, format=format, scale=scale)


def plot_lightcurve(df, title="Light curve", show_err=True):
    t = df["time"].to_numpy()
    y = df["flux"].to_numpy()
    yerr = df.get("flux_err", np.full_like(y, np.nan))

    plt.figure()
    if show_err and np.isfinite(yerr).any():
        plt.errorbar(t, y, yerr=yerr, fmt=".", markersize=3, linewidth=0.5)
    else:
        plt.plot(t, y, ".", markersize=3)
    plt.xlabel("Time")
    plt.ylabel("Flux")
    plt.title(title)
    plt.tight_layout()
    plt.show()


def normalize_flux(df, method="median"):
    x = df.copy()
    if method == "median":
        scale = np.nanmedian(x["flux"].to_numpy())
    elif method == "mean":
        scale = np.nanmean(x["flux"].to_numpy())
    else:
        raise ValueError("method must be 'median' or 'mean'")
    x["flux_norm"] = x["flux"] / scale
    if "flux_err" in x.columns and np.isfinite(x["flux_err"]).any():
        x["flux_err_norm"] = x["flux_err"] / scale
    return x


def detrend_savgol(df, window_length=101, polyorder=2):
    """Detrend using Savitzky–Golay smoothing (requires scipy)."""
    if not HAVE_SCIPY:
        raise ImportError("scipy is required for Savitzky–Golay detrending.")

    x = df.copy()
    y = x["flux_norm"].to_numpy() if "flux_norm" in x.columns else x["flux"].to_numpy()

    # window_length must be odd and <= len(y)
    wl = int(window_length)
    if wl % 2 == 0:
        wl += 1
    wl = min(wl, len(y) - (1 - len(y) % 2))  # keep it odd and <= len(y)
    wl = max(wl, 5)

    trend = savgol_filter(y, window_length=wl, polyorder=polyorder, mode="interp")
    x["trend"] = trend
    x["flux_detrended"] = y / trend
    return x


def lomb_scargle_periodogram(df, min_period=MIN_PERIOD, max_period=MAX_PERIOD, n_freq=N_FREQ):
    """Compute Lomb–Scargle power vs period."""
    t = df["time"].to_numpy()
    y = df["flux_detrended"].to_numpy() if "flux_detrended" in df.columns else (
        df["flux_norm"].to_numpy() if "flux_norm" in df.columns else df["flux"].to_numpy()
    )
    yerr = None
    if "flux_err_norm" in df.columns and np.isfinite(df["flux_err_norm"]).any():
        yerr = df["flux_err_norm"].to_numpy()
    elif "flux_err" in df.columns and np.isfinite(df["flux_err"]).any():
        yerr = df["flux_err"].to_numpy()

    # frequencies in 1/day
    fmin = 1.0 / max_period
    fmax = 1.0 / min_period
    freq = np.linspace(fmin, fmax, int(n_freq))

    ls = LombScargle(t, y, dy=yerr)
    power = ls.power(freq)
    period = 1.0 / freq

    best_idx = int(np.nanargmax(power))
    best_period = period[best_idx]
    return period, power, best_period


def phase_fold(df, period):
    x = df.copy()
    t = x["time"].to_numpy()
    phase = ((t - t.min()) / period) % 1.0
    x["phase"] = phase
    return x.sort_values("phase").reset_index(drop=True)


def plot_periodogram(period, power, best_period=None):
    plt.figure()
    plt.plot(period, power)
    plt.gca().invert_xaxis()  # common in astro to show short periods on the right
    plt.xlabel("Period (days)")
    plt.ylabel("Lomb–Scargle Power")
    plt.title("Periodogram")
    if best_period is not None:
        plt.axvline(best_period, linestyle="--")
    plt.tight_layout()
    plt.show()


def plot_phase(df, ycol="flux_detrended", title="Phase-folded"):
    plt.figure()
    plt.plot(df["phase"], df[ycol], ".", markersize=3)
    plt.xlabel("Phase")
    plt.ylabel(ycol)
    plt.title(title)
    plt.tight_layout()
    plt.show()


## 1) Load and inspect your time series

In [None]:
# If you're loading a local file:
# df = load_timeseries(DATA_PATH)
# df = sanitize_timeseries(df)
# df = normalize_flux(df)
# plot_lightcurve(df, title="Raw light curve")

# If you're querying JWST from MAST (optional):
def query_jwst(program_id=None, target_name=None, search_radius=None):
    if not HAVE_MAST:
        raise ImportError("astroquery.mast is not available in this environment.")

    program_id = program_id if program_id is not None else PROGRAM_ID
    target_name = target_name if target_name is not None else TARGET_NAME
    search_radius = search_radius if search_radius is not None else SEARCH_RADIUS

    if program_id is not None:
        print(f"Querying JWST by PROGRAM_ID={program_id} ...")
        return Observations.query_criteria(obs_collection="JWST", proposal_id=str(program_id))

    print(f"Querying JWST near TARGET_NAME='{target_name}' (radius={search_radius}) ...")
    obs = Observations.query_object(target_name, radius=search_radius)
    if "obs_collection" in obs.colnames:
        obs = obs[obs["obs_collection"] == "JWST"]
    return obs

# Example:
# obs = query_jwst()
# obs[:5]


## 2) Detrend / normalize

In [None]:
# df = detrend_savgol(df, window_length=101, polyorder=2)
# plot_lightcurve(df.assign(flux=df["flux_detrended"]), title="Detrended (Savgol)", show_err=False)


## 3) Period search (Lomb–Scargle)

In [None]:
# period, power, best_period = lomb_scargle_periodogram(df)
# print("Best period (days):", best_period)
# plot_periodogram(period, power, best_period=best_period)


## 4) Phase-fold and visualize

In [None]:
# folded = phase_fold(df, best_period)
# plot_phase(folded, ycol="flux_detrended", title=f"Phase-folded at P={best_period:.6f} d")


## 5) Export results

In [None]:
# df.to_csv("cleaned_timeseries.csv", index=False)
# folded.to_csv("phase_folded_timeseries.csv", index=False)
print("Done.")
