# find the high quality subset

In [None]:
#!/usr/bin/env python3

import os
import json
import numpy as np
import pandas as pd

=================
# CONFIGURATION
=================
HIST_LAB_PATH = "./mimic_wav_lab_history.csv"
CURR_LAB_PATH = "./mimic_wav_lab_overlap.csv"

HIST_VITAL_PATH = "./mimic_wav_vital_history.csv"
CURR_VITAL_PATH = "./mimic_wav_vital_overlap.csv"

WAVEFORM_DIR = "/opt/localdata100tb/UNIPHY_Plus/dataset/EST/MIMIC3_SPO2_I_40hz_v3"

OUTPUT_LIST = "mimic_high_quality_info_list.json"

lab_cols = [
    "Potassium","Calcium","Sodium","Glucose",
    "Lactate","Creatinine"
]

vital_cols = [
    "GM",
    "ABPs","ABPd","ABPm",
    "NBPs","NBPd","NBPm"
]

=================
# STAGE A — LOAD LAB + VITAL TABLES
=================

print("\n=== Loading lab + vital tables ===")
df_hist_lab = pd.read_csv(HIST_LAB_PATH, parse_dates=["CHARTTIME"])
df_curr_lab = pd.read_csv(CURR_LAB_PATH, parse_dates=["CHARTTIME"])

df_hist_vit = pd.read_csv(HIST_VITAL_PATH, parse_dates=["CHARTTIME"])
df_curr_vit = pd.read_csv(CURR_VITAL_PATH, parse_dates=["CHARTTIME"])


# ---------- LAB summary ----------
df_curr_lab_summary = (
    df_curr_lab.groupby(["SUBJECT_ID","HADM_ID"])[lab_cols]
               .count()
               .reset_index()
)
df_curr_lab_summary = df_curr_lab_summary.rename(columns={c:f"{c}_curr" for c in lab_cols})

df_hist_lab_summary = (
    df_hist_lab.groupby(["SUBJECT_ID","HADM_ID"])[lab_cols]
               .count()
               .reset_index()
)
df_hist_lab_summary = df_hist_lab_summary.rename(columns={c:f"{c}_hist" for c in lab_cols})

lab_all = df_curr_lab_summary.merge(df_hist_lab_summary,
                                    on=["SUBJECT_ID","HADM_ID"],
                                    how="outer").fillna(0)

lab_all["curr_lab_types_present"] = (lab_all[[f"{c}_curr" for c in lab_cols]] > 0).sum(axis=1)
lab_all["curr_lab_types_dense"]   = (lab_all[[f"{c}_curr" for c in lab_cols]] >= 2).sum(axis=1)


# ---------- VITAL summary ----------
df_curr_vit_summary = (
    df_curr_vit.groupby(["SUBJECT_ID","HADM_ID"])[vital_cols]
               .count()
               .reset_index()
)
df_curr_vit_summary = df_curr_vit_summary.rename(columns={c:f"{c}_curr_v" for c in vital_cols})

df_hist_vit_summary = (
    df_hist_vit.groupby(["SUBJECT_ID","HADM_ID"])[vital_cols]
               .count()
               .reset_index()
)
df_hist_vit_summary = df_hist_vit_summary.rename(columns={c:f"{c}_hist_v" for c in vital_cols})

vital_all = df_curr_vit_summary.merge(df_hist_vit_summary,
                                      on=["SUBJECT_ID","HADM_ID"],
                                      how="outer").fillna(0)

vital_all["curr_vital_types_present"] = (vital_all[[f"{c}_curr_v" for c in vital_cols]] > 0).sum(axis=1)
vital_all["curr_vital_types_dense"]   = (vital_all[[f"{c}_curr_v" for c in vital_cols]] >= 3).sum(axis=1)

print("Total SUBJECT/HADM with lab+vital data:", len(lab_all))


=================
# STAGE B — LOAD WAVEFORM METADATA FROM NPZ FILES
=================

print("\n=== Scanning waveform files ===")

wave_recs = []
files = [f for f in os.listdir(WAVEFORM_DIR) if f.endswith(".npz")]
print("Total files in waveform dir:", len(files))

