
# Patient-level Stratified Train/Test Split (Polars + scikit-learn)

This notebook:
- Loads **three CSVs** (`ml_oneicu.csv`, `ml_eicu.csv`, `ml_mimiciv.csv`) from `../data`
- Performs a **patient-level** (i.e., `icu_stay_id`-level) **train/test split** with **stratification by `outcome_lead`**
- Uses an **80/20** split (configurable)
- Writes **two CSVs per dataset** into `../data/machine_learning`: `<db>_train.csv` and `<db>_test.csv`

**Key points**
- We split by `icu_stay_id` to avoid leaking a patient's time-series across train and test.
- Stratification is based on a **stay-level** label computed from `outcome_lead` using `max` (any positive row ⇒ positive stay). You can switch to mean/threshold if you prefer.
- If stratification is infeasible (e.g., a class has fewer than 2 stays), we fall back to a grouped random split and log a warning.


In [6]:
from pathlib import Path
import logging
import warnings

import polars as pl
from sklearn.model_selection import train_test_split

# ---- parameters (edit here if needed) ----
data_dir = Path("../data/machine_learning")
out_dir = Path("../data/machine_learning")

id_col = "icu_stay_id"
label_col = "outcome_lead"

test_size = 0.2
random_state = 813

logging.basicConfig(level=20, format="%(asctime)s %(levelname)s %(message)s")

In [7]:
def read_ml_csvs(data_dir: Path) -> dict[str, pl.DataFrame]:
    """
    read the three ml csvs into polars dataframes and enforce basic dtypes.

    expects:
      - ml_oneicu.csv
      - ml_eicu.csv
      - ml_mimiciv.csv

    returns a dict with keys: {"oneicu", "eicu", "mimiciv"}.
    """
    mapping = {
        "oneicu": "ml_oneicu.csv",
        "eicu": "ml_eicu.csv",
        "mimiciv": "ml_mimiciv.csv",
    }
    dfs: dict[str, pl.DataFrame] = {}
    for key, filename in mapping.items():
        path = data_dir / filename
        if not path.exists():
            raise FileNotFoundError(f"missing file: {path}")

        df = pl.read_csv(
            path,
            infer_schema_length=None,
        )

        # ensure required columns are present & typed as expected
        if id_col not in df.columns or label_col not in df.columns:
            raise KeyError(f"{path.name} is missing required column(s) '{id_col}' or '{label_col}'")

        dfs[key] = df
        logging.info("loaded %s (%d rows)", filename, df.height)
    return dfs


def stay_level_labels(df, id_col=id_col, label_col=label_col):
    """
    compute stay-level labels from row-level outcome by max aggregation.
    returns a dataframe with columns: [id_col, "y"] where y in {0,1}.
    """
    return (
        df
        .group_by(id_col)
        .agg(pl.col(label_col).max().alias("y"))
        .select(id_col, "y")
    )


def stratified_split_by_patient(
    df,
    id_col=id_col,
    label_col=label_col,
    test_size=test_size,
    random_state=random_state,
):
    """
    split by icu_stay_id with stratification on stay-level labels (max of outcome).
    falls back to grouped split without stratify if a class has < 2 stays.
    """
    stay_y = stay_level_labels(df, id_col=id_col, label_col=label_col)
    ids = stay_y[id_col].to_list()
    y = stay_y["y"].to_list()

    # class counts
    counts = {}
    for v in y:
        counts[v] = counts.get(v, 0) + 1

    stratify = y if len(counts) >= 2 and min(counts.values()) >= 2 else None
    if stratify is None:
        warnings.warn(f"cannot stratify: class counts {counts}. proceeding without stratification.")

    train_ids, test_ids = train_test_split(
        ids,
        test_size=test_size,
        random_state=random_state,
        shuffle=True,
        stratify=stratify,
    )

    train_df = df.filter(pl.col(id_col).is_in(train_ids))
    test_df = df.filter(pl.col(id_col).is_in(test_ids))

    # sanity check: no leakage of stays
    train_stays = set(train_df.select(id_col).unique().to_series().to_list())
    test_stays = set(test_df.select(id_col).unique().to_series().to_list())
    assert train_stays.isdisjoint(test_stays), "leakage detected: some icu_stay_id appear in both train and test."

    logging.info(
        "split by %s: stays train/test = %d/%d, rows train/test = %d/%d",
        id_col,
        stay_y.filter(pl.col(id_col).is_in(train_ids)).height,
        stay_y.filter(pl.col(id_col).is_in(test_ids)).height,
        train_df.height,
        test_df.height,
    )
    return train_df, test_df


