<a href="https://colab.research.google.com/github/wasnaqvi/colab_notebooks/blob/main/Survey.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [66]:
import pandas as pd

class HermesData:
    """
    Minimal stub to mimic src.data.HermesData
    Only requirement: .df attribute
    """
    def __init__(self, df: pd.DataFrame):
        self.df = df


In [67]:
from __future__ import annotations

import math
from typing import Dict, List, Iterable, Optional, Sequence

import numpy as np
import pandas as pd



def compute_leverage(arr: np.ndarray) -> float:
    """
    1D leverage proxy:
        L = sqrt( sum_i (x_i - mean(x))^2 )
    computed on finite values only.

    Notes
    -----
    - This is essentially sqrt(n) * std(x).
      Grows with both spread and sample size.
    """
    arr = np.asarray(arr, float)
    m = np.isfinite(arr)
    arr = arr[m]
    if arr.size < 2:
        return 0.0
    return float(np.sqrt(np.sum((arr - arr.mean()) ** 2)))


def _infer_name_col(df: pd.DataFrame) -> Optional[str]:
    """
    Best-effort inference of a planet-name column.
    Returns the column name if found, else None.
    """
    candidates = [
        "Planet Name",
        "planet_name",
        "pl_name",
        "Name",
        "name",
        "Planet",
        "planet",
        "Target",
        "target",
        "TOI",
        "toi",
    ]
    for c in candidates:
        if c in df.columns:
            return c
    return None


def _extract_names(df: pd.DataFrame, name_col: Optional[str]) -> List[str]:
    """
    Row-aligned extraction of planet names from df.

    If name_col is None or missing, falls back to "row_{i}".
    """
    if name_col is not None and name_col in df.columns:
        # make robust strings; preserve ordering aligned to df rows
        s = df[name_col].astype(str).fillna("")
        # optional: strip whitespace
        names = [x.strip() for x in s.to_list()]
        # if any are empty, replace those with row_i identifiers
        out: List[str] = []
        for i, nm in enumerate(names):
            out.append(nm if nm else f"row_{i}")
        return out

    # fallback: stable identifiers even if there is no name column
    return [f"row_{i}" for i in range(len(df))]



class Survey:
    """
    Survey
    ------
    A `Survey` represents *one sampled subset* of the parent Hermes dataset.

    Think of it as the atomic unit you fit models to:
      - It contains the DataFrame slice (`df`) that your model will ingest
      - It contains metadata about how it was drawn (survey_id, class_label)
      - NEW: it also carries planet/target names forward so you can:
          * inspect "what was actually sampled"
          * label/annotate plots per-survey
          * later overlay posterior predictive / fitted curves with planet names

    What is stored?
    ---------------
    Nothing permanent is written anywhere — names live only in-memory inside
    each Survey instance:

      - `planet_names`: list[str] aligned with df rows
      - `planet_index`: dict[str, list[int]] mapping name -> row indices
        (list because duplicates can happen; e.g., repeated identifiers)
    """

    def __init__(
        self,
        survey_id: int,
        class_label: str,
        df: pd.DataFrame,
        *,
        name_col: Optional[str] = None,
    ):
        self.survey_id = int(survey_id)
        self.class_label = str(class_label)
        self.df = df.reset_index(drop=True)

        # NEW: carry forward names (row-aligned)
        self.name_col = name_col if (name_col in self.df.columns) else None
        self.planet_names: List[str] = _extract_names(self.df, self.name_col)

        # NEW: quick lookup from name -> rows inside this survey
        self.planet_index: Dict[str, List[int]] = {}
        for i, nm in enumerate(self.planet_names):
            self.planet_index.setdefault(nm, []).append(i)

    @property
    def n(self) -> int:
        return len(self.df)

    # targets.

    def targets(self) -> List[str]:
        """Return the planet/target names in this survey (row-aligned)."""
        return list(self.planet_names)

    def target_table(self, cols: Optional[Sequence[str]] = None) -> pd.DataFrame:
        """
        Return a small inspection table with names + selected columns.

        """
        out = pd.DataFrame({"planet_name": self.planet_names})
        if cols:
            cols = [c for c in cols if c in self.df.columns]
            out = pd.concat([out, self.df.loc[:, cols].reset_index(drop=True)], axis=1)
        return out

    def row_for_target(self, name: str) -> List[int]:
        """Return row indices in `df` for a given target name (may be multiple)."""
        return self.planet_index.get(str(name), [])

    # leverage and metrics testing.

    def leverage(self, col: str = "logM") -> float:
        """
        Leverage of the specified column.
        Default: leverage of logM.
        """
        return compute_leverage(self.df[col].to_numpy(float))

    def leverage_2D(self, col_x: str = "logM", col_y: str = "Star Metallicity") -> float:
        """
        2D leverage proxy as quadrature sum of 1D leverages.
        """
        return float(np.sqrt(self.leverage(col_x) ** 2 + self.leverage(col_y) ** 2))

    def leverage_3D(
        self,
        col_x: str = "logM",
        col_y: str = "Star Metallicity",
        col_z: str = "Planet Radius [Re]",
    ) -> float:
        return float(math.cbrt(self.leverage(col_x) ** 2 + self.leverage(col_y) ** 2 + self.leverage(col_z) ** 2))

    def mahalanobis_3D(
        self,
        col_x: str = "logM",
        col_y: str = "Star Metallicity",
        col_z: str = "Planet Radius [Re]",
    ) -> float:
        """
        Mean 3D Mahalanobis distance of points in the specified columns,
        computed within this survey.

        Returns 0 if there are <2 finite rows.
        """
        data = self.df[[col_x, col_y, col_z]].to_numpy(float)
        data = data[np.all(np.isfinite(data), axis=1)]
        if data.shape[0] < 2:
            return 0.0
        mean = np.mean(data, axis=0)
        cov = np.cov(data, rowvar=False)
        inv_cov = np.linalg.inv(cov)
        diff = data - mean
        m_dist = np.sqrt(np.einsum("ij,jk,ik->i", diff, inv_cov, diff))
        return float(np.mean(m_dist))