for fname in files:
    base = fname[:-4]
    parts = base.split("_")
    if len(parts) < 7:
        continue

    try:
        subj_id = int(parts[0])
        hadm_id = int(parts[1])
        clip_str = parts[2]
        nseg = int(parts[-1])
    except ValueError:
        continue

    try:
        clip_raw_index = [int(x) for x in clip_str.split("-")]
    except:
        clip_raw_index = []

    arr = np.load(os.path.join(WAVEFORM_DIR, fname), allow_pickle=True)
    if "time" not in arr:
        continue

    time_ms = arr["time"]
    if len(time_ms) == 0:
        continue

    wave_start_dt = pd.to_datetime(time_ms[0], unit="ms", errors="coerce")
    wave_end_dt   = pd.to_datetime(time_ms[-1] + 30000, unit="ms", errors="coerce")

    wave_recs.append((
        subj_id,
        hadm_id,
        fname,
        nseg,
        wave_start_dt,
        wave_end_dt,
        clip_raw_index
    ))

df_wave = pd.DataFrame(
    wave_recs,
    columns=[
        "SUBJECT_ID","HADM_ID","file","nseg",
        "wave_start_dt","wave_end_dt","clip_raw_index"
    ],
)

print("Parsed waveform entries:", df_wave.shape[0])


=================
# STAGE C — MERGE LAB + VITAL + WAVEFORM
=================

df_lab_vit = lab_all.merge(vital_all, on=["SUBJECT_ID","HADM_ID"], how="outer").fillna(0)
df_all = df_lab_vit.merge(df_wave, on=["SUBJECT_ID","HADM_ID"], how="inner")

print("\nMerged lab+vital+wave entries:", df_all.shape[0])


=================
# STRICT Lab Overlap + STRICT Vital Overlap
=================

print("\n=== Computing STRICT Lab–Waveform and Vital–Waveform Overlap ===")

curr_lab_by_hadm = df_curr_lab.groupby(["SUBJECT_ID","HADM_ID"])
curr_vit_by_hadm = df_curr_vit.groupby(["SUBJECT_ID","HADM_ID"])

lab_overlap = []
vit_overlap = []

for _, row in df_all.iterrows():

    subj = row["SUBJECT_ID"]
    hadm = row["HADM_ID"]

    w_start = row["wave_start_dt"]
    w_end   = row["wave_end_dt"]

    key = (subj, hadm)

    # LAB
    if key in curr_lab_by_hadm.groups:
        labs = curr_lab_by_hadm.get_group(key)
        times = labs["CHARTTIME"].dropna().sort_values().to_numpy()
        inside = (times >= w_start.to_datetime64()) & (times <= w_end.to_datetime64())
        lab_overlap.append(float(inside.sum() / len(times)) if len(times) else 0.0)
    else:
        lab_overlap.append(0.0)

    # VITAL
    if key in curr_vit_by_hadm.groups:
        vs = curr_vit_by_hadm.get_group(key)
        vt = vs["CHARTTIME"].dropna().sort_values().to_numpy()
        inside = (vt >= w_start.to_datetime64()) & (vt <= w_end.to_datetime64())
        vit_overlap.append(float(inside.sum() / len(vt)) if len(vt) else 0.0)
    else:
        vit_overlap.append(0.0)

df_all["lab_overlap_fraction"] = lab_overlap
df_all["vital_overlap_fraction"] = vit_overlap


=================
# VITAL TIME COVERAGE (NEW)
=================

# Count total vitals (current only)
vital_counts = (
    df_curr_vit.groupby(["SUBJECT_ID","HADM_ID"])
               .size()
               .reset_index(name="vital_count")
)

df_all = df_all.merge(vital_counts, on=["SUBJECT_ID","HADM_ID"], how="left")
df_all["vital_count"] = df_all["vital_count"].fillna(0)

df_all["wave_hours"] = (df_all["wave_end_dt"] - df_all["wave_start_dt"]).dt.total_seconds() / 3600.0
df_all["wave_hours"] = df_all["wave_hours"].clip(lower=1e-6)

df_all["vital_per_4hr"] = (df_all["vital_count"] / df_all["wave_hours"]) * 4


=================
# FINAL QUALITY FILTER (UPDATED)
=================

print("\n=== Selecting HIGH-QUALITY data (labs + vitals + waveform) ===")

