In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Ray tracing with yt + trident for sightlines defined in los_endpoints_obsOriented.csv

- Dataset: TNG50-1, snap_099.*.hdf5 chunks
- Units for endpoints: absolute ckpc/h (matches Gadget native at z~0)
- For each (SubhaloID, span), pair 'minus' and 'plus' endpoints -> one ray
- Spectra: raw & COS-G130M LSF for H I 1216 (Lyα), C II 1334, Si III 1206
- Column densities computed from number densities:
    H_p0, C_p1, Si_p2 (and 'dl' path length from yt)

Outputs under: /scratch/tsingh65/TNG50-1_snap99/result/
    result/
      rays/
        sid<SID>/span_<SPAN>/ray_<RID>.h5    <-- spectra + per-cell columns
      logs/
      summary_rays.csv

Usage:
    python run_trident_rays.py --mode test \
        --data-dir /scratch/tsingh65/TNG50-1_snap99/data \
        --endpoints /scratch/tsingh65/TNG50-1_snap99/los_endpoints_obsOriented.csv

    python run_trident_rays.py --mode all \
        --data-dir /scratch/tsingh65/TNG50-1_snap99/data \
        --endpoints /scratch/tsingh65/TNG50-1_snap99/los_endpoints_obsOriented.csv
"""

import os
import sys
import csv
import h5py
import json
import argparse
import numpy as np
import pandas as pd
import yt
import trident

# ----------------------------
# Config defaults (can be args)
# ----------------------------
DEFAULT_DATA_DIR   = "/scratch/tsingh65/TNG50-1_snap99/data"
DEFAULT_ENDPOINTS  = "/scratch/tsingh65/TNG50-1_snap99/los_endpoints_obsOriented.csv"
DEFAULT_OUT_ROOT   = "/scratch/tsingh65/TNG50-1_snap99/result"
FIRST_CHUNK        = "snap_099.0.hdf5"
INDEX_FILENAME     = "snap_099.0.hdf5.ewah"   # per your note (yt index cache)

IONS     = ["H I 1216", "C II 1334", "Si III 1206"]
ION_DENS = {
    "H I 1216": ("H_p0_number_density", "H I"),
    "C II 1334": ("C_p1_number_density", "C II"),
    "Si III 1206": ("Si_p2_number_density", "Si III"),
}

LSF_NAME = "COS-G130M"

# ----------------------------
# Helpers
# ----------------------------
def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path

def load_dataset(data_dir: str):
    """
    Load Gadget snapshot by pointing yt at the first chunk.
    Also pass/attach the .ewah index filename if supported.
    """
    first = os.path.join(data_dir, FIRST_CHUNK)
    if not os.path.isfile(first):
        raise FileNotFoundError(f"Cannot find first chunk: {first}")

    idx_path = os.path.join(data_dir, INDEX_FILENAME)
    try:
        ds = yt.load(first, index_filename=idx_path)
    except TypeError:
        # yt may not accept index_filename kwarg; load then set attribute if present
        ds = yt.load(first)
        if hasattr(ds, "index_filename"):
            ds.index_filename = idx_path
    return ds

def read_and_pair_endpoints(csv_path: str):
    """
    Read los_endpoints_obsOriented.csv and pair the +/- points into
    rays per (SubhaloID, span).

    Returns a DataFrame with one row per ray:
      columns: Galaxy, SubhaloID, span, inc_deg, phi_deg, rho_kpc, Rvir_kpc,
               start_ckpch(3), end_ckpch(3), ray_id
    """
    df = pd.read_csv(csv_path)
    # strictly require these to exist
    req = ["Galaxy","SubhaloID","span","sign","X_ckpch","Y_ckpch","Z_ckpch",
           "inc_deg","phi_deg","rho_kpc","Rvir_kpc"]
    missing = [c for c in req if c not in df.columns]
    if missing:
        raise KeyError(f"Missing required columns in {csv_path}: {missing}")

    # Pair by key = (SubhaloID, span). If multiple galaxies (unlikely), keep Galaxy from first.
    grp_cols = ["SubhaloID", "span"]
    rays = []
    for (sid, span), g in df.groupby(grp_cols):
        # Expect exactly one 'minus' and one 'plus'
        if set(g["sign"].unique()) != {"minus", "plus"}:
            # Skip incomplete pairs (quietly), but record a warning
            print(f"[WARN] Incomplete +/- pair for sid={sid}, span={span}; skipping.")
            continue

        g_minus = g[g["sign"]=="minus"].iloc[0]
        g_plus  = g[g["sign"]=="plus"].iloc[0]

        # Build ray row
        galaxy = str(g_minus["Galaxy"])
        inc    = float(g_minus["inc_deg"])
        phi    = float(g_minus["phi_deg"])
        rho    = float(g_minus["rho_kpc"])
        rvir   = float(g_minus["Rvir_kpc"])

        start = np.array([g_minus["X_ckpch"], g_minus["Y_ckpch"], g_minus["Z_ckpch"]], dtype=float)
        end   = np.array([g_plus["X_ckpch"],  g_plus["Y_ckpch"],  g_plus["Z_ckpch"] ], dtype=float)

        ray_id = f"sid{int(sid)}_{span}"

        rays.append(dict(
            Galaxy=galaxy,
            SubhaloID=int(sid),
            span=span,
            inc_deg=inc,
            phi_deg=phi,
            rho_kpc=rho,
            Rvir_kpc=rvir,
            start_ckpch=start,
            end_ckpch=end,
            ray_id=ray_id
        ))

    if not rays:
        raise RuntimeError("No valid +/- endpoint pairs found. Check your CSV.")
    return pd.DataFrame(rays)

def make_ray(ds, start_ckpch, end_ckpch, fields=None):
    """
    Build a single ray with trident.make_simple_ray using absolute ckpc/h endpoints.
    Returns the yt Dataset 'ray' object (not saved to file).
    """
    sp = yt.YTArray(start_ckpch, "kpc/h")
    ep = yt.YTArray(end_ckpch,   "kpc/h")

    print("Making ray from", sp, "to", ep)
    

    # If ion fields are *already* in the dataset, you can skip add_ion_fields.
    # If you needed them: trident.add_ion_fields(ds, ions=['H I','C II','Si III'])

    ray = trident.make_simple_ray(
        ds,
        start_position=sp,
        end_position=ep,
        fields=(fields or [
            ("gas", "density"),
            ("gas", "temperature"),
            ("gas", "metallicity"),
            ("gas", "dl"),  # path length per cell
            ("gas", "H_p0_number_density"),
            ("gas", "C_p1_number_density"),
            ("gas", "Si_p2_number_density"),
            # add more fields if you need to debug
        ]),
        ftype="gas",
        data_filename=None,   # don't dump a separate ray file (we'll pack into our HDF5)
        pix_frac=1.0
    )
    return ray

def compute_columns(ray):
    """
    Column densities (true integrals): sum(n_ion * dl).
    Returns dict with totals and per-cell arrays (for optional debugging).
    """
    out = {}
    dl = ray.r[("gas", "dl")]  # length in cm by default (yt units)
    for disp, (fname, _label) in ION_DENS.items():
        n = ray.r[("gas", fname)]
        col = (n * dl).sum()   # units: cm^-2
        out[f"N_tot_{disp}"] = float(col)
    return out



def build_spectrum(ray):
    """
    Use instrument preset 'COS-G130M' to let Trident define the wavelength grid
    and apply the LSF to the flux. We also save a 'raw' (pre-LSF) flux by
    computing exp(-tau) from the returned optical depth.
    """
    line_list = ["H I 1216", "C II 1334", "Si III 1206"]

    # Instrument-configured spectrum (includes COS-G130M LSF)
    sg = trident.SpectrumGenerator(LSF_NAME)
    sg.add_line_list(line_list)
    # Alternatively: sg.make_spectrum(ray, lines=line_list) works too
    sg.make_spectrum(ray, lines=line_list)

    lam = np.array(sg.lambda_field)   # Å
    tau = np.array(sg.tau_field)      # intrinsic optical depth on this grid
    flux_lsf = np.array(sg.flux_field)  # LSF-convolved flux

    # Raw (pre-LSF) flux on the same wavelength grid
    flux_raw = np.exp(-tau)

    return {
        "raw": {"lambda": lam, "flux": flux_raw, "tau": tau},
        "lsf": {"lambda": lam, "flux": flux_lsf, "tau": tau},
    }
    
    
def save_ray_hdf5(out_path, meta, ray, spec, col_dict):
    """
    Write a single HDF5 containing:
      /meta        : JSON-ish attrs (galaxy, sid, span, inc, phi, rho, Rvir)
      /spectrum/raw: lambda, flux, tau
      /spectrum/lsf: lambda, flux, tau
      /columns     : totals (N_tot_HI, N_tot_CII, N_tot_SiIII)
    """
    ensure_dir(os.path.dirname(out_path))
    with h5py.File(out_path, "w") as f:
        # Meta
        gmeta = f.create_group("meta")
        for k, v in meta.items():
            # store small scalars/strings as attrs
            try:
                gmeta.attrs[k] = v
            except TypeError:
                gmeta.attrs[k] = json.dumps(v)

        # Spectra
        gs = f.create_group("spectrum")
        for tag in ("raw", "lsf"):
            gt = gs.create_group(tag)
            gt.create_dataset("lambda_A", data=spec[tag]["lambda"])
            gt.create_dataset("flux",     data=spec[tag]["flux"])
            gt.create_dataset("tau",      data=spec[tag]["tau"])

        # Columns (totals)
        gc = f.create_group("columns")
        for k, v in col_dict.items():
            gc.attrs[k] = v

def process_one_ray(ds, row, out_root):
    """
    Build & save one ray for a paired row (start/end already paired).
    Returns path + summary dict.
    """
    sid   = int(row["SubhaloID"])
    span  = str(row["span"])
    rid   = str(row["ray_id"])
    gal   = str(row["Galaxy"])
    inc   = float(row["inc_deg"])
    phi   = float(row["phi_deg"])
    rho   = float(row["rho_kpc"])
    rvir  = float(row["Rvir_kpc"])

    start = np.array(row["start_ckpch"], dtype=float)
    end   = np.array(row["end_ckpch"],   dtype=float)

    ray = make_ray(ds, start, end)
    spec = build_spectrum(ray)
    cols = compute_columns(ray)

    out_dir = os.path.join(out_root, "rays", f"sid{sid}", f"span_{span}")
    out_h5  = os.path.join(out_dir, f"ray_{rid}.h5")

    meta = dict(
        Galaxy=gal,
        SubhaloID=sid,
        span=span,
        ray_id=rid,
        inc_deg=inc, phi_deg=phi, rho_kpc=rho, Rvir_kpc=rvir,
        start_ckpch=start.tolist(),
        end_ckpch=end.tolist(),
        instrument=LSF_NAME,
        lambda_min=LAMBDA_MIN, lambda_max=LAMBDA_MAX, dlambda=DLAMBDA,
        ions=IONS,
    )

    save_ray_hdf5(out_h5, meta, ray, spec, cols)

    # Also return a summary row for the master CSV
    srow = dict(
        ray_id=rid,
        Galaxy=gal,
        SubhaloID=sid,
        span=span,
        inc_deg=inc, phi_deg=phi, rho_kpc=rho, Rvir_kpc=rvir,
        out_h5=out_h5,
        N_tot_HI=cols.get("N_tot_H I 1216", np.nan) if "N_tot_H I 1216" in cols else cols.get("N_tot_H I", np.nan),
        N_tot_CII=cols.get("N_tot_C II 1334", np.nan) if "N_tot_C II 1334" in cols else cols.get("N_tot_C II", np.nan),
        N_tot_SiIII=cols.get("N_tot_Si III 1206", np.nan) if "N_tot_Si III 1206" in cols else cols.get("N_tot_Si III", np.nan),
    )
    # Fix column keys (because we used display names in compute_columns)
    if "N_tot_H I 1216" in cols:   srow["N_tot_HI"]    = cols["N_tot_H I 1216"]
    if "N_tot_C II 1334" in cols:  srow["N_tot_CII"]   = cols["N_tot_C II 1334"]
    if "N_tot_Si III 1206" in cols:srow["N_tot_SiIII"] = cols["N_tot_Si III 1206"]

    return out_h5, srow

# ----------------------------
# Single-ray "test" runner
# ----------------------------
def run_test(data_dir, endpoints_csv, out_root):
    print("[TEST] Loading dataset…")
    ds = load_dataset(data_dir)

    print("[TEST] Reading and pairing endpoints…")
    rays_df = read_and_pair_endpoints(endpoints_csv)
    print(f"[TEST] Paired rays: {len(rays_df)}")

    # Pick the first one as a smoke test
    row = rays_df.iloc[0]
    print(f"[TEST] Running 1st ray: ray_id={row['ray_id']} sid={row['SubhaloID']} span={row['span']}")

    out_h5, srow = process_one_ray(ds, row, out_root)
    print(f"[TEST OK] wrote: {out_h5}")
    # Write a tiny test summary
    test_csv = os.path.join(out_root, "summary_rays_TEST.csv")
    pd.DataFrame([srow]).to_csv(test_csv, index=False)
    print(f"[TEST OK] wrote: {test_csv}")

# ----------------------------
# Batch runner
# ----------------------------
def run_all(data_dir, endpoints_csv, out_root):
    logs_dir = ensure_dir(os.path.join(out_root, "logs"))
    print("[ALL] Loading dataset…")
    ds = load_dataset(data_dir)

    print("[ALL] Reading and pairing endpoints…")
    rays_df = read_and_pair_endpoints(endpoints_csv)
    print(f"[ALL] Paired rays: {len(rays_df)}")

    summary_rows = []
    for i, row in rays_df.iterrows():
        try:
            print(f"[{i+1}/{len(rays_df)}] ray_id={row['ray_id']} sid={row['SubhaloID']} span={row['span']}")
            out_h5, srow = process_one_ray(ds, row, out_root)
            summary_rows.append(srow)
        except Exception as e:
            print(f"[ERROR] ray failed: ray_id={row.get('ray_id')} sid={row.get('SubhaloID')} span={row.get('span')}: {e}")
            # log error
            with open(os.path.join(logs_dir, "errors.txt"), "a") as f:
                f.write(f"{row.get('ray_id')} | sid={row.get('SubhaloID')} span={row.get('span')} | {repr(e)}\n")

    # Save master summary
    if summary_rows:
        master_csv = os.path.join(out_root, "summary_rays.csv")
        pd.DataFrame(summary_rows).to_csv(master_csv, index=False)
        print(f"[ALL OK] wrote: {master_csv}")
    else:
        print("[ALL] No successful rays to summarize.")

# ----------------------------
# CLI
# ----------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", choices=["test","all"], required=True,
                    help="test: run one ray; all: run all rays")
    ap.add_argument("--data-dir", default=DEFAULT_DATA_DIR,
                    help="Directory containing snap_099.*.hdf5 chunks")
    ap.add_argument("--endpoints", default=DEFAULT_ENDPOINTS,
                    help="Path to los_endpoints_obsOriented.csv")
    ap.add_argument("--out-root", default=DEFAULT_OUT_ROOT,
                    help="Where to save results")
    args = ap.parse_args()

    ensure_dir(args.out_root)
    if args.mode == "test":
        run_test(args.data_dir, args.endpoints, args.out_root)
    else:
        run_all(args.data_dir, args.endpoints, args.out_root)

if __name__ == "__main__":
    main()