class SurveySampler:
    """
    SurveySampler
    -------------
    Builds nested mass classes (S1..S4) from HermesData and draws many Survey
    objects from them over an N-grid.

    Core idea:
      - HermesData is the "parent population" (your ARIEL MCS or synthetic set)
      - SurveySampler constructs *class-conditional subsets* (S1..S4)
      - sample_grid draws many Survey realizations without replacement
        for each (class, N) combination.

    NEW:
      - You can specify `name_col` (planet-name column) once in the sampler.
      - Each Survey produced will carry those names forward in-memory.

    Why this matters:
      - Compute survey-level metrics (leverage, WAIC diffs, etc.)
        and still have full traceability to the *exact targets* that drove
        that result.
    """

    def __init__(
        self,
        hermes: HermesData,
        rng_seed: Optional[int] = None,
        *,
        name_col: Optional[str] = None,
    ):
        self.hermes = hermes
        self.rng = np.random.default_rng(rng_seed)

        # NEW: choose / infer planet name column once (used for all surveys)
        if name_col is None:
            name_col = _infer_name_col(self.hermes.df)
        self.name_col = name_col

        # build nested mass classes based on logM quantiles
        self.mass_classes: Dict[str, pd.DataFrame] = self._build_mass_classes()

    def _build_mass_classes(self) -> Dict[str, pd.DataFrame]:
        df = self.hermes.df
        q25, q50, q75 = df["logM"].quantile([0.25, 0.5, 0.75])

        classes: Dict[str, pd.DataFrame] = {}
        classes["S1"] = df.copy()
        classes["S2"] = df[df["logM"] >= q25].copy()
        classes["S3"] = df[df["logM"] >= q50].copy()
        classes["S4"] = df[df["logM"] >= q75].copy()
        return classes

    def sample_grid(
        self,
        N_grid: Iterable[int],
        n_reps_per_combo: int = 10,
        class_order: Optional[List[str]] = None,
    ) -> List[Survey]:
        """
        For each class in class_order and each N in N_grid,
        draw n_reps_per_combo surveys without replacement.

        Returns
        -------
        surveys : list[Survey]
            Flat list of Survey objects. Each Survey contains:
              - df: sampled targets
              - planet_names, planet_index: NEW traceability layer
        """
        if class_order is None:
            class_order = ["S1", "S2", "S3", "S4"]

        surveys: List[Survey] = []
        survey_id = 1

        for label in class_order:
            if label not in self.mass_classes:
                continue
            subset = self.mass_classes[label]
            n_available = len(subset)

            for N in N_grid:
                if N > n_available:
                    continue

                for _ in range(n_reps_per_combo):
                    rs = int(self.rng.integers(0, 2**32 - 1))
                    sample_df = subset.sample(n=N, replace=False, random_state=rs)
                    surveys.append(
                        Survey(
                            survey_id,
                            label,
                            sample_df,
                            name_col=self.name_col,
                        )
                    )
                    survey_id += 1

        return surveys