df_quality = df_all[
    (df_all["curr_lab_types_dense"]   >= 4)  &
    #(df_all["curr_vital_types_dense"] >= 3)  &
    (df_all["nseg"] >= 2000)                 &
    #(df_all["vital_overlap_fraction"] >= 0.90) &
    (df_all["vital_per_4hr"] >= 1)             # NEW coverage filter
]

print("\nHigh-quality entries:", df_quality.shape[0])


=================
# GENERATE OUTPUT JSON
=================

info_list = []
for _, row in df_quality.iterrows():
    fname = row["file"]
    hadm  = int(row["HADM_ID"])
    nseg  = int(row["nseg"])

    entry = [
        fname,
        hadm,
        0,
        nseg - 1,
        nseg - 1,
        0
    ]
    info_list.append(entry)

with open(OUTPUT_LIST, "w") as f:
    json.dump(info_list, f, indent=2)

print("\nSaved", len(info_list), "entries to", OUTPUT_LIST)
if len(info_list):
    print("Example entry:", info_list[0])



=== Loading lab + vital tables ===
Total SUBJECT/HADM with lab+vital data: 14329

=== Scanning waveform files ===
Total files in waveform dir: 4282
Parsed waveform entries: 4282

Merged lab+vital+wave entries: 4281

=== Computing STRICT Lab–Waveform and Vital–Waveform Overlap ===

=== Selecting HIGH-QUALITY data (labs + vitals + waveform) ===

High-quality entries: 2694

Saved 2694 entries to mimic_high_quality_info_list.json
Example entry: ['107_174162_0_PLETH40_II120_II500_2673.npz', 174162, 0, 2672, 2672, 0]


In [None]:
#!/usr/bin/env python3

import os
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
os.environ["USE_PYGEOS"] = "0"

import polars as pl
import numpy as np
import json
from sklearn.model_selection import StratifiedGroupKFold

pl.Config.set_tbl_cols(-1)
pl.Config.set_tbl_width_chars(2000)


# 1. Paths / Config

DEMO_PATH   = "./mimic_patient_admission_demo_with_diag.csv"
HIST_PATH   = "./mimic_wav_lab_history.csv"
TARGET_PATH = "./mimic_wav_lab_overlap.csv"
VITAL_PATH  = "./mimic_wav_vital_overlap.csv"

HQ_JSON     = "./mimic_high_quality_info_list.json"
OUT_SPLIT_JSON = "./ppg_split_lists_stratified_hadm_labdemo.json"

lab_cols   = ["Potassium", "Calcium", "Sodium", "Glucose", "Lactate", "Creatinine"]
vital_cols = ["ABPs", "ABPd", "ABPm", "NBPs", "NBPd", "NBPm"]
nbp_cols   = ["NBPs", "NBPd", "NBPm"]


# 2. Load CSVs

demo   = pl.read_csv(DEMO_PATH,   infer_schema_length=20000)
hist   = pl.read_csv(HIST_PATH,   infer_schema_length=20000)
target = pl.read_csv(TARGET_PATH, infer_schema_length=20000)

# vital → string first, then convert using to_datetime
vital_raw = pl.read_csv(VITAL_PATH, infer_schema_length=20000)

print("Loaded CSVs:")
print(" demo   :", demo.shape)
print(" hist   :", hist.shape)
print(" target :", target.shape)
print(" vital_raw  :", vital_raw.shape)


# 3. Load HQ waveform encounters

with open(HQ_JSON, "r") as f:
    ppg_meta = json.load(f)

encounters = []
for entry in ppg_meta:
    fname = entry[0]
    parts = fname.split("_")
    if len(parts) < 2:
        continue
    try:
        subj = int(parts[0])
        hadm = int(parts[1])
        encounters.append((subj, hadm))
    except:
        continue

encounters = sorted(set(encounters))
print("Unique waveform encounters:", len(encounters))

base_enc = pl.DataFrame({
    "SUBJECT_ID": [p[0] for p in encounters],
    "HADM_ID":    [p[1] for p in encounters],
})


# 4. Join encounters
demo_enc   = demo.join(base_enc,   on=["SUBJECT_ID", "HADM_ID"], how="inner")
hist_enc   = hist.join(base_enc,   on=["SUBJECT_ID", "HADM_ID"], how="inner")
target_enc = target.join(base_enc, on=["SUBJECT_ID", "HADM_ID"], how="inner")

