In [1]:
import numpy as np, pandas as pd
from pathlib import Path
from sklearn.model_selection import GroupShuffleSplit

# ------------------------------------------------------------------ #
# 1.  Load + basic cleaning                                          
# ------------------------------------------------------------------ #
df_all  = pd.read_csv(
    "/data1/home/srinivasana/peds_agents/llm_2stage_fc_predictions_full.csv",
    dtype=str   # keep everything as string → easier NA handling
).rename(columns={"app_id": "canon_id"})

df_gold = (
    pd.read_excel(
        "/data1/home/srinivasana/peds_agents/data/manual_annotated_labels.xlsx",
        engine="openpyxl", dtype=str
    )
    .rename(columns={"resovled_label_A": "manual_label_resolved"})  # typo in sheet
    [["canon_id", "manual_label_resolved"]]
    .drop_duplicates("canon_id")
)

def _clean_label(col: pd.Series) -> pd.Series:
    """empty / whitespace → NaN → 'NotExtrapolated'"""
    return (
        col.replace(r"^\s*$", np.nan, regex=True)
           .fillna("NotExtrapolated")
           .str.strip()
    )

df_all ["resolved_label"]        = _clean_label(df_all ["resolved_label"])
df_gold["manual_label_resolved"] = _clean_label(df_gold["manual_label_resolved"])

# ------------------------------------------------------------------ #
# 2.  Merge – each row now has BOTH labels (manual may be NaN)       
# ------------------------------------------------------------------ #
df_all = df_all.merge(df_gold, on="canon_id", how="left", validate="m:1")

# final label: manual if present else LLM prediction -----------------
df_all["label"] = df_all["manual_label_resolved"].fillna(df_all["resolved_label"])

# ------------------------------------------------------------------ #
# 3.  Identify the 100 manual rows (one per canon_id)                 
# ------------------------------------------------------------------ #
gold_rows = df_all[df_all["manual_label_resolved"].notna()].copy()

# ------------------------------------------------------------------ #
# 4.  70 / 30 split on GOLD rows, stratified by true class & grouped  
# ------------------------------------------------------------------ #
gss = GroupShuffleSplit(n_splits=1, test_size=0.30, random_state=42)

train_idx, test_idx = next(
    gss.split(gold_rows,
              gold_rows["manual_label_resolved"],   # stratify
              groups=gold_rows["canon_id"].astype(str))
)

gold_train = gold_rows.iloc[train_idx].reset_index(drop=True)
gold_test  = gold_rows.iloc[test_idx ].reset_index(drop=True)

# sanity: all four labels in each split ------------------------------
assert gold_train["manual_label_resolved"].nunique() == gold_rows["manual_label_resolved"].nunique()
assert gold_test ["manual_label_resolved"].nunique() == gold_rows["manual_label_resolved"].nunique()

# ------------------------------------------------------------------ #
# 5.  Add high-confidence pseudo LLM rows to the TRAIN split          
# ------------------------------------------------------------------ #
df_pseudo_high = df_all[
    df_all["manual_label_resolved"].isna() & df_all["confidence"].eq("high")
].copy()

df_train = pd.concat([gold_train, df_pseudo_high], ignore_index=True)
df_train["is_gold"] = df_train["manual_label_resolved"].notna().astype(int)

# mark gold rows in test ---------------------------------------------
gold_test["is_gold"] = 1

# ------------------------------------------------------------------ #
# 6.  Persist                                                         
# ------------------------------------------------------------------ #
out_dir = Path("/data1/home/srinivasana/peds_agents/agents/notebooks/splits")
out_dir.mkdir(parents=True, exist_ok=True)

df_train.to_csv(out_dir / "train.csv", index=False)
gold_test.to_csv(out_dir / "test.csv",  index=False)

print("✓ splits written →", out_dir)
print("  train rows:", len(df_train), "  (gold:", df_train['is_gold'].sum(), ")")
print("  test  rows:", len(gold_test))
print("\nlabel distribution (train gold only):")
print(gold_train["manual_label_resolved"].value_counts())
print("\nlabel distribution (test):")
print(gold_test["manual_label_resolved"].value_counts())


✓ splits written → /data1/home/srinivasana/peds_agents/agents/notebooks/splits
  train rows: 709   (gold: 76 )
  test  rows: 33

label distribution (train gold only):
manual_label_resolved
NotExtrapolated    39
Partial            32
Unlabeled           4
Full                1
Name: count, dtype: int64

label distribution (test):
manual_label_resolved
NotExtrapolated    21
Partial             9
Unlabeled           2
Full                1
Name: count, dtype: int64