In [68]:
df = pd.read_csv("Ariel_MCS_Known_2024-07-09.csv")
print(df.columns)
print(df.head())
print("N =", len(df))

df['logM']=np.log10(df['Planet Mass [Me]'])


Index(['Star Name', 'Star Mass [Ms]', 'Star Mass Error Lower [Ms]',
       'Star Mass Error Upper [Ms]', 'Star Temperature [K]',
       'Star Temperature Error Lower [K]', 'Star Temperature Error Upper [K]',
       'Star Radius [Rs]', 'Star Radius Error Lower [Rs]',
       'Star Radius Error Upper [Rs]',
       ...
       'Tier 2 Eclipses', 'Tier 3 Eclipses', 'Preferred Method',
       'Tier 1 Observations', 'Tier 2 Observations', 'Tier 3 Observations',
       'FGS1_Flag', 'FGS2_Flag', 'FGS_Flag', 'Max Tier'],
      dtype='object', length=197)
  Star Name  Star Mass [Ms]  Star Mass Error Lower [Ms]  \
0   WASP-43            0.69                       -0.04   
1   WASP-47            1.11                       -0.49   
2  TOI-5704            0.73                       -0.08   
3   TOI-672            0.54                       -0.02   
4   TOI-199            0.94                       -0.01   

   Star Mass Error Upper [Ms]  Star Temperature [K]  \
0                        0.04           

In [79]:
[c for c in df.columns if "name" in c.lower()]

['Star Name', 'Planet Name']

In [70]:
hermes=HermesData(df)
sampler=SurveySampler(hermes,rng_seed=123,name_col="logM")

surveys = sampler.sample_grid(
    N_grid=[5,10,15],
    n_reps_per_combo=2,
)

len(surveys)


24

In [77]:
for s in surveys:
    print(s.survey_id, s.class_label, s.n)

surveys[10].df

1 S1 5
2 S1 5
3 S1 10
4 S1 10
5 S1 15
6 S1 15
7 S2 5
8 S2 5
9 S2 10
10 S2 10
11 S2 15
12 S2 15
13 S3 5
14 S3 5
15 S3 10
16 S3 10
17 S3 15
18 S3 15
19 S4 5
20 S4 5
21 S4 10
22 S4 10
23 S4 15
24 S4 15


Unnamed: 0,Star Name,Star Mass [Ms],Star Mass Error Lower [Ms],Star Mass Error Upper [Ms],Star Temperature [K],Star Temperature Error Lower [K],Star Temperature Error Upper [K],Star Radius [Rs],Star Radius Error Lower [Rs],Star Radius Error Upper [Rs],...,Tier 3 Eclipses,Preferred Method,Tier 1 Observations,Tier 2 Observations,Tier 3 Observations,FGS1_Flag,FGS2_Flag,FGS_Flag,Max Tier,logM
0,TOI-5126,1.24,-0.05,0.05,6150.0,-130.0,110.0,1.24,-0.03,0.03,...,200.0,Transit,3,28,54,0,0,0,1,1.362105
1,WASP-72,1.39,-0.06,0.06,6250.0,-100.0,100.0,1.98,-0.24,0.24,...,5.0,Eclipse,1,3,5,0,0,0,2,2.69143
2,TOI-1842,1.46,-0.03,0.03,6230.0,-50.0,50.0,2.02,-0.05,0.05,...,52.0,Transit,1,7,14,0,0,0,2,1.832606
3,TOI-1288,0.89,-0.02,0.04,5225.0,-27.0,23.0,1.01,-0.01,0.01,...,63.0,Eclipse,5,33,63,0,0,0,1,1.62326
4,HAT-P-22,0.92,-0.04,0.04,5302.0,-80.0,80.0,1.04,-0.04,0.04,...,1.0,Eclipse,1,1,1,0,0,0,3,2.834025
5,NGTS-8,0.89,-0.04,0.05,5241.0,-50.0,50.0,0.98,-0.02,0.02,...,30.0,Eclipse,3,15,30,1,0,1,2,2.470676
6,WASP-98,0.81,-0.06,0.06,5473.0,-121.0,121.0,0.74,-0.02,0.02,...,16.0,Eclipse,2,8,16,1,0,1,2,2.466924
7,KELT-7,1.53,-0.05,0.07,6789.0,-49.0,50.0,1.73,-0.04,0.04,...,1.0,Eclipse,1,1,1,0,0,0,3,2.609403
8,TOI-3688A,1.2,-0.08,0.07,5950.0,-100.0,100.0,1.3,-0.04,0.04,...,13.0,Eclipse,1,7,13,0,0,0,2,2.493419
9,WASP-138,1.22,-0.05,0.05,6272.0,-96.0,96.0,1.36,-0.05,0.05,...,16.0,Eclipse,2,8,16,0,0,0,2,2.588553