print("Joined:")
print(" demo_enc   :", demo_enc.shape)
print(" hist_enc   :", hist_enc.shape)
print(" target_enc :", target_enc.shape)


# 5. Demographics
demo_stats = (
    demo_enc
    .group_by(["SUBJECT_ID", "HADM_ID"])
    .agg([
        pl.col("GENDER").first().alias("GENDER"),
        pl.col("ETHNICITY").first().alias("ETHNICITY"),
        pl.col("INSURANCE").first().alias("INSURANCE"),
        pl.col("LANGUAGE").first().alias("LANGUAGE"),
        pl.col("MARITAL_STATUS").first().alias("MARITAL_STATUS"),
        pl.col("ICD9_CODE").first().alias("ICD9_CODE"),
        pl.col("age_at_admit").mean().alias("age_at_admit"),
    ])
)


# 6. Lab history stats
hist_stats = (
    hist_enc
    .select(["SUBJECT_ID", "HADM_ID"] + lab_cols)
    .group_by(["SUBJECT_ID", "HADM_ID"])
    .agg(
        [pl.col(c).min().alias(f"{c}_hist_min") for c in lab_cols] +
        [pl.col(c).max().alias(f"{c}_hist_max") for c in lab_cols]
    )
)

for lab in lab_cols:
    hist_stats = hist_stats.with_columns(
        (pl.col(f"{lab}_hist_max") - pl.col(f"{lab}_hist_min")).alias(f"{lab}_hist_range")
    )


# 7. Target mean
target_stats = (
    target_enc
    .select(["SUBJECT_ID", "HADM_ID"] + lab_cols)
    .group_by(["SUBJECT_ID", "HADM_ID"])
    .agg([pl.col(c).mean().alias(f"{c}_target_mean") for c in lab_cols])
)


# 8. Vital stats (NBP-based)
vital = vital_raw.with_columns(
    pl.col("CHARTTIME").str.to_datetime(strict=False)
)

vital_enc = vital.join(base_enc, on=["SUBJECT_ID", "HADM_ID"], how="inner")

vital_stats = (
    vital_enc
    .select(["SUBJECT_ID", "HADM_ID"] + vital_cols)
    .group_by(["SUBJECT_ID", "HADM_ID"])
    .agg([pl.col(c).count().alias(f"{c}_count") for c in vital_cols])
)

vital_stats = vital_stats.with_columns(
    sum(pl.col(f"{c}_count") for c in nbp_cols).alias("NBP_total_count")
)

dur_df = (
    vital_enc
    .group_by(["SUBJECT_ID", "HADM_ID"])
    .agg([
        pl.col("CHARTTIME").min().alias("vital_start"),
        pl.col("CHARTTIME").max().alias("vital_end"),
    ])
)

dur_df = dur_df.with_columns([
    pl.col("vital_start").cast(pl.Datetime),
    pl.col("vital_end").cast(pl.Datetime),
])

vital_stats = vital_stats.join(dur_df, on=["SUBJECT_ID", "HADM_ID"], how="left")

vital_stats = vital_stats.with_columns(
    ((pl.col("vital_end") - pl.col("vital_start")).dt.total_seconds().fill_null(0) / 3600.0)
    .alias("vital_hours")
)

vital_stats = vital_stats.with_columns(
    (pl.col("NBP_total_count") / (pl.col("vital_hours") + 1e-6)).alias("NBP_per_hour")
)

print(" vital_stats:", vital_stats.shape)


# 9. Merge
summary = (
    base_enc
    .join(demo_stats,   on=["SUBJECT_ID", "HADM_ID"], how="left")
    .join(hist_stats,   on=["SUBJECT_ID", "HADM_ID"], how="left")
    .join(target_stats, on=["SUBJECT_ID", "HADM_ID"], how="left")
    .join(vital_stats,  on=["SUBJECT_ID", "HADM_ID"], how="left")
    .fill_null(0)
)

summary = summary.with_columns([
    pl.when(pl.col("GENDER") == "M").then(0).otherwise(1).alias("gender_bin"),
    pl.when(pl.col("age_at_admit") < 30).then(1)
     .when(pl.col("age_at_admit") < 50).then(2)
     .when(pl.col("age_at_admit") < 70).then(3)
     .when(pl.col("age_at_admit") < 90).then(4)
     .otherwise(5)
     .alias("age_bin"),
])