def save_train_test_csvs(
    dfs,
    out_dir,
    prefix="ml_",
    id_col=id_col,
    label_col=label_col,
    test_size=test_size,
    random_state=random_state,
):
    """for each dataset, perform the patient-level split and write train/test csvs."""
    out_dir.mkdir(parents=True, exist_ok=True)

    for name, df in dfs.items():
        logging.info("processing dataset: %s", name)
        train_df, test_df = stratified_split_by_patient(
            df,
            id_col=id_col,
            label_col=label_col,
            test_size=test_size,
            random_state=random_state,
        )

        train_path = out_dir / f"{prefix}{name}_train.csv"
        test_path = out_dir / f"{prefix}{name}_test.csv"

        train_df.write_csv(train_path)
        test_df.write_csv(test_path)

        logging.info("wrote: %s (%d rows)", train_path, train_df.height)
        logging.info("wrote: %s (%d rows)", test_path, test_df.height)


In [8]:
dfs = read_ml_csvs(data_dir)
# quick row counts per dataset
{key: df.height for key, df in dfs.items()}

2025-09-04 22:56:18,415 INFO loaded ml_oneicu.csv (1785876 rows)
2025-09-04 22:56:20,019 INFO loaded ml_eicu.csv (1176638 rows)
2025-09-04 22:56:20,814 INFO loaded ml_mimiciv.csv (1111068 rows)


{'oneicu': 1785876, 'eicu': 1176638, 'mimiciv': 1111068}

In [9]:
def stay_class_counts(df):
    counts_df = stay_level_labels(df, id_col, label_col).group_by("y").count()
    return {int(row["y"]): int(row["count"]) for row in counts_df.iter_rows(named=True)}

In [10]:
save_train_test_csvs(
    dfs=dfs,
    out_dir=out_dir,
    prefix="ml_",
    id_col=id_col,
    label_col=label_col,
    test_size=test_size,
    random_state=random_state,
)

print("done. files written to:", out_dir.resolve())

2025-09-04 22:56:20,835 INFO processing dataset: oneicu
2025-09-04 22:56:21,505 INFO split by icu_stay_id: stays train/test = 48853/12214, rows train/test = 1428245/357631
2025-09-04 22:56:23,172 INFO wrote: ../data/machine_learning/ml_oneicu_train.csv (1428245 rows)
2025-09-04 22:56:23,172 INFO wrote: ../data/machine_learning/ml_oneicu_test.csv (357631 rows)
2025-09-04 22:56:23,172 INFO processing dataset: eicu
2025-09-04 22:56:23,329 INFO split by icu_stay_id: stays train/test = 27467/6867, rows train/test = 940257/236381
2025-09-04 22:56:23,614 INFO wrote: ../data/machine_learning/ml_eicu_train.csv (940257 rows)
2025-09-04 22:56:23,614 INFO wrote: ../data/machine_learning/ml_eicu_test.csv (236381 rows)
2025-09-04 22:56:23,614 INFO processing dataset: mimiciv
2025-09-04 22:56:23,705 INFO split by icu_stay_id: stays train/test = 18912/4729, rows train/test = 888052/223016
2025-09-04 22:56:23,852 INFO wrote: ../data/machine_learning/ml_mimiciv_train.csv (888052 rows)
2025-09-04 22:56:2

done. files written to: /Users/kinoshitatakashihiroshi/Dropbox/VS_Code/OneICU_profile_paper/data/machine_learning