In [71]:
surveys[20].df.head()

Unnamed: 0,Star Name,Star Mass [Ms],Star Mass Error Lower [Ms],Star Mass Error Upper [Ms],Star Temperature [K],Star Temperature Error Lower [K],Star Temperature Error Upper [K],Star Radius [Rs],Star Radius Error Lower [Rs],Star Radius Error Upper [Rs],...,Tier 3 Eclipses,Preferred Method,Tier 1 Observations,Tier 2 Observations,Tier 3 Observations,FGS1_Flag,FGS2_Flag,FGS_Flag,Max Tier,logM
0,TOI-481,1.14,-0.01,0.02,5735.0,-72.0,72.0,1.66,-0.02,0.02,...,17.0,Eclipse,2,9,17,0,0,0,2,2.686884
1,K2-237,1.26,-0.07,0.05,6360.0,-200.0,190.0,1.26,-0.03,0.03,...,2.0,Eclipse,1,1,2,0,0,0,3,2.636689
2,HATS-52,1.11,-0.05,0.05,6010.0,-150.0,150.0,1.05,-0.06,0.06,...,10.0,Eclipse,1,5,10,1,0,1,2,2.852441
3,WASP-82,1.48,-0.37,0.37,6480.0,-90.0,90.0,2.1,-0.16,0.16,...,1.0,Eclipse,1,1,1,0,0,0,3,2.570379
4,TOI-1820,1.04,-0.13,0.13,5734.0,-50.0,50.0,1.51,-0.06,0.06,...,13.0,Eclipse,1,7,13,0,0,0,2,2.863921


In [73]:
s = surveys[0]

print("Survey ID:", s.survey_id)
print("Class:", s.class_label)
print("N:", s.n)
print("Targets:", s.targets())


Survey ID: 1
Class: S1
N: 5
Targets: ['2.020706650581999', '1.0806264869218056', '2.0536427083955098', '0.21231007576265515', '2.4165358677014988']


In [None]:
s.target_table(
    cols=["logM", "Star Metallicity", "Planet Radius"]
)


In [None]:
name = s.targets()[4]
print("Planet:", name)
print("Row indices:", s.row_for_target(name))
print(s.df.iloc[s.row_for_target(name)])


In [None]:
print("L1:", s.leverage("logM"))
print("L2:", s.leverage_2D())
print("L3:", s.leverage_3D())
print("Mahalanobis:", s.mahalanobis_3D())


In [78]:
ARIEL_CSV = "Ariel_MCS_Known_2024-07-09.csv"
OUT_CSV   = "hermes_synthetic_data_0.3.0.csv"

rng = np.random.default_rng(42)

df = pd.read_csv(ARIEL_CSV)


# Required columns in ARIEL Known
name_col = "Planet Name"
mass_col = "Planet Mass [Mjup]"
feh_col  = "Star Metallicity"


tess_col  = "Star TESS Mag"        # brighter -> better precision
depth_col = "Transit Depth [%]"    # deeper -> better precision
rj_col    = "Planet Radius [Rjup]" # might be useful later
method_col = "Preferred Method"
required = [name_col, mass_col, feh_col]
missing = [c for c in required if c not in df.columns]
if missing:
    raise KeyError(f"ARIEL file is missing required columns: {missing}")

work = df.copy()

# Keep only finite mass + metallicity + name
work = work[np.isfinite(work[mass_col]) & np.isfinite(work[feh_col])]
work = work[work[mass_col] > 0]
work = work[work[name_col].notna()]

# Compute logM = log10(M/Mjup)
work["logM"] = np.log10(work[mass_col].to_numpy(float))