# 10. Vital bins
nbp_vals = summary["NBP_per_hour"].to_numpy()
finite = np.isfinite(nbp_vals)

if finite.sum() > 10:
    q = np.nanquantile(nbp_vals[finite], [0.2, 0.4, 0.6, 0.8])
    summary = summary.with_columns(
        pl.when(pl.col("NBP_per_hour") < q[0]).then(0)
         .when(pl.col("NBP_per_hour") < q[1]).then(1)
         .when(pl.col("NBP_per_hour") < q[2]).then(2)
         .when(pl.col("NBP_per_hour") < q[3]).then(3)
         .otherwise(4)
         .alias("NBP_bin")
    )
else:
    summary = summary.with_columns(pl.lit(2).alias("NBP_bin"))


# 11. Final strata (now fully cleaned)
combo_cols = [
    "age_bin",
    "gender_bin",
    "ETHNICITY",
    "INSURANCE",
    "LANGUAGE",
    "MARITAL_STATUS",
    "ICD9_CODE",
    "NBP_bin",
]

def combine_fields(df, cols):
    out = df[cols[0]].cast(pl.Utf8)
    for c in cols[1:]:
        out = out + "_" + df[c].cast(pl.Utf8)
    return out

summary = summary.with_columns(
    combine_fields(summary, combo_cols).alias("strata")
)

# ---- CLEAN categorical values for sklearn ----
categorical_cols = [
    "ETHNICITY", "INSURANCE", "LANGUAGE",
    "MARITAL_STATUS", "ICD9_CODE", "strata"
]

summary = summary.with_columns([
    pl.col(c).cast(pl.Utf8).fill_null("UNK") for c in categorical_cols
])

summary_pd = summary.to_pandas()
summary_pd["strata"] = summary_pd["strata"].astype(str)

print("Final summary shape:", summary.shape)
print(summary.head(5))


# 12. StratifiedGroupKFold
hadm_ids = summary_pd["HADM_ID"].to_numpy()
strata   = summary_pd["strata"].to_numpy()

sgkf = StratifiedGroupKFold(n_splits=4, shuffle=True, random_state=42)
train_idx, test_idx = next(sgkf.split(hadm_ids, strata, groups=hadm_ids))

train_hadm = set(hadm_ids[train_idx])
test_hadm  = set(hadm_ids[test_idx])

print("Train encounters:", len(train_hadm))
print("Test encounters :", len(test_hadm))


# 13. Assign waveform files
split_dict = {
    "train_control_list": [],
    "test_control_list": [],
}

for entry in ppg_meta:
    hadm = int(entry[1])
    if hadm in train_hadm:
        split_dict["train_control_list"].append(entry)
    elif hadm in test_hadm:
        split_dict["test_control_list"].append(entry)

with open(OUT_SPLIT_JSON, "w") as f:
    json.dump(split_dict, f, indent=4)

print("Saved:", OUT_SPLIT_JSON)
print("Train:", len(split_dict["train_control_list"]))
print("Test :", len(split_dict["test_control_list"]))


Loaded CSVs:
 demo   : (6737, 11)
 hist   : (640371, 9)
 target : (79383, 13)
 vital_raw  : (845265, 14)
Unique waveform encounters: 2694
Joined:
 demo_enc   : (2692, 11)
 hist_enc   : (83309, 9)
 target_enc : (25648, 13)
 vital_stats: (2694, 13)
Final summary shape: (2694, 48)
shape: (5, 48)
┌────────────┬─────────┬────────┬────────────────────────┬───────────┬──────────┬────────────────┬───────────┬──────────────┬────────────────────┬──────────────────┬─────────────────┬──────────────────┬──────────────────┬─────────────────────┬────────────────────┬──────────────────┬─────────────────┬──────────────────┬──────────────────┬─────────────────────┬──────────────────────┬────────────────────┬───────────────────┬────────────────────┬────────────────────┬───────────────────────┬───────────────────────┬─────────────────────┬────────────────────┬─────────────────────┬─────────────────────┬────────────────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬─



Train encounters: 2020
Test encounters : 674
Saved: ./ppg_split_lists_stratified_hadm_labdemo.json
Train: 2020
Test : 674