# Synthetic 1D "true" relation for log(X_H2O)
# Synthetic 2D "true relation"

# need to look at notes for this one!!
# eventually need to write down these values and check in with welbanks and that new paper.
alpha = -3.0     # baseline abundance in log10
beta  = -0.30    # mass trend coefficient vs logM
gamma =  0.60    # stellar metallicity trend coefficient vs [Fe/H]
sigma_intr = 0.55  # intrinsic scatter (dex)

mu = alpha + beta * work["logM"].to_numpy(float) + gamma * work[feh_col].to_numpy(float)
log_x_h2o = mu + rng.normal(0.0, sigma_intr, size=len(work))

# Make synthetic measurement uncertainties
# (use depth + TESS mag if present; otherwise fallback)
base = 0.45  # dex, typical-ish

# default modifiers
mag_term   = np.zeros(len(work))
depth_term = np.zeros(len(work))

if tess_col in work.columns:
    tess = work[tess_col].to_numpy(float)
    # fainter stars -> larger errors
    mag_term = 0.06 * np.clip(tess - 9.0, -3.0, 6.0)

if depth_col in work.columns:
    depth = work[depth_col].to_numpy(float)
    # deeper transits -> smaller errors (log depth)
    # protect against nonpositive depth
    depth_safe = np.where(np.isfinite(depth) & (depth > 0), depth, np.nan)
    logd = np.log10(depth_safe)
    depth_term = -0.10 * np.nan_to_num(logd, nan=0.0)  # deeper => more negative => smaller unc

# asymmetric errors
sigma_meas = base + mag_term + depth_term
sigma_meas = np.clip(sigma_meas, 0.15, 1.20)

unc_lower = sigma_meas * rng.uniform(0.85, 1.15, size=len(work))
unc_upper = sigma_meas * rng.uniform(0.85, 1.15, size=len(work))

# Assemble HERMES synthetic table
# match  0.2.0 schema + add a few Ariel columns in the future??

out = pd.DataFrame({
    "Planet Name": work[name_col].astype(str).to_numpy(),
    "logM": work["logM"].to_numpy(float),
    "log(X_H2O)": log_x_h2o.astype(float),
    "uncertainty_lower": unc_lower.astype(float),
    "uncertainty_upper": unc_upper.astype(float),
    "Star Metallicity": work[feh_col].to_numpy(float),
    "Planet Radius": work[rj_col].to_numpy(float),

})

# carry metallicity error bars if available
for c in ["Star Metallicity Error Lower", "Star Metallicity Error Upper"]:
    if c in work.columns:
        out[c] = work[c].to_numpy(float)

#carry radius / period / etc for later plotting (safe to include)
optional_keep = [
    rj_col,
    "Planet Period [days]",
    "Star TESS Mag",
    "Transit Depth [%]",
]
for c in optional_keep:
    if c in work.columns and c not in out.columns:
        out[c] = work[c].to_numpy()

core = ["logM", "log(X_H2O)", "uncertainty_lower", "uncertainty_upper", "Star Metallicity"]
out = out[np.all(np.isfinite(out[core].to_numpy(float)), axis=1)].reset_index(drop=True)
out['Preferred Method']=df['Preferred Method']
out.to_csv(OUT_CSV, index=False)

print("Wrote:", OUT_CSV)
print("Rows:", len(out), "Cols:", len(out.columns))
print(out.head())


Wrote: hermes_synthetic_data_0.3.0.csv
Rows: 655 Cols: 14
  Planet Name      logM  log(X_H2O)  uncertainty_lower  uncertainty_upper  \
0    WASP-43b  0.300595   -2.928584           0.459125           0.551982   
1    WASP-47b  0.082785   -3.380827           0.577575           0.652239   
2   TOI-5704b -1.363574   -1.921380           0.654179           0.605118   
3    TOI-672b -1.079603   -2.584809           0.591149           0.546080   
4    TOI-199b -0.769551   -3.710204           0.498838           0.558719   

   Star Metallicity  Planet Radius  Star Metallicity Error Lower  \
0            -0.010          1.006                        -0.150   
1             0.360          1.150                        -0.050   
2             0.428          0.288                        -0.100   
3            -0.710          0.469                        -0.625   
4             0.220          0.810                        -0.030   

   Star Metallicity Error Upper  Planet Radius [Rjup]  Planet Period [