## **Missing-aware Evolving Fuzzy Classifier (MEF-Classifier) on MIMIC-iv**

In [1]:
%load_ext autoreload
%autoreload 2

### **Data Loading**

In [2]:
import numpy as np
import pandas as pd

In [None]:
csv_path = '...'

date_cols = [
    "index_time", 
    "t_hcc_first", "t_cirr_first", "t_fib_first", "t_lf_first",
    "t_event_min", "t_event_max", "last_followup_time",
    "charttime"
]

raw = pd.read_csv(
    csv_path,
    parse_dates=[c for c in date_cols if c in open(csv_path, "r", encoding="utf-8", errors="ignore").readline()],
    low_memory=False
)

cohort_cols = [
    "subject_id", "index_hadm_id", "index_time",
    "gender", "age_at_index",
    "t_hcc_first", "t_cirr_first", "t_fib_first", "t_lf_first",
    "t_event_min", "t_event_max", "last_followup_time"
]
cohort_cols = [c for c in cohort_cols if c in raw.columns]
cohort = raw[cohort_cols].drop_duplicates("subject_id").reset_index(drop=True)

# compute birth_year
if ("index_time" in cohort.columns) and ("age_at_index" in cohort.columns):
    cohort["birth_year"] = (cohort["index_time"].dt.year - cohort["age_at_index"]).astype("float")

# gender to numeric
if "gender" in cohort.columns:
    cohort["gender_num"] = cohort["gender"].map({"F": 0, "M": 1}).astype("float")

# --- labs（long） ---
labs_cols = ["subject_id", "hadm_id", "charttime", "itemid", "feature_name", "valuenum", "valueuom"]
labs_cols = [c for c in labs_cols if c in raw.columns]
labs = raw[labs_cols].copy()

labs = labs.dropna(subset=["subject_id", "charttime", "itemid", "valuenum"])
labs["itemid"] = labs["itemid"].astype(int)
labs["valuenum"] = pd.to_numeric(labs["valuenum"], errors="coerce")
labs = labs.dropna(subset=["valuenum"])

print("cohort:", cohort.shape, "labs:", labs.shape)


cohort: (2476, 14) labs: (365924, 7)


In [4]:
def convert_units(df):
    df = df.copy()

    # Albumin: mg/dL -> g/dL -> g/L
    mask = (df["feature_name"].eq("Albumin")) & (df["valueuom"].astype(str).str.lower().eq("mg/dl"))
    df.loc[mask, "valuenum"] = df.loc[mask, "valuenum"] / 1000.0
    df.loc[mask, "valueuom"] = "g/dL"

    mask = (df["feature_name"].eq("Albumin")) & (df["valueuom"].astype(str).str.lower().eq("g/dl"))
    df.loc[mask, "valuenum"] = df.loc[mask, "valuenum"] * 10.0
    df.loc[mask, "valueuom"] = "g/L"

    # Total protein: mg/dL -> g/dL -> g/L
    mask = (df["feature_name"].str.lower().isin(["protein, total", "total protein"])) & \
           (df["valueuom"].astype(str).str.lower().eq("mg/dl"))
    df.loc[mask, "valuenum"] = df.loc[mask, "valuenum"] / 1000.0
    df.loc[mask, "valueuom"] = "g/dL"

    mask = (df["feature_name"].str.lower().isin(["protein, total", "total protein"])) & \
           (df["valueuom"].astype(str).str.lower().eq("g/dl"))
    df.loc[mask, "valuenum"] = df.loc[mask, "valuenum"] * 10.0
    df.loc[mask, "valueuom"] = "g/L"

    # For features with unit: #/uL -> K/uL
    mask = df["valueuom"].astype(str).str.contains(r"#/uL|#/ul", case=False, na=False)
    df.loc[mask, "valuenum"] = df.loc[mask, "valuenum"] / 1000.0
    df.loc[mask, "valueuom"] = "K/uL"

    return df

labs = convert_units(labs)

In [None]:
labs["feature_std"] = labs["feature_name"].copy()

is_eos = labs["feature_name"].str.lower().str.contains("eosin", na=False)
mask_abs = is_eos & labs["valueuom"].astype(str).str.contains(r"#|/uL|/ul", case=False, na=False)
mask_pct = is_eos & labs["valueuom"].astype(str).str.contains(r"%", case=False, na=False)
labs.loc[mask_abs, "feature_std"] = "Eosinophils_abs"
labs.loc[mask_pct, "feature_std"] = "Eosinophils_pct"
labs.loc[mask_pct, "feature_std"] = "Eosinophils"
labs = labs[labs["feature_std"] != "Eosinophils_abs"].copy()


is_lym = labs["feature_name"].str.lower().str.contains("lymphocyte", na=False)
mask_abs = is_lym & labs["valueuom"].astype(str).str.contains(r"#|/uL|/ul", case=False, na=False)
mask_pct = is_lym & labs["valueuom"].astype(str).str.contains(r"%", case=False, na=False)
labs.loc[mask_abs, "feature_std"] = "Lymphocytes_abs"
labs.loc[mask_pct, "feature_std"] = "Lymphocytes_pct"
labs.loc[mask_abs, "feature_std"] = "Lymphocytes"
labs = labs[labs["feature_std"] != "Lymphocytes_pct"].copy()


is_neutro = labs["feature_name"].str.lower().str.contains("neutrophil", na=False)
mask_abs = is_neutro & labs["valueuom"].astype(str).str.contains(r"#|/uL|/ul", case=False, na=False)
mask_pct = is_neutro & labs["valueuom"].astype(str).str.contains(r"%", case=False, na=False)
labs.loc[mask_abs, "feature_std"] = "Neutrophils_abs"
labs.loc[mask_pct, "feature_std"] = "Neutrophils_pct"
labs.loc[mask_abs, "feature_std"] = "Neutrophils"
labs = labs[labs["feature_std"] != "Neutrophils_pct"].copy()

is_baso = labs["feature_name"].str.lower().str.contains("basophil", na=False)
mask_abs = is_baso & labs["valueuom"].astype(str).str.contains(r"#|/uL|/ul", case=False, na=False)
mask_pct = is_baso & labs["valueuom"].astype(str).str.contains(r"%", case=False, na=False)
labs.loc[mask_abs, "feature_std"] = "Basophils_abs"
labs.loc[mask_pct, "feature_std"] = "Basophils_pct"
labs.loc[mask_abs, "feature_std"] = "Basophils"
labs = labs[labs["feature_std"] != "Basophils_pct"].copy()



In [None]:
# month period
labs["MonthPeriod"] = labs["charttime"].dt.to_period("M")

labs = labs.sort_values(by=["subject_id", "charttime"]) # 确保按时间排序
agg = (labs
       .groupby(["subject_id", "MonthPeriod", "feature_std"], as_index=False)["valuenum"]
       .first())

wide = agg.pivot_table(index=["subject_id", "MonthPeriod"],
                       columns="feature_std", values="valuenum", aggfunc="first").reset_index()
wide.columns.name = None

print("monthly wide:", wide.shape)
print("features:", [c for c in wide.columns if c not in ["subject_id", "MonthPeriod"]])


monthly wide: (21014, 10)
features: ['Albumin', 'Basophils', 'Eosinophils', 'Lymphocytes', 'Neutrophils', 'Platelets', 'Total protein', 'White cells count']


In [None]:
def pick_t1(row):
    t1 = row.get("t_event_max", pd.NaT)
    if pd.isna(t1):
        t1 = row.get("last_followup_time", pd.NaT)
    return t1

cohort["t1_extract"] = cohort.apply(pick_t1, axis=1)
cohort = cohort[cohort["t1_extract"].notna() & cohort["index_time"].notna()].copy()
cohort = cohort[cohort["t1_extract"] > cohort["index_time"]].copy()

cohort["index_mp"] = cohort["index_time"].dt.to_period("M")
cohort["t1_mp"] = cohort["t1_extract"].dt.to_period("M")

cohort["T_max"] = (cohort["t1_mp"].astype("int64") - cohort["index_mp"].astype("int64") + 1).astype(int)
cohort = cohort[cohort["T_max"] >= 2].copy()

wide_scoped = wide.merge(
    cohort[["subject_id", "index_mp", "t1_mp"]],
    on="subject_id",
    how="inner"
)

wide_scoped = wide_scoped[
    (wide_scoped["MonthPeriod"] >= wide_scoped["index_mp"]) &
    (wide_scoped["MonthPeriod"] <= wide_scoped["t1_mp"])
].copy()

wide_scoped["TimeUnit"] = (
    wide_scoped["MonthPeriod"].astype("int64") - wide_scoped["index_mp"].astype("int64") + 1
).astype(int)

ts = (
    wide_scoped
    .drop(columns=["index_mp", "t1_mp"])
    .sort_values(["subject_id", "TimeUnit", "MonthPeriod"])
    .reset_index(drop=True)
)

ts = ts[['subject_id', 'MonthPeriod', 'TimeUnit'] + \
     [c for c in ts.columns if c not in ['subject_id', 'MonthPeriod', 'TimeUnit']]]

print("ts shape:", ts.shape)



ts shape: (20896, 11)


In [None]:
event_cols = ["t_hcc_first", "t_cirr_first", "t_fib_first", "t_lf_first", "t_event_min"]
event_cols = [c for c in event_cols if c in cohort.columns]

df_labels = cohort[["subject_id"] + event_cols].copy()

df_labels["has_event"] = df_labels[event_cols].notna().any(axis=1).astype(int)

print(df_labels["has_event"].value_counts(dropna=False))

has_event
0    2150
1     208
Name: count, dtype: int64


#### **dataset overview**

In [10]:
## compute the age mean and std of the cohort
age_mean = cohort["age_at_index"].mean()
age_std = cohort["age_at_index"].std()
print("Age mean:", age_mean, "std:", age_std)

## compute the gender distribution (%) of the cohort
gender_counts = cohort["gender_num"].value_counts(normalize=True, dropna=False) * 100
print("Gender distribution (%):\n", gender_counts)

## compute the follow-up median, IQR and mean, std (years) of the cohort
followup_durations = (cohort["t1_extract"] - cohort["index_time"]).dt.days / 365.25
# followup_median = followup_durations.median()
# followup_q1 = followup_durations.quantile(0.25)
# followup_q3 = followup_durations.quantile(0.75)
followup_mean = followup_durations.mean()
followup_std = followup_durations.std()
# print(f"Follow-up duration (years): median={followup_median:.2f}, IQR=({followup_q1:.2f}, {followup_q3:.2f}), mean={followup_mean:.2f}, std={followup_std:.2f}")
print(f"Follow-up duration (years): mean={followup_mean:.2f}, std={followup_std:.2f}")

Age mean: 57.354113655640376 std: 14.308399601297857
Gender distribution (%):
 gender_num
0.0    51.187447
1.0    48.812553
Name: proportion, dtype: float64
Follow-up duration (years): mean=2.92, std=3.02


### **train-test-split**

In [11]:
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit, GroupShuffleSplit

In [None]:
def split_patient_ids(df_ts, df_labels, id_col='ID', test_size=0.2, random_state=42, use_stratify=True, stratify_col='has_event',
                        subset_frac=None, balance=False, balance_range=(0.8, 1.2), min_per_class=1, verbose=True):

    ids_ts  = pd.Index(df_ts[id_col].unique())
    ids_lab = pd.Index(df_labels[id_col].unique())
    valid_ids = ids_ts.intersection(ids_lab)  

    lab_sr = (df_labels.set_index(id_col)
                        .reindex(valid_ids)[stratify_col]
                        .fillna(0).astype(int))

    ids_arr = valid_ids.to_numpy()

    work_ids = ids_arr.copy()
    y_work = lab_sr.values.copy()
    rng = np.random.default_rng(random_state)

    if subset_frac is not None and 0 < subset_frac < 1.0:
        sss_sub = StratifiedShuffleSplit(n_splits=1, train_size=subset_frac, random_state=random_state)
        sub_idx, _ = next(sss_sub.split(work_ids.reshape(-1, 1), y_work))
        work_ids = work_ids[sub_idx]
        y_work = y_work[sub_idx]

    if balance and len(np.unique(y_work)) > 0:
        by_class = {c: work_ids[y_work == c] for c in np.unique(y_work)}
        min_count = min(len(v) for v in by_class.values()) if by_class else 0
        lo, hi = balance_range

        chosen = []
        for c, arr in by_class.items():
            base_count = int(np.floor(rng.uniform(lo, hi) * max(min_count, 0)))
            base_count = max(min_per_class, min(base_count, len(arr)))
            if base_count > 0:
                chosen.append(rng.choice(arr, size=base_count, replace=False))

        work_ids = np.concatenate(chosen) if chosen else np.array([], dtype=ids_arr.dtype)
        y_work = lab_sr.reindex(work_ids).to_numpy()

    if use_stratify:
        uniq, cnt = np.unique(y_work, return_counts=True)
        min_cnt = cnt.min() if len(cnt) else 0
        if min_cnt < 2:
            if verbose:
                print(f"[split] stratify disabled (min class count={min_cnt})")
            use_stratify = False
    if use_stratify:
        # y = lab_sr.to_numpy()
        sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
        tr_idx, te_idx = next(sss.split(work_ids.reshape(-1,1), y_work))
    else:
        gss = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
        tr_idx, te_idx = next(gss.split(work_ids, groups=work_ids))

    train_ids = set(work_ids[tr_idx])
    test_ids = set(work_ids[te_idx])

    print(f"Patients original total={len(valid_ids)} | used total={len(train_ids)+len(test_ids)}  | Train={len(train_ids)} | Test={len(test_ids)}")
    if use_stratify:
        print("Pos rate (has_event) — Train:",
              lab_sr.loc[list(train_ids)].mean(),
              " Test:", lab_sr.loc[list(test_ids)].mean())

    return train_ids, test_ids

def slice_by_ids(df_ts: pd.DataFrame, train_ids: set, test_ids: set, id_col='ID'):
    df_train_ts = df_ts[df_ts[id_col].isin(train_ids)].copy()
    df_test_ts  = df_ts[df_ts[id_col].isin(test_ids)].copy()

    assert set(df_train_ts[id_col].unique()).isdisjoint(set(df_test_ts[id_col].unique()))
    return df_train_ts, df_test_ts


In [13]:
train_ids, test_ids = split_patient_ids(ts, df_labels=df_labels, id_col="subject_id", use_stratify=True, subset_frac=None, balance=False)

df_train_ts, df_test_ts = slice_by_ids(ts, train_ids, test_ids, id_col="subject_id")


Patients original total=2358 | used total=2358  | Train=1886 | Test=472
Pos rate (has_event) — Train: 0.088016967126193  Test: 0.08898305084745763


In [None]:
from sklearn import preprocessing

## normalization 
feature_cols = [c for c in ts.columns if c not in ["subject_id", "MonthPeriod", "TimeUnit"]]

scaler = preprocessing.StandardScaler()
scaler.fit(df_train_ts[feature_cols])              

df_train_ts[feature_cols] = scaler.transform(df_train_ts[feature_cols])  
df_test_ts[feature_cols]  = scaler.transform(df_test_ts[feature_cols])

### **Create Sliding Windows**

In [None]:
def apply_fill_method(window_data, fill_method, feature_cols):
    window_data = window_data.copy()
    
    if fill_method == 'nan':
        pass
    elif fill_method == 'inf':
        window_data[feature_cols] = window_data[feature_cols].fillna(np.inf)
    elif fill_method == 'zero':
        window_data[feature_cols] = window_data[feature_cols].fillna(0)
    elif fill_method == 'forward_fill':
        window_data[feature_cols] = window_data[feature_cols].fillna(method='ffill')
    elif fill_method == 'backward_fill':
        window_data[feature_cols] = window_data[feature_cols].fillna(method='bfill')
    elif fill_method == 'mean':
        for col in feature_cols:
            mean_val = window_data[col].mean()
            if not np.isnan(mean_val):
                window_data[col] = window_data[col].fillna(mean_val)
    elif fill_method == 'median':
        for col in feature_cols:
            median_val = window_data[col].median()
            if not np.isnan(median_val):
                window_data[col] = window_data[col].fillna(median_val)
    
    return window_data

def create_sliding_windows(df, window_size, min_win_valid=1, fill_method='nan', id_col='ID', time_col='TimeUnit', patient_info_cols=None, event_date_cols=None,
                           preexpanded=True, features_override=None, agg_select=('first',),
                           df_demo=None, df_labels=None, earliest_date_col='earliest_date', time_horizon_months=6, return_extra=True,
                           balance_windows=False, balance_range=(0.8, 1.2), min_per_class=1, balance_random_state=42):
    
    if not preexpanded:
        raw_feature_cols = [c for c in df.columns if c not in [id_col, time_col, 'MonthPeriod']]
        df_expanded, feature_cols = expand_zipped_columns(df, raw_feature_cols, agg_select, drop_original=True)
    else:
        df_expanded = df
        if features_override is not None:
            feature_cols = list(features_override)
        else:
            feature_cols = [c for c in df.columns if c not in [id_col, time_col, 'MonthPeriod']]
    
    print(f"特征列: {feature_cols}")
    print(f"窗口大小: {window_size}")
    print(f"填充方法: {fill_method}")
    
    windowed_data = []
    win_meta = []  # {id_window, ID, last_valid_mp(Period), first_mp(Period), window_end_time}
    window_id = 0
    
    for patient_id in df_expanded[id_col].unique():
        patient_data = df_expanded[df_expanded[id_col] == patient_id].sort_values(time_col).reset_index(drop=True)
        
        if 'MonthPeriod' in patient_data.columns:
            mp_series = patient_data['MonthPeriod']
            if not isinstance(mp_series.dtype, pd.PeriodDtype):
                try:
                    mp_period = pd.PeriodIndex(pd.to_datetime(mp_series, errors='coerce'), freq='M')
                except Exception:
                    mp_period = pd.PeriodIndex([pd.NaT]*len(patient_data), freq='M')
            else:
                mp_period = mp_series
        else:
            mp_period = pd.PeriodIndex([pd.NaT]*len(patient_data), freq='M')

        first_mp = mp_period[~mp_period.isna()].min() if (~mp_period.isna()).any() else pd.NaT

        patient_time_min = int(patient_data[time_col].min())
        patient_time_max = int(patient_data[time_col].max())

        for i, current_row in patient_data.iterrows():
            current_time = current_row[time_col]
            
            start_time = current_time - window_size + 1
            end_time = current_time

            window_mask = (patient_data[time_col] >= start_time) & (patient_data[time_col] <= end_time)
            window_data = patient_data[window_mask].copy()

            if len(window_data) < min_win_valid:
                continue

            current_window_mps = mp_period[window_mask]
            if (~current_window_mps.isna()).any():
                last_valid_mp = current_window_mps[~current_window_mps.isna()].max()
            else:
                last_valid_mp = pd.NaT

            if len(window_data) < window_size:
                missing_rows = window_size - len(window_data)

                valid_times = window_data[time_col].tolist()
                
                if valid_times:
                    win_valid_min = min(valid_times) 
                    win_valid_max = max(valid_times) 
                else:
                    win_valid_min = float('inf')
                    win_valid_max = float('-inf')

                existing_times = set(valid_times)
                missing_times = [t for t in range(start_time, end_time + 1) if t not in existing_times]
                missing_times = missing_times[:missing_rows]

                empty_rows = []
                for missing_time in missing_times:
                    empty_row = {col: np.nan for col in patient_data.columns}
                    empty_row[id_col] = patient_id
                    empty_row[time_col] = missing_time
                    
                    is_padding = (missing_time < win_valid_min) or (missing_time > win_valid_max)
                    
                    if is_padding:
                        for c in feature_cols:
                            empty_row[c] = -np.inf  
                    
                    empty_rows.append(empty_row)
                
                if empty_rows:
                    empty_df = pd.DataFrame(empty_rows)
                    empty_df = empty_df.astype(window_data.dtypes)
                    for c in feature_cols:
                        empty_df[c] = empty_df[c].astype(float)
                        
                    window_data = pd.concat([empty_df, window_data], ignore_index=True)
                    window_data = window_data.sort_values(time_col).reset_index(drop=True)

            window_data[feature_cols] = window_data[feature_cols].astype(float)
            window_data = apply_fill_method(window_data, fill_method, feature_cols)

            window_data['id_window'] = window_id
            
            windowed_data.append(window_data)

            if (df_demo is not None) or (df_labels is not None):
                win_meta.append({
                    'id_window': window_id,
                    id_col: patient_id,
                    'window_end_time': end_time,
                    'last_valid_mp': last_valid_mp,  
                    'first_mp': first_mp             
                })

            window_id += 1

            if end_time >= patient_data['TimeUnit'].max():
                break

    result_df = pd.concat(windowed_data, ignore_index=True)
    
    print(f"\n生成了 {window_id} 个窗口")
    print(f"结果数据形状: {result_df.shape}")

    win_demo = None
    win_labels = None
    if len(win_meta) > 0:
        meta_df = pd.DataFrame(win_meta)

        if df_demo is not None:
            demo_cols = [id_col] + [c for c in patient_info_cols if c in df_demo.columns]
            tmp = meta_df.merge(df_demo[demo_cols], on=id_col, how='left')

            last_ts = tmp['last_valid_mp'].apply(
                lambda p: (p.to_timestamp(how='start') if not pd.isna(p) else pd.NaT)
            )

            if 'birth_year' in tmp.columns:
                by = pd.to_numeric(tmp['birth_year'], errors='coerce')
            else:
                by = pd.Series(np.nan, index=tmp.index)

            age_from_birthyear = np.where(
                by.notna() & last_ts.notna(),
                last_ts.dt.year - by,
                np.nan
            )

            tmp['age_at_window'] = np.where(
                by.notna() & last_ts.notna(),
                age_from_birthyear,
                np.nan 
            )

            cols_out = ['id_window', id_col]
            if 'gender' in tmp.columns: cols_out.append('gender')
            if 'age_at_window' in tmp.columns: cols_out.append('age_at_window')
            win_demo = tmp[cols_out].copy()

        if df_labels is not None and event_date_cols is not None:
            avaliable_event_cols = [c for c in event_date_cols if c in df_labels.columns]
            lab_cols = [id_col] + avaliable_event_cols
            lab = meta_df.merge(df_labels[lab_cols], on=id_col, how='left')

            last_ts = lab['last_valid_mp'].apply(
                lambda p: (p.to_timestamp(how='start') if not pd.isna(p) else pd.NaT)
            )

            horizon_dt = last_ts + pd.offsets.DateOffset(months=int(time_horizon_months))

            y_cols = {}
            for event_col in avaliable_event_cols:
                ed = pd.to_datetime(lab[event_col], errors='coerce')
                y = np.where(ed.notna() & horizon_dt.notna(), (ed <= horizon_dt).astype(int), 0)
                y_cols[f'y_{event_col}'] = y

            win_labels = pd.DataFrame({
                    'id_window': lab['id_window'].values,
                    id_col: lab[id_col].values,
                    **y_cols  
                })

        if balance_windows:
            y_columns = [c for c in win_labels.columns if c.startswith('y_')]
            if (win_labels is None) or len(y_columns) == 0:
                print("[balance] 未找到窗口级标签（win_labels['y_*']），跳过均衡。")
            else:
                rng = np.random.default_rng(balance_random_state)
                balanced_col = y_columns[-1]  
                y_series = win_labels.set_index('id_window')[balanced_col].astype(int)

                counts = y_series.value_counts().sort_index()
                for y_col in y_columns:
                    counts = win_labels[y_col].value_counts().sort_index()

                if len(counts) >= 1:
                    min_count = counts.min()
                    lo, hi = balance_range

                    chosen_win_ids = []
                    for cls_current, cls_count in counts.items():
                        wid_cls = y_series.index[y_series.values == cls_current].to_numpy()
                        base_count = int(np.floor(min_count * rng.uniform(lo, hi)))
                        base_count = max(min_per_class, min(base_count, cls_count))
                        if base_count > 0:
                            chosen = rng.choice(wid_cls, size=base_count, replace=False)
                            chosen_win_ids.append(chosen)

                    if len(chosen_win_ids) > 0:
                        keep_ids = np.concatenate(chosen_win_ids)
                    else:
                        keep_ids = np.array([], dtype=result_df['id_window'].dtype)

                    result_df = result_df[result_df['id_window'].isin(keep_ids)].copy()
                    if win_demo is not None:
                        win_demo = win_demo[win_demo['id_window'].isin(keep_ids)].copy()
                    if win_labels is not None:
                        win_labels = win_labels[win_labels['id_window'].isin(keep_ids)].copy()

                    
                    new_counts = win_labels[balanced_col].value_counts().sort_index()
                    for y_col in y_columns:
                        new_counts = win_labels[y_col].value_counts().sort_index()

    if return_extra:
        return result_df, win_demo, win_labels
    
    return result_df

In [None]:
window_size = 6
min_win_valid = int(window_size*0.5)  
fill_method = 'inf'  # 'nan', 'inf', 'zero', 'forward_fill', 'backward_fill', 'mean', 'median'
agg_select = ('first',)  # ('last',), ('first','last'), ('first','mean','max'), or {'Albumin':['last'], 'Platelets':['mean','max']}
time_horizon_months = 6  

id_col = 'subject_id'
time_col = 'TimeUnit'
EVENT_DATE_COLS = event_cols
PATIENT_INFO_COLS = [ "birth_year", "age_at_index", "gender_num" ] 

df_demo = cohort[["subject_id"] + PATIENT_INFO_COLS].copy()
df_demo = df_demo.rename(columns={"gender_num": "gender"})
PATIENT_INFO_COLS = [c for c in df_demo.columns if c not in ['subject_id']]

print(f"\n=== 创建滑动窗口 (window_size={window_size}, fill_method={fill_method}) ===")

balance_windows_train = True 
balance_windows_test = False

df_train_win, demo_train_win, labels_train_win = create_sliding_windows(df_train_ts, window_size=window_size, min_win_valid=min_win_valid, fill_method=fill_method,
                                                                        id_col=id_col, time_col=time_col, 
                                                                        patient_info_cols=PATIENT_INFO_COLS, event_date_cols=EVENT_DATE_COLS,
                                                                        preexpanded=True, features_override=feature_cols, agg_select=agg_select,
                                                                        df_demo=df_demo, df_labels=df_labels, time_horizon_months=time_horizon_months,
                                                                        balance_windows=balance_windows_train)
df_test_win, demo_test_win, labels_test_win = create_sliding_windows(df_test_ts,  window_size=window_size, min_win_valid=min_win_valid, fill_method=fill_method,
                                                                        id_col=id_col, time_col=time_col, 
                                                                        patient_info_cols=PATIENT_INFO_COLS, event_date_cols=EVENT_DATE_COLS,
                                                                        preexpanded=True, features_override=feature_cols, agg_select=agg_select,
                                                                        df_demo=df_demo, df_labels=df_labels, time_horizon_months=time_horizon_months,
                                                                        balance_windows=balance_windows_test)


=== 创建滑动窗口 (window_size=6, fill_method=inf) ===
特征列: ['Albumin', 'Basophils', 'Eosinophils', 'Lymphocytes', 'Neutrophils', 'Platelets', 'Total protein', 'White cells count']
窗口大小: 6
填充方法: inf

生成了 7904 个窗口
结果数据形状: (47424, 12)
特征列: ['Albumin', 'Basophils', 'Eosinophils', 'Lymphocytes', 'Neutrophils', 'Platelets', 'Total protein', 'White cells count']
窗口大小: 6
填充方法: inf

生成了 1807 个窗口
结果数据形状: (10842, 12)


In [18]:
df_train_win[['id_window', 'subject_id']].drop_duplicates()

Unnamed: 0,id_window,subject_id
234,39,10025862
312,52,10025862
378,63,10098875
444,74,10132759
450,75,10132759
...,...,...
46998,7833,19936204
47052,7842,19936204
47106,7851,19936204
47292,7882,19940147


### **Pre-Analysis**

In [19]:
# 1) sort window rows
df_train_win = df_train_win.sort_values(['id_window', 'TimeUnit']).reset_index(drop=True)
df_test_win  = df_test_win.sort_values(['id_window', 'TimeUnit']).reset_index(drop=True)

# 2) sort labels/demo by id_window too
labels_train_win = labels_train_win.sort_values('id_window').reset_index(drop=True)
labels_test_win  = labels_test_win.sort_values('id_window').reset_index(drop=True)

demo_train_win = demo_train_win.sort_values('id_window').reset_index(drop=True)
demo_test_win  = demo_test_win.sort_values('id_window').reset_index(drop=True)

# 3) ensure same id_window order between windowed rows and labels/demo
win_ids_train = df_train_win['id_window'].drop_duplicates().to_numpy()
lab_ids_train = labels_train_win['id_window'].to_numpy()
dem_ids_train = demo_train_win['id_window'].to_numpy()

assert (win_ids_train == lab_ids_train).all(), "Train: id_window order mismatch between df_train_win and labels_train_win"
assert (win_ids_train == dem_ids_train).all(), "Train: id_window order mismatch between df_train_win and demo_train_win"


In [20]:
counts_train = df_train_win.groupby('id_window').size()
counts_test  = df_test_win.groupby('id_window').size()

assert counts_train.nunique() == 1 and counts_train.iloc[0] == window_size, \
    f"Train windows not all size={window_size}: {counts_train.value_counts().head()}"

assert counts_test.nunique() == 1 and counts_test.iloc[0] == window_size, \
    f"Test windows not all size={window_size}: {counts_test.value_counts().head()}"


N = counts_train.shape[0]
F = len(feature_cols)
input_train = df_train_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)
print(f"训练集窗口数: {N}, 特征数: {F}, 输入形状: {input_train.shape}")

N = counts_test.shape[0]
input_test = df_test_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)
print(f"测试集窗口数: {N}, 特征数: {F}, 输入形状: {input_test.shape}")

mask_train = np.isnan(input_train)
mask_test = np.isnan(input_test)

target_cols = [c for c in labels_train_win.columns if c.startswith('y_')]
target_train = labels_train_win[target_cols].to_numpy(dtype=int)  # (N, num_targets)
target_test  = labels_test_win[target_cols].to_numpy(dtype=int)    # (N, num_targets)
print(f"训练集标签形状: {target_train.shape}, 测试集标签形状: {target_test.shape}")
for i, col in enumerate(target_cols):
    print(f"  目标列 {col} 的正类比例: 训练集 {target_train[:,i].mean():.4f}, 测试集 {target_test[:,i].mean():.4f}")

demo_cols = ['gender', 'age_at_window']
demo_train = demo_train_win[demo_cols].to_numpy(dtype=int)  # (N, 2)
demo_test  = demo_test_win[demo_cols].to_numpy(dtype=int)    # (N, 2)
print(f"训练集人口统计信息形状: {demo_train.shape}, 测试集人口统计信息形状: {demo_test.shape}")



训练集窗口数: 831, 特征数: 8, 输入形状: (831, 6, 8)
测试集窗口数: 1807, 特征数: 8, 输入形状: (1807, 6, 8)
训练集标签形状: (831, 5), 测试集标签形状: (1807, 5)
  目标列 y_t_hcc_first 的正类比例: 训练集 0.0241, 测试集 0.0177
  目标列 y_t_cirr_first 的正类比例: 训练集 0.2238, 测试集 0.0133
  目标列 y_t_fib_first 的正类比例: 训练集 0.0048, 测试集 0.0039
  目标列 y_t_lf_first 的正类比例: 训练集 0.2960, 测试集 0.0343
  目标列 y_t_event_min 的正类比例: 训练集 0.4741, 测试集 0.0526
训练集人口统计信息形状: (831, 2), 测试集人口统计信息形状: (1807, 2)


In [None]:
import numpy as np
import pandas as pd
from typing import Optional, Tuple, Dict, Any

def analyze_and_select_valid_distribution(
    input_arr: np.ndarray,                 
    target_arr: Optional[np.ndarray],      
    labels_win: pd.DataFrame,              
    df_win: pd.DataFrame,                  
    demo_arr: Optional[np.ndarray] = None, 
    *,
    desired: Optional[str] = None,         
    main_target_name: Optional[str] = None,
    ratio_tol: float = 0.05,               
    random_state: int = 42,
    front_thr: float = 0.4,                
    back_thr: float = 0.6,                 
    verbose: bool = True
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray], pd.DataFrame, pd.DataFrame, pd.DataFrame, Dict[str, Any], Optional[np.ndarray]]:

    assert input_arr.ndim == 3, "input_arr 应为 (N, T, F) 且缺失为 np.nan"
    N, T, F = input_arr.shape

    win_ids_order = df_win["id_window"].drop_duplicates().to_numpy()
    assert len(win_ids_order) == N, "unique id_window 数量必须等于 N (= input_arr.shape[0])"

    valid_mask = ~np.isnan(input_arr)           # (N, T, F)
    valid_counts_t = valid_mask.sum(axis=2)     # (N, T)

    com_norm = _valid_center_of_mass(valid_counts_t)                     # (N,)
    dist_class = _classify_by_com(com_norm, front_thr=front_thr, back_thr=back_thr)
    avg_valid_ratio_t = valid_mask.mean(axis=(0, 2))                     # (T,)

    labels_win_aug = _align_labels_by_id_window(labels_win, win_ids_order).copy()
    labels_win_aug["dist_class"] = dist_class

    uniq, cnts = np.unique(dist_class, return_counts=True)
    counts = dict(zip(uniq.tolist(), cnts.tolist()))
    if verbose:
        print("[Distribution] counts:", counts)

    stats = {
        "com_norm": com_norm,
        "dist_class": dist_class,
        "avg_valid_ratio_t": avg_valid_ratio_t,
        "counts": counts,
    }

    if desired is None:
        return (input_arr, target_arr, demo_arr, df_win, labels_win, labels_win_aug, stats, None)

    assert desired in {"front", "back", "average"}, "desired 必须是 'front'/'back'/'average' 或 None"

    if main_target_name is None:
        cand_cols = [c for c in labels_win_aug.columns if c != "id_window"]
        assert len(cand_cols) > 0, "labels_win 至少应包含一个除 'id_window' 外的目标列"
        main_target_name = cand_cols[-1]
    assert main_target_name in labels_win_aug.columns, f"{main_target_name} 不在 labels_win 中"

    y_main = labels_win_aug[main_target_name].to_numpy()
    if y_main.dtype.kind not in "iu":
        y_main = y_main.astype(int)

    desired_mask = (dist_class == desired)
    keep_idx = _stratified_keep_indices(y_main, desired_mask, ratio_tol=ratio_tol, random_state=random_state)

    if keep_idx.size == 0:
        raise ValueError(f"没有符合分布 '{desired}' 的窗口。")

    kept_win_ids = win_ids_order[keep_idx]

    input_new  = input_arr[keep_idx]
    target_new = target_arr[keep_idx] if target_arr is not None else None
    demo_new   = demo_arr[keep_idx]   if demo_arr   is not None else None

    df_win_new      = df_win[df_win["id_window"].isin(kept_win_ids)].copy()
    labels_win_new  = _align_labels_by_id_window(labels_win, kept_win_ids)
    labels_win_aug2 = _align_labels_by_id_window(labels_win_aug, kept_win_ids)

    if verbose:
        p_orig = y_main.mean()
        p_new  = y_main[keep_idx].mean()
        print(f"[Selection] desired='{desired}', kept {keep_idx.size}/{N} windows.")
        print(f"[Selection] pos ratio: original={p_orig:.4f}, new={p_new:.4f}")

    return (input_new, target_new, demo_new, df_win_new, labels_win_new, labels_win_aug2, stats, keep_idx)


def _valid_center_of_mass(valid_counts_t: np.ndarray) -> np.ndarray:
    N, T = valid_counts_t.shape
    t_idx = np.arange(T, dtype=float)[None, :]                 # (1, T)
    totals = valid_counts_t.sum(axis=1)                        # (N,)
    with np.errstate(invalid="ignore", divide="ignore"):
        com = (valid_counts_t * t_idx).sum(axis=1) / totals    # (N,)
    if T > 1:
        com_norm = com / (T - 1.0)
    else:
        com_norm = np.where(totals > 0, 0.5, np.nan)           # 单步时间：非空置 0.5
    com_norm[totals == 0] = np.nan
    return com_norm

def _classify_by_com(com_norm: np.ndarray, *, front_thr: float, back_thr: float) -> np.ndarray:
    cls = np.full(com_norm.shape, "average", dtype=object)
    finite = ~np.isnan(com_norm)
    cls[ finite & (com_norm < front_thr)] = "front"
    cls[ finite & (com_norm > back_thr) ] = "back"
    cls[~finite] = "empty"
    return cls

def _stratified_keep_indices(
    labels: np.ndarray, mask_candidate: np.ndarray, *, ratio_tol: float, random_state: int
) -> np.ndarray:

    rng = np.random.RandomState(random_state)
    idx_all = np.arange(labels.shape[0])
    idx_desired = idx_all[mask_candidate]
    if idx_desired.size == 0:
        return idx_desired

    p_orig = labels.mean()
    y_cand = labels[idx_desired]
    p_cand = y_cand.mean()

    if np.abs(p_cand - p_orig) <= ratio_tol:
        return idx_desired

    cls0_idx = idx_desired[y_cand == 0]
    cls1_idx = idx_desired[y_cand == 1]

    K = idx_desired.size
    n1_target = int(round(p_orig * K))
    n0_target = K - n1_target

    n0 = min(n0_target, cls0_idx.size)
    n1 = min(n1_target, cls1_idx.size)
    if n0 == 0 and cls0_idx.size > 0: n0 = 1
    if n1 == 0 and cls1_idx.size > 0: n1 = 1

    keep0 = rng.choice(cls0_idx, size=n0, replace=False) if n0 > 0 else np.array([], dtype=int)
    keep1 = rng.choice(cls1_idx, size=n1, replace=False) if n1 > 0 else np.array([], dtype=int)
    keep = np.sort(np.concatenate([keep0, keep1], axis=0))
    return keep

def _align_labels_by_id_window(df_labels: pd.DataFrame, id_order: np.ndarray) -> pd.DataFrame:
    out = df_labels.copy()
    assert "id_window" in out.columns, "labels/df_win 必须包含 'id_window'"
    mapper = pd.Categorical(out["id_window"], categories=id_order, ordered=True)
    out = out.loc[~mapper.isna()].copy()
    out["__ord__"] = mapper
    out = out.sort_values("__ord__").drop(columns="__ord__")
    return out


In [None]:
(input_train2, target_train2, demo_train2,
 df_train_win2, labels_train_win2, labels_train_win_aug,
 stats, keep_idx) = analyze_and_select_valid_distribution(
    input_arr=input_train,
    target_arr=target_train,
    labels_win=labels_train_win,
    df_win=df_train_win,
    demo_arr=demo_train,
    desired=None,                  
    main_target_name=None,         
    ratio_tol=0.05, random_state=42,
    front_thr=0.4, back_thr=0.6,
    verbose=True
)

(input_test2, target_test2, demo_test2,
 df_test_win2, labels_test_win2, labels_test_win_aug,
 stats, keep_idx) = analyze_and_select_valid_distribution(
    input_arr=input_test,
    target_arr=target_test,
    labels_win=labels_test_win,
    df_win=df_test_win,
    demo_arr=demo_test,
    desired=None,                  
    main_target_name=None,         
    ratio_tol=0.05, random_state=42,
    front_thr=0.2, back_thr=0.8,
    verbose=True
)

[Distribution] counts: {'average': 831}
[Distribution] counts: {'average': 1807}


In [23]:
N = counts_train.shape[0]
F = len(feature_cols)
input_train_3d = df_train_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)

N = counts_test.shape[0]
input_test_3d = df_test_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)

mask_train_3d = np.isfinite(input_train_3d)
mask_test_3d = np.isfinite(input_test_3d)

# ## adding demo version 1
# ## -----------------------------
# input_train_3d = np.concatenate([input_train_3d, np.repeat(demo_train[:, None, :], window_size, axis=1)], axis=2)
# input_test_3d = np.concatenate([input_test_3d, np.repeat(demo_test[:, None, :], window_size, axis=1)], axis=2)


input_train_flat = input_train_3d.reshape(input_train_3d.shape[0],-1)
input_test_flat = input_test_3d.reshape(input_test_3d.shape[0],-1)
mask_train_flat = mask_train_3d.reshape(mask_train_3d.shape[0],-1)
mask_test_flat = mask_test_3d.reshape(mask_test_3d.shape[0],-1)


# ## adding demo version 2
# ## ------------------------------
# input_train_flat = np.concatenate([input_train_flat, demo_train], axis=1)
# input_test_flat = np.concatenate([input_test_flat, demo_test], axis=1)


target_train_final = target_train[:,-1].reshape(-1,1)
target_test_final = target_test[:,-1].reshape(-1,1)

In [25]:
input = np.vstack((input_train_flat, input_test_flat))
target = np.vstack((target_train_final, target_test_final))
mask = np.vstack((mask_train_flat, mask_test_flat))

num_variables = input_train.shape[-1]

# input_fill = np.where(mask, input, 0.0)

start_test = len(input_train_flat) + 1

print(f'input shape: {input.shape}, target shape: {target.shape}, mask shape: {mask.shape}, num_variables: {num_variables}, start_test: {start_test}')

input shape: (2638, 48), target shape: (2638, 1), mask shape: (2638, 48), num_variables: 8, start_test: 832


### **MEF-Classifier**

In [26]:
import numpy as np
from enfis_functions_irregular_classifier import mar_trainOnline

def train(data_input, data_target, num_vars, max_cluster, half_life, threshold_mf, min_rule_weight, start_test, end_test=None, start_external=None,
           ablation=False, interpret=False, evo=False, mode='base', ds_W=2.0, base_lambda=1.0, tau_match=0.7, score_q=0.0, rule_partial_match=False):
    """
    Online train+predict for all windows. Returns system dict (contains predicted + uncertainty).
    """

    system = mar_trainOnline(data_input, data_target, num_vars, max_cluster, half_life, threshold_mf, min_rule_weight, ablation, evo, interpret, 
                             mode, ds_W, base_lambda, tau_match, score_q, rule_partial_match, start_test, start_external)

    print('training & testing finished')

    system['num_rules'] = float(np.mean(system['net']['ruleCount']))
    print('num_rule = ', system['num_rules'])

    # Convenience: store slices for test
    start_idx = int(start_test) - 1
    if end_test is not None:
        end_idx = int(end_test) - 1
    else:
        end_idx = system['predicted'].shape[0]

    system['test_start_idx'] = start_idx
    system['test_end_idx'] = end_idx

    system['probs_test'] = system['predicted'][start_idx:end_idx].copy()
    system['trues_test'] = data_target[start_idx:end_idx].copy()

    if 'uncertainty' in system['net']:
        system['uncertainty_test'] = system['net']['uncertainty'][start_idx:end_idx].copy()

    if start_external is not None:
        start_ext_idx = int(start_external) - 1
        system['external_start_idx'] = start_ext_idx
        system['probs_external'] = system['predicted'][start_ext_idx:].copy()
        system['trues_external'] = data_target[start_ext_idx:].copy()
        if 'uncertainty' in system['net']:
            system['uncertainty_external'] = system['net']['uncertainty'][start_ext_idx:].copy()

    return system


### **Evaluation Metrics**

In [None]:
import numpy as np
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    classification_report, roc_auc_score, average_precision_score
)

def eval_system(system, y_true, patient_all=None, start_test=1, end_test=None, start_external=None, thr=0.5, mode='ds', risk_k_grid=None,
                recall_targets=(0.8,), u_taus=(0.9, 0.8, 0.75, 0.7)):
    """
    system: output of train() / mar_trainOnline wrapper, must contain system['predicted']
            ds-mode may contain system['net']['uncertainty'] and system['net']['base_rate']
    y_true: full target array shape (N,1) or (N,)
    start_test: 1-indexed start position of test
    thr: default decision threshold
    mode: 'base' or 'ds' (only affects uncertainty/risk-multiple reporting)
    """

    start_idx = int(start_test) - 1
    if end_test is not None:
        end_idx = int(end_test) - 1
    else:
        end_idx = system['predicted'].shape[0]

    probs = system['predicted'][start_idx:end_idx].reshape(-1).astype(float)
    trues = np.asarray(y_true)[start_idx:end_idx].reshape(-1).astype(int)

    print(f"[Eval] test windows={len(trues)}, prevalence={trues.mean():.4f}")

    # --- ranking metrics ---
    if len(np.unique(trues)) == 2:
        auroc = roc_auc_score(trues, probs)
        auprc = average_precision_score(trues, probs)
        print(f"AUROC: {auroc:.4f}")
        print(f"AUPRC: {auprc:.4f}")
    else:
        print("AUROC/AUPRC skipped: only one class present in trues.")

    # --- threshold metrics at given thr ---
    preds = (probs >= thr).astype(int)
    acc = accuracy_score(trues, preds)
    prec = precision_score(trues, preds, zero_division=0)
    rec = recall_score(trues, preds, zero_division=0)
    f1 = f1_score(trues, preds, zero_division=0)

    print(f"\n[Threshold thr={thr:.3f}]")
    print(f"Accuracy : {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall   : {rec:.4f}")
    print(f"F1-Score : {f1:.4f}")
    print("Classification Report:\n", classification_report(trues, preds, digits=4, zero_division=0))

    # --- scan thresholds: best F1 ---
    ts = np.linspace(0.0, 1.0, 501)
    best = None
    for t in ts:
        p = (probs >= t).astype(int)
        P = precision_score(trues, p, zero_division=0)
        R = recall_score(trues, p, zero_division=0)
        F = f1_score(trues, p, zero_division=0)
        if best is None or F > best['f1']:
            best = {'t': float(t), 'precision': float(P), 'recall': float(R), 'f1': float(F), 'pos_pred': int(p.sum())}
    print(f"\n[Best-F1 scan] {best}")

    # --- scan thresholds: meet recall targets then maximize precision ---
    for target_R in recall_targets:
        cands = []
        for t in ts:
            p = (probs >= t).astype(int)
            P = precision_score(trues, p, zero_division=0)
            R = recall_score(trues, p, zero_division=0)
            F = f1_score(trues, p, zero_division=0)
            if R >= target_R:
                cands.append((t, P, R, F, int(p.sum())))
        if len(cands) == 0:
            print(f"[Recall>={target_R}] no threshold achieves target recall.")
        else:
            t, P, R, F, npos = max(cands, key=lambda x: x[1])
            print(f"[Recall>={target_R}] t={t:.3f}, precision={P:.3f}, recall={R:.3f}, f1={F:.3f}, pos_pred={npos}")

    risk_scan = None
    if mode == 'ds':
        a = float(system.get('net', {}).get('base_rate', trues.mean()))
        print(f"\n[Risk-multiple scan] base_rate(a)={a:.4f}")

        max_prob = float(np.max(probs)) if probs.size else 1.0
        max_thr = min(1.0, max_prob + 1e-6)

        if a <= 0:
            print("  (skip) base_rate<=0")
            risk_scan = {'a': a, 'k_grid': None, 'rows': [], 'best_f1': None, 'recall_best': {}}
        else:
            k_max = max_thr / a
            # reasonable bounds
            k_min = 0.0
            k_max = max(k_max, 0.0)

            if risk_k_grid is None:
                # finer than your old list: 0.5x to k_max (cap at 10) with step 0.05
                k_cap = min(10.0, k_max)
                risk_k_grid = np.round(np.arange(0.5, k_cap + 1e-9, 0.05), 3)

            rows = []
            best_f1_k = None

            for k in risk_k_grid:
                thr_k = float(k) * a
                if thr_k < 0 or thr_k > 1:
                    continue
                p = (probs >= thr_k).astype(int)
                P = precision_score(trues, p, zero_division=0)
                R = recall_score(trues, p, zero_division=0)
                F = f1_score(trues, p, zero_division=0)
                npos = int(p.sum())

                row = {'k': float(k), 'thr': thr_k, 'precision': float(P), 'recall': float(R), 'f1': float(F), 'pos_pred': npos}
                rows.append(row)

                if best_f1_k is None or row['f1'] > best_f1_k['f1']:
                    best_f1_k = row

            # print a compact summary instead of spamming all rows
            if best_f1_k is not None:
                print(f"  [Best-F1 by k] k={best_f1_k['k']:.3f} thr={best_f1_k['thr']:.3f} "
                    f"P={best_f1_k['precision']:.3f} R={best_f1_k['recall']:.3f} F1={best_f1_k['f1']:.3f} "
                    f"pos_pred={best_f1_k['pos_pred']}")

            # recall-constrained best by k (maximize precision under recall>=target)
            recall_best = {}
            for target_R in recall_targets:
                cand = [r for r in rows if r['recall'] >= float(target_R)]
                if len(cand) == 0:
                    recall_best[f"recall_ge_{target_R}"] = None
                    print(f"  [k for Recall>={target_R}] none")
                else:
                    bestP = max(cand, key=lambda r: r['precision'])
                    recall_best[f"recall_ge_{target_R}"] = bestP
                    print(f"  [k for Recall>={target_R}] k={bestP['k']:.3f} thr={bestP['thr']:.3f} "
                        f"P={bestP['precision']:.3f} R={bestP['recall']:.3f} F1={bestP['f1']:.3f} "
                        f"pos_pred={bestP['pos_pred']}")

            risk_scan = {'a': a, 'k_grid': risk_k_grid.tolist(), 'rows': rows,
                        'best_f1': best_f1_k, 'recall_best': recall_best}


    # --- uncertainty selective prediction (ds only) ---
    u = None
    if mode == 'ds' and ('net' in system) and ('uncertainty' in system['net']):
        u = system['net']['uncertainty'][start_idx:end_idx].reshape(-1).astype(float)
        print("\n[Uncertainty stats]")
        print(f"  mean={u.mean():.4f}, std={u.std():.4f}, min={u.min():.4f}, max={u.max():.4f}")

        if len(np.unique(trues)) == 2:
            print("\n[Selective prediction: keep u<=tau]")
            for tau in u_taus:
                keep = (u <= tau)
                if keep.sum() < 50 or len(np.unique(trues[keep])) < 2:
                    print(f"  tau={tau}: kept={keep.mean()*100:.1f}% (skip metrics: too few / one class)")
                    continue
                auroc_k = roc_auc_score(trues[keep], probs[keep])
                auprc_k = average_precision_score(trues[keep], probs[keep])
                print(f"  tau={tau}: kept={keep.mean()*100:.1f}%, AUROC={auroc_k:.3f}, AUPRC={auprc_k:.3f}")


    if start_external is not None:
        start_ext_idx = int(start_external) - 1
        probs_ext = system['predicted'][start_ext_idx:].reshape(-1).astype(float)
        trues_ext = np.asarray(y_true)[start_ext_idx:].reshape(-1).astype(int)

        print(f"\n[Eval External] windows={len(trues_ext)}, prevalence={trues_ext.mean():.4f}")

        # --- ranking metrics ---
        if len(np.unique(trues_ext)) == 2:
            auroc_ext = roc_auc_score(trues_ext, probs_ext)
            auprc_ext = average_precision_score(trues_ext, probs_ext)
            print(f"External AUROC: {auroc_ext:.4f}")
            print(f"External AUPRC: {auprc_ext:.4f}")
        else:
            print("External AUROC/AUPRC skipped: only one class present in trues.")

        # return {
        #     'auroc': auroc, 'auprc': auprc,
        #     'best_f1': best,
        #     'risk_scan': risk_scan,
        #     'auroc_external': auroc_ext,
        #     'auprc_external': auprc_ext,
        # }
    
    if patient_all is not None:
        # Aggregate to patient-level
        print("\n[Patient-level aggregation]")
        pred_all = system['predicted']
        assert len(pred_all) == len(patient_all), "patient_all length mismatch"

        pred_test = pred_all[start_idx:end_idx].reshape(-1).astype(float)
        patient_test = patient_all[start_idx:end_idx]
        trues_test = np.asarray(y_true)[start_idx:end_idx].reshape(-1).astype(int)

        df_inst_test = pd.DataFrame({
            'patient_id': patient_test,
            'p_hat': pred_test,
            'trues': trues_test,
            'inst_idx': np.arange(len(pred_test)) 
        })

        def patient_level_metrics(df_inst_test: pd.DataFrame):
            # True label per patient: max over their windows
            y_pat = df_inst_test.groupby('patient_id')['trues'].max()

            # ===== 1) Max-pooling =====
            s_max = df_inst_test.groupby('patient_id')['p_hat'].max()
            df_pat_max = pd.DataFrame({'y': y_pat, 's': s_max}).reset_index(drop=True)

            # Compute AUC metrics safely
            auroc_max = roc_auc_score(np.asarray(df_pat_max['y']).astype(int), np.asarray(df_pat_max['s']).astype(float)) if len(np.unique(df_pat_max['y'])) == 2 else float('nan')
            auprc_max = average_precision_score(np.asarray(df_pat_max['y']).astype(int), np.asarray(df_pat_max['s']).astype(float)) if len(np.unique(df_pat_max['y'])) == 2 else float('nan')

            # ===== 2) Last-window =====
            # last = the last occurrence in this test slice (if your test windows per patient are in time order)
            idx_last = df_inst_test.groupby('patient_id')['inst_idx'].idxmax()
            df_last = df_inst_test.loc[idx_last, ['patient_id', 'p_hat']].set_index('patient_id')['p_hat']
            df_pat_last = pd.DataFrame({'y': y_pat, 's': df_last}).reset_index(drop=True)

            auroc_last = roc_auc_score(np.asarray(df_pat_last['y']).astype(int), np.asarray(df_pat_last['s']).astype(float)) if len(np.unique(df_pat_last['y'])) == 2 else float('nan')
            auprc_last = average_precision_score(np.asarray(df_pat_last['y']).astype(int), np.asarray(df_pat_last['s']).astype(float)) if len(np.unique(df_pat_last['y'])) == 2 else float('nan')

            print(f"Per-patient MAX pooling:  AUROC={auroc_max:.4f}, AUPRC={auprc_max:.4f} "
                f"(N_pat={df_pat_max.shape[0]}, pos_rate={df_pat_max['y'].mean():.3f})")
            print(f"Per-patient LAST window:  AUROC={auroc_last:.4f}, AUPRC={auprc_last:.4f} "
                f"(N_pat={df_pat_last.shape[0]}, pos_rate={df_pat_last['y'].mean():.3f})")
            
            return auroc_max, auprc_max, auroc_last, auprc_last
        
        auroc_pat_max, auprc_pat_max, auroc_pat_last, auprc_pat_last = patient_level_metrics(df_inst_test)

        if start_external is not None:
            start_ext_idx = int(start_external) - 1
            pred_ext = pred_all[start_ext_idx:].reshape(-1).astype(float)
            patient_ext = patient_all[start_ext_idx:]
            trues_ext = np.asarray(y_true)[start_ext_idx:].reshape(-1).astype(int)

            df_inst_ext = pd.DataFrame({
                'patient_id': patient_ext,
                'p_hat': pred_ext,
                'trues': trues_ext,
                'inst_idx': np.arange(len(pred_ext))  # 用顺序当作“时间”，last=最后一次出现
            })

            print("\n[External Patient-level aggregation]")
            auroc_pat_ext_max, auprc_pat_ext_max, auroc_pat_ext_last, auprc_pat_ext_last = patient_level_metrics(df_inst_ext)

            return {
                'auroc': auroc, 'auprc': auprc,
                'probs': probs, 'trues': trues, 'uncertainty': u,
                'best_f1': best,
                'risk_scan': risk_scan,
                'auroc_external': auroc_ext, 'auprc_external': auprc_ext,
                'auroc_pat_max': auroc_pat_max,
                'auprc_pat_max': auprc_pat_max,
                'auroc_pat_last': auroc_pat_last,
                'auprc_pat_last': auprc_pat_last,
                'auroc_pat_ext_max': auroc_pat_ext_max,
                'auprc_pat_ext_max': auprc_pat_ext_max,
                'auroc_pat_ext_last': auroc_pat_ext_last,
                'auprc_pat_ext_last': auprc_pat_ext_last,
            }
        
        return {
            'auroc': auroc, 'auprc': auprc,
            'probs': probs, 'trues': trues, 'uncertainty': u,
            'best_f1': best,
            'risk_scan': risk_scan,
            'auroc_pat_max': auroc_pat_max,
            'auprc_pat_max': auprc_pat_max,
            'auroc_pat_last': auroc_pat_last,
            'auprc_pat_last': auprc_pat_last,
        }
        

    
    return {
        'auroc': auroc, 'auprc': auprc,
        'probs': probs, 'trues': trues, 'uncertainty': u,
        'best_f1': best,
        'risk_scan': risk_scan,
    }




### **LOOP + External Validation**

In [34]:
import numpy as np
import pandas as pd
import os

from sklearn import preprocessing
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit, GroupShuffleSplit

In [35]:
fp_project_folder = "../"
fp_code_folder = os.path.join(fp_project_folder, "code")
fp_data_folder = os.path.join(fp_code_folder, "data")
fp_checkpoint_folder = os.path.join(fp_code_folder, "checkpoints")
fp_figure_folder = os.path.join(fp_code_folder, "figures")

fp_ts = os.path.join(fp_data_folder, "Preprocess/data_compact_2.csv")
fp_demo = os.path.join(fp_data_folder, "Baseline.HCC.livercomp.FIB4.PLT.csv")

In [36]:
def print_descr(input_df, id_name):
  id_name = id_name
  print("Unique records: ", len(input_df[id_name].unique()), ", Total num rows: ", len(input_df), ", Max length: ", max(input_df.groupby([id_name]).size()))
external_df_rawTS = pd.read_csv(fp_ts)

external_feature_cols = [
    'Albumin', 'Platelets', ## tier 0
    'Lymphocytes', 'White cells count', 'Neutrophils', 'Basophils', 'Eosinophils', 'Total protein' ## tier 1
]

external_df_rawTS = external_df_rawTS[['ID', 'TimeUnit', 'MonthPeriod'] + external_feature_cols]
external_df_rawTS['MonthPeriod'] = pd.to_datetime(external_df_rawTS['MonthPeriod']).dt.to_period('M')
print_descr(external_df_rawTS, id_name='ID')

external_df_rawDemo = pd.read_csv(fp_demo)
external_df_rawDemo = external_df_rawDemo.iloc[:, 1:]  # remove the first column which is just index
external_df_rawDemo = external_df_rawDemo.rename(columns={'Study.ID':'ID', 'birth.year':'birth_year', 'age.entry':'age_entry', 
                                        'Liver.failure':'Liver_failure', 'Date_Liver.failure':'Date_Liver_failure'})

external_demo_cols = [
    'birth_year', 'age_entry', 'gender', 'earlieststeatosisentrydate', 
    'HCC', 'HCC_date', 'Cirrhosis', 'Date_Cirrhosis', 'Fibrosis', 'Date_Fibrosis', 'Liver_failure', 'Date_Liver_failure'
]
external_df_rawDemo = external_df_rawDemo[['ID'] + external_demo_cols]
external_df_rawDemo['earlieststeatosisentrydate'] = pd.to_datetime(external_df_rawDemo['earlieststeatosisentrydate'], dayfirst=True, errors='coerce')
external_df_rawDemo['HCC_date'] = pd.to_datetime(external_df_rawDemo['HCC_date'], dayfirst=True, errors='coerce')
external_df_rawDemo['Date_Cirrhosis'] = pd.to_datetime(external_df_rawDemo['Date_Cirrhosis'], dayfirst=True, errors='coerce')
external_df_rawDemo['Date_Fibrosis'] = pd.to_datetime(external_df_rawDemo['Date_Fibrosis'], dayfirst=True, errors='coerce')
external_df_rawDemo['Date_Liver_failure'] = pd.to_datetime(external_df_rawDemo['Date_Liver_failure'], dayfirst=True, errors='coerce')
external_df_rawDemo = external_df_rawDemo.sort_values(by=['ID']).reset_index(drop=True)
print_descr(external_df_rawDemo, id_name='ID')

Unique records:  3106 , Total num rows:  40256 , Max length:  98
Unique records:  3192 , Total num rows:  3192 , Max length:  1


  external_df_rawTS = pd.read_csv(fp_ts)


In [None]:
## remove the row with all NaN
external_all_nan_mask = external_df_rawTS[external_feature_cols].isna().all(axis=1)

external_problematic_rows = external_df_rawTS[external_all_nan_mask]

if not external_problematic_rows.empty:
    print("找到以下行的特征列全部为 NaN:")
    print(external_problematic_rows[['ID', 'TimeUnit']])
    print(f'rows of nan: {len(external_problematic_rows)}')
else:
    print("没有找到任何特征列全部为 NaN 的行。")

print(f"原始 DataFrame 的行数: {len(external_df_rawTS)}, 原始DataFrame的ID数: {len(external_df_rawTS['ID'].unique())}")
external_df_rawTS = external_df_rawTS.dropna(subset=external_feature_cols, how='all')
print(f"清理后 DataFrame 的行数: {len(external_df_rawTS)}, 清理后DataFrame的ID数: {len(external_df_rawTS['ID'].unique())}")

找到以下行的特征列全部为 NaN:
             ID  TimeUnit
11     NASH0002        50
44     NASH0004        34
46     NASH0004        37
47     NASH0004        39
61     NASH0007       112
...         ...       ...
40222  NASH3191        77
40227  NASH3191        86
40234  NASH3191        96
40238  NASH3191       100
40239  NASH3191       101

[8591 rows x 2 columns]
rows of nan: 8591
原始 DataFrame 的行数: 40256, 原始DataFrame的ID数: 3106
清理后 DataFrame 的行数: 31665, 清理后DataFrame的ID数: 3061


In [38]:
## remove the rows in external_df_rawTS whose MonthPeriod is before earlieststeatosisentrydate in external_df_rawDemo
## remove the rows in external_df_rawTS whose ID is not in external_df_rawDemo
external_df_filteredTS = pd.merge(external_df_rawTS, external_df_rawDemo[['ID', 'earlieststeatosisentrydate']], on='ID', how='left')
external_mask_valid_time = external_df_filteredTS['MonthPeriod'] >= external_df_filteredTS['earlieststeatosisentrydate'].dt.to_period('M')
external_df_filteredTS = external_df_filteredTS[external_mask_valid_time]
external_df_filteredTS = external_df_filteredTS.drop(columns=['earlieststeatosisentrydate'])
external_mask_valid_id = external_df_filteredTS['ID'].isin(external_df_rawDemo['ID'])
external_df_filteredTS = external_df_filteredTS[external_mask_valid_id]
print_descr(external_df_filteredTS, id_name='ID')

Unique records:  3025 , Total num rows:  24556 , Max length:  97


In [39]:
external_EVENT_DATE_COLS = ['HCC_date', 'Date_Cirrhosis', 'Date_Fibrosis', 'Date_Liver_failure']
external_PATIENT_INFO_COLS = ['birth_year', 'age_entry', 'gender', 'earlieststeatosisentrydate']
external_ID_COL = 'ID'

external_idx_pred_event = [0,1,2,3] # 预测n个事件的综合风险
external_predicted_EVENT_DATE_COLS = [external_EVENT_DATE_COLS[i] for i in external_idx_pred_event]

external_df = external_df_rawDemo.copy()
## --df_demo: patient demographics (each patient/ID one row)
external_demo_cols = [external_ID_COL] + external_PATIENT_INFO_COLS
external_demo_cols = [c for c in external_demo_cols if c in external_df.columns]
external_df_demo = external_df[external_demo_cols].drop_duplicates(subset=[external_ID_COL]).reset_index(drop=True)
## --df_label: patient labels (each patient/ID one row)
external_label_cols = [external_ID_COL] + external_EVENT_DATE_COLS
external_label_cols = [c for c in external_label_cols if c in external_df.columns]
external_df_labels = external_df[external_label_cols].drop_duplicates(subset=[external_ID_COL]).reset_index(drop=True)
external_df_labels['earliest_event_date'] = external_df_labels[external_predicted_EVENT_DATE_COLS].min(axis=1, skipna=True)
external_df_labels['has_event'] = external_df_labels['earliest_event_date'].notna().astype(int)

In [40]:
external_df_ts = external_df_filteredTS.copy()
external_df_demo = external_df_demo.copy()
external_df_labels = external_df_labels.copy()

In [None]:
import ast

STAT_INDEX = {'first': 0, 'last': 1, 'mean': 2, 'max': 3, 'min': 4, 'count': 5}

def _extract_from_cell(cell, idx):
    if pd.isna(cell):
        return np.nan
    if isinstance(cell, (tuple, list)):
        return cell[idx] if len(cell) > idx else np.nan
    if isinstance(cell, str):
        try:
            t = ast.literal_eval(cell)
            if isinstance(t, (tuple, list)) and len(t) > idx:
                return t[idx]
        except Exception:
            pass
        return np.nan
    return cell

def expand_zipped_columns(df, feature_cols, selection, drop_original=True):
    df_out = df.copy()
    flag = False
    if isinstance(selection, dict):
        sel_map = {col: tuple(selection[col]) for col in feature_cols}
    else:
        stats = tuple(selection)
        sel_map = {col: stats for col in feature_cols}
        if len(stats) == 1:
            flag = True

    new_cols = []
    for col in feature_cols:
        stats_for_col = sel_map.get(col, ('last',))
        sample = df_out[col].dropna()
        looks_zipped = False
        if len(sample) > 0:
            v = sample.iloc[0]
            looks_zipped = isinstance(v, (tuple, list, str))

        if looks_zipped:
            for stat in stats_for_col:
                idx = STAT_INDEX[stat]
                new_name = f"{col}_{stat}"
                df_out[new_name] = df_out[col].map(lambda x: _extract_from_cell(x, idx))
                new_cols.append(new_name)
            if drop_original:
                df_out.drop(columns=[col], inplace=True)
            if len(stats_for_col) == 1:
                df_out.rename(columns={new_name: col}, inplace=True)
        else:
            new_cols.append(col)
    if flag:
        new_cols = feature_cols

    return df_out, new_cols


In [None]:
def window_to_patient_list(df_win, id_col, win_col):
    meta = (df_win[[win_col, id_col]]
            .drop_duplicates()
            .sort_values(win_col)          
            .reset_index(drop=True))
    patient_list = meta[id_col].to_list()  
    win_list = meta[win_col].to_list()
    return patient_list, win_list, meta

In [None]:
agg_select = ('first',)

external_id_col, external_time_col = 'ID', 'TimeUnit'
external_raw_feature_cols = [c for c in external_df_ts.columns if c not in [external_id_col, external_time_col, 'MonthPeriod']]

external_ts_pre, external_feature_cols = expand_zipped_columns(external_df_ts, external_raw_feature_cols,
                                                selection=agg_select, drop_original=True)

df_external_ts = external_ts_pre.copy()
df_external_demo = external_df_demo.copy()
df_external_labels = external_df_labels.copy()

rename_map = {
    "HCC_date": "t_hcc_first",
    "Date_Cirrhosis": "t_cirr_first",
    "Date_Fibrosis": "t_fib_first",
    "Date_Liver_failure": "t_lf_first",
    "earliest_event_date": "t_event_min",
}
df_external_labels = df_external_labels.rename(columns=rename_map)

external_EVENT_DATE_COLS = ["t_hcc_first", "t_cirr_first", "t_fib_first", "t_lf_first", "t_event_min"]

In [None]:
import numpy as np
import pandas as pd
import random

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    except Exception:
        pass

def run_once_with_external(seed: int, df_external_ts, use_demo=False, interpret=False):
    set_seed(seed)

    train_ids, test_ids = split_patient_ids(ts, df_labels=df_labels, id_col="subject_id", random_state=seed, use_stratify=True, subset_frac=None, balance=False)
    df_train_ts, df_test_ts = slice_by_ids(ts, train_ids, test_ids, id_col="subject_id")
    print(f"[Seed {seed}] Train patients: {len(train_ids)}, Test patients: {len(test_ids)}, External patients: {len(df_external_ts['ID'].unique())}")

    ## --- normalization ---
    feature_cols = [c for c in ts.columns if c not in ["subject_id", "MonthPeriod", "TimeUnit"]]
    scaler = preprocessing.StandardScaler()
    scaler.fit(df_train_ts[feature_cols])              
    df_train_ts[feature_cols] = scaler.transform(df_train_ts[feature_cols])  
    df_test_ts[feature_cols]  = scaler.transform(df_test_ts[feature_cols])

    df_external_ts = df_external_ts.copy()
    missing = [c for c in feature_cols if c not in df_external_ts.columns]
    assert len(missing) == 0, f"External data missing columns: {missing}"
    df_external_ts[feature_cols] = scaler.transform(df_external_ts[feature_cols])  

    # sliding window parameters
    window_size = 12
    time_horizon_months = 12  
    min_win_valid = int(window_size*0.5)  
    fill_method = 'inf'  # 'nan', 'inf', 'zero', 'forward_fill', 'backward_fill', 'mean', 'median'
    agg_select = ('first',)  # ('last',), ('first','last'), ('first','mean','max'), or {'Albumin':['last'], 'Platelets':['mean','max']}
    

    id_col = 'subject_id'
    time_col = 'TimeUnit'
    EVENT_DATE_COLS = event_cols 
    PATIENT_INFO_COLS = [ "birth_year", "age_at_index", "gender_num" ] 

    df_demo = cohort[["subject_id"] + PATIENT_INFO_COLS].copy()
    df_demo = df_demo.rename(columns={"gender_num": "gender"})
    PATIENT_INFO_COLS = [c for c in df_demo.columns if c not in ['subject_id']]

    print(f"\n=== 创建滑动窗口 (window_size={window_size}, fill_method={fill_method}) ===")

    balance_windows_train = True
    balance_windows_test = False

    df_train_win, demo_train_win, labels_train_win = create_sliding_windows(df_train_ts, window_size=window_size, min_win_valid=min_win_valid, fill_method=fill_method,
                                                                            id_col=id_col, time_col=time_col, 
                                                                            patient_info_cols=PATIENT_INFO_COLS, event_date_cols=EVENT_DATE_COLS,
                                                                            preexpanded=True, features_override=feature_cols, agg_select=agg_select,
                                                                            df_demo=df_demo, df_labels=df_labels, time_horizon_months=time_horizon_months,
                                                                            balance_windows=balance_windows_train)
    df_test_win, demo_test_win, labels_test_win = create_sliding_windows(df_test_ts,  window_size=window_size, min_win_valid=min_win_valid, fill_method=fill_method,
                                                                        id_col=id_col, time_col=time_col, 
                                                                        patient_info_cols=PATIENT_INFO_COLS, event_date_cols=EVENT_DATE_COLS,
                                                                        preexpanded=True, features_override=feature_cols, agg_select=agg_select,
                                                                        df_demo=df_demo, df_labels=df_labels, time_horizon_months=time_horizon_months,
                                                                        balance_windows=balance_windows_test)
    ## --- 添加 external 数据的窗口 ---
    df_external_win, demo_external_win, labels_external_win = create_sliding_windows(df_external_ts, window_size=window_size, min_win_valid=min_win_valid, fill_method=fill_method,
                                                                        id_col=external_id_col, time_col=external_time_col, 
                                                                        patient_info_cols=external_PATIENT_INFO_COLS, event_date_cols=external_EVENT_DATE_COLS,
                                                                        preexpanded=True, features_override=feature_cols, agg_select=agg_select,
                                                                        df_demo=df_external_demo, df_labels=df_external_labels, time_horizon_months=time_horizon_months,
                                                                        balance_windows=balance_windows_test)

    ## patient ids list
    patient_train, win_train, meta_train = window_to_patient_list(df_train_win, id_col=id_col, win_col='id_window')
    patient_test,  win_test,  meta_test  = window_to_patient_list(df_test_win, id_col=id_col, win_col='id_window')
    patient_external, win_external, meta_external = window_to_patient_list(df_external_win, id_col=external_id_col, win_col='id_window')
    patient_all = patient_train + patient_test + patient_external

    ## --- 准备模型输入输出数组 ---
    # 1) sort window rows
    df_train_win = df_train_win.sort_values(['id_window', 'TimeUnit']).reset_index(drop=True)
    df_test_win  = df_test_win.sort_values(['id_window', 'TimeUnit']).reset_index(drop=True)

    df_external_win = df_external_win.sort_values(['id_window', 'TimeUnit']).reset_index(drop=True)

    # 2) sort labels/demo by id_window too
    labels_train_win = labels_train_win.sort_values('id_window').reset_index(drop=True)
    labels_test_win  = labels_test_win.sort_values('id_window').reset_index(drop=True)

    labels_external_win = labels_external_win.sort_values('id_window').reset_index(drop=True)

    demo_train_win = demo_train_win.sort_values('id_window').reset_index(drop=True)
    demo_test_win  = demo_test_win.sort_values('id_window').reset_index(drop=True)

    demo_external_win = demo_external_win.sort_values('id_window').reset_index(drop=True)

    # 3) ensure same id_window order between windowed rows and labels/demo
    win_ids_train = df_train_win['id_window'].drop_duplicates().to_numpy()
    lab_ids_train = labels_train_win['id_window'].to_numpy()
    dem_ids_train = demo_train_win['id_window'].to_numpy()

    assert (win_ids_train == lab_ids_train).all(), "Train: id_window order mismatch between df_train_win and labels_train_win"
    assert (win_ids_train == dem_ids_train).all(), "Train: id_window order mismatch between df_train_win and demo_train_win"

    counts_train = df_train_win.groupby('id_window').size()
    counts_test  = df_test_win.groupby('id_window').size()

    count_external = df_external_win.groupby('id_window').size()

    assert counts_train.nunique() == 1 and counts_train.iloc[0] == window_size, \
        f"Train windows not all size={window_size}: {counts_train.value_counts().head()}"

    assert counts_test.nunique() == 1 and counts_test.iloc[0] == window_size, \
        f"Test windows not all size={window_size}: {counts_test.value_counts().head()}"
    
    assert count_external.nunique() == 1 and count_external.iloc[0] == window_size, \
        f"External windows not all size={window_size}: {count_external.value_counts().head()}"


    N = counts_train.shape[0]
    F = len(feature_cols)
    input_train = df_train_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)
    print(f"训练集窗口数: {N}, 特征数: {F}, 输入形状: {input_train.shape}")

    N = counts_test.shape[0]
    input_test = df_test_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)
    print(f"测试集窗口数: {N}, 特征数: {F}, 输入形状: {input_test.shape}")

    N = count_external.shape[0]
    input_external = df_external_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)
    print(f"外部集窗口数: {N}, 特征数: {F}, 输入形状: {input_external.shape}")

    mask_train = np.isnan(input_train)
    mask_test = np.isnan(input_test)

    mask_external = np.isnan(input_external)

    target_cols = [c for c in labels_train_win.columns if c.startswith('y_')]
    target_train = labels_train_win[target_cols].to_numpy(dtype=int)  # (N, num_targets)
    target_test  = labels_test_win[target_cols].to_numpy(dtype=int)    # (N, num_targets)

    target_cols_ext = [c for c in labels_external_win.columns if c.startswith("y_")]
    assert set(target_cols_ext) == set(target_cols), \
        f"External y_ cols mismatch. train={target_cols}, external={target_cols_ext}"

    target_external  = labels_external_win[target_cols].to_numpy(dtype=int)    # (N, num_targets)
    print(f"训练集标签形状: {target_train.shape}, 测试集标签形状: {target_test.shape}, 外部集标签形状: {target_external.shape}")
    for i, col in enumerate(target_cols):
        print(f"  目标列 {col} 的正类比例: 训练集 {target_train[:,i].mean():.4f}, 测试集 {target_test[:,i].mean():.4f}, 外部集 {target_external[:,i].mean():.4f}")

    demo_cols = ['gender', 'age_at_window']
    demo_train = demo_train_win[demo_cols].to_numpy(dtype=int)  # (N, 2)
    demo_test  = demo_test_win[demo_cols].to_numpy(dtype=int)    # (N, 2)
    demo_external  = demo_external_win[demo_cols].to_numpy(dtype=int)    # (N, 2)
    print(f"训练集人口统计信息形状: {demo_train.shape}, 测试集人口统计信息形状: {demo_test.shape}, 外部集人口统计信息形状: {demo_external.shape}")


    N = counts_train.shape[0]
    F = len(feature_cols)
    input_train_3d = df_train_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)

    N = counts_test.shape[0]
    input_test_3d = df_test_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)

    N = count_external.shape[0]
    input_external_3d = df_external_win[feature_cols].to_numpy(dtype=float).reshape(N, window_size, F)  # (N,T,F)

    mask_train_3d = np.isfinite(input_train_3d)
    mask_test_3d = np.isfinite(input_test_3d)
    mask_external_3d = np.isfinite(input_external_3d)

    input_train_flat = input_train_3d.reshape(input_train_3d.shape[0],-1)
    input_test_flat = input_test_3d.reshape(input_test_3d.shape[0],-1)
    input_external_flat = input_external_3d.reshape(input_external_3d.shape[0],-1)
    mask_train_flat = mask_train_3d.reshape(mask_train_3d.shape[0],-1)
    mask_test_flat = mask_test_3d.reshape(mask_test_3d.shape[0],-1)
    mask_external_flat = mask_external_3d.reshape(mask_external_3d.shape[0],-1)

    if use_demo:
        input_train_flat = np.concatenate([input_train_flat, demo_train], axis=1)
        input_test_flat = np.concatenate([input_test_flat, demo_test], axis=1)
        input_external_flat = np.concatenate([input_external_flat, demo_external], axis=1)

    target_train_final = target_train[:,-1].reshape(-1,1)
    target_test_final = target_test[:,-1].reshape(-1,1)
    target_external_final = target_external[:,-1].reshape(-1,1)

    input = np.vstack((input_train_flat, input_test_flat, input_external_flat))
    target = np.vstack((target_train_final, target_test_final, target_external_final))
    mask = np.vstack((mask_train_flat, mask_test_flat, mask_external_flat))

    num_variables = input_train.shape[-1]

    # input_fill = np.where(mask, input, 0.0)

    start_test = len(input_train_flat) + 1
    end_test = len(input_train_flat) + len(input_test_flat)
    start_external = len(input_train_flat) + len(input_test_flat) + 1

    print(f'input shape: {input.shape}, target shape: {target.shape}, mask shape: {mask.shape}, num_variables: {num_variables}, \
          start_test: {start_test}, start_external: {start_external}')
    
    max_cluster = 50
    half_life = 100
    threshold_mf = 0.7
    min_rule_weight = 0.6

    system_loop = train(
        input, target, num_variables, max_cluster, half_life, threshold_mf, min_rule_weight,
        start_test, end_test, start_external,
        ablation=False, interpret=interpret, evo=False,
        mode='ds', ds_W=2.0, base_lambda=1.0,
        tau_match=0.7, score_q=0.0, rule_partial_match=False
    )

    pred_all = np.asarray(system_loop["predicted"]).reshape(-1)
    assert len(pred_all) == len(patient_all), "predicted 长度和 window 数量不一致"
    
    metrics = eval_system(system_loop, y_true=target, patient_all=patient_all, start_test=start_test, end_test=end_test, start_external=start_external, thr=0.5, mode='ds')

    if interpret:
        return metrics, system_loop, scaler, [start_test, start_external], feature_cols, window_size, patient_all, target

    return metrics['auroc'], metrics['auprc'], metrics['auroc_external'], metrics['auprc_external'], \
           metrics['auroc_pat_max'], metrics['auprc_pat_max'], metrics['auroc_pat_last'], metrics['auprc_pat_last'], \
           metrics['auroc_pat_ext_max'], metrics['auprc_pat_ext_max'], metrics['auroc_pat_ext_last'], metrics['auprc_pat_ext_last']



In [None]:
# ===============================
#  外层重复实验 + 汇总 mean/std
# ===============================
seeds = [0, 1, 2, 3, 4]  
rows = []

use_demo = True

for s in seeds:
    auroc, auprc, auroc_ext, auprc_ext, \
    auroc_pat_max, auprc_pat_max, auroc_pat_last, auprc_pat_last, \
    auroc_pat_ext_max, auprc_pat_ext_max, auroc_pat_ext_last, auprc_pat_ext_last = run_once_with_external(s, df_external_ts, use_demo=use_demo)
    rows.append({"seed": s, "AUROC": auroc, "AUPRC": auprc, "AUROC_external": auroc_ext, "AUPRC_external": auprc_ext, \
                 "AUROC_pat_max": auroc_pat_max, "AUPRC_pat_max": auprc_pat_max, "AUROC_pat_last": auroc_pat_last, "AUPRC_pat_last": auprc_pat_last, \
                 "AUROC_pat_ext_max": auroc_pat_ext_max, "AUPRC_pat_ext_max": auprc_pat_ext_max, "AUROC_pat_ext_last": auroc_pat_ext_last, "AUPRC_pat_ext_last": auprc_pat_ext_last})

res = pd.DataFrame(rows)
display(res)

summary = res[["AUROC", "AUPRC", "AUROC_external", "AUPRC_external"]].agg(["mean", "std"]).T
display(summary)

print(f"AUROC: {summary.loc['AUROC','mean']:.4f} ± {summary.loc['AUROC','std']:.4f}")
print(f"AUPRC: {summary.loc['AUPRC','mean']:.4f} ± {summary.loc['AUPRC','std']:.4f}")
print(f"AUROC_external: {summary.loc['AUROC_external','mean']:.4f} ± {summary.loc['AUROC_external','std']:.4f}")
print(f"AUPRC_external: {summary.loc['AUPRC_external','mean']:.4f} ± {summary.loc['AUPRC_external','std']:.4f}")

Patients original total=2358 | used total=2358  | Train=1886 | Test=472
Pos rate (has_event) — Train: 0.088016967126193  Test: 0.08898305084745763
[Seed 0] Train patients: 1886, Test patients: 472, External patients: 3025

=== 创建滑动窗口 (window_size=12, fill_method=inf) ===
特征列: ['Albumin', 'Basophils', 'Eosinophils', 'Lymphocytes', 'Neutrophils', 'Platelets', 'Total protein', 'White cells count']
窗口大小: 12
填充方法: inf

生成了 4769 个窗口
结果数据形状: (57228, 12)
特征列: ['Albumin', 'Basophils', 'Eosinophils', 'Lymphocytes', 'Neutrophils', 'Platelets', 'Total protein', 'White cells count']
窗口大小: 12
填充方法: inf

生成了 1208 个窗口
结果数据形状: (14496, 12)
特征列: ['Albumin', 'Basophils', 'Eosinophils', 'Lymphocytes', 'Neutrophils', 'Platelets', 'Total protein', 'White cells count']
窗口大小: 12
填充方法: inf

生成了 6398 个窗口
结果数据形状: (76776, 12)
训练集窗口数: 613, 特征数: 8, 输入形状: (613, 12, 8)
测试集窗口数: 1208, 特征数: 8, 输入形状: (1208, 12, 8)
外部集窗口数: 6398, 特征数: 8, 输入形状: (6398, 12, 8)
训练集标签形状: (613, 5), 测试集标签形状: (1208, 5), 外部集标签形状: (6398, 5)
  目标列 y_t

Unnamed: 0,seed,AUROC,AUPRC,AUROC_external,AUPRC_external,AUROC_pat_max,AUPRC_pat_max,AUROC_pat_last,AUPRC_pat_last,AUROC_pat_ext_max,AUPRC_pat_ext_max,AUROC_pat_ext_last,AUPRC_pat_ext_last
0,0,0.8082,0.434588,0.809813,0.426371,0.791353,0.416602,0.788847,0.421214,0.749237,0.437898,0.768192,0.474209
1,1,0.838448,0.464281,0.811411,0.419858,0.742241,0.455149,0.804741,0.478176,0.773943,0.446478,0.76732,0.448985
2,2,0.772336,0.322545,0.812719,0.415605,0.829406,0.49154,0.865266,0.547329,0.757647,0.461771,0.766144,0.475631
3,3,0.791022,0.303536,0.81728,0.434989,0.823268,0.497549,0.807374,0.485515,0.75915,0.460124,0.78658,0.488002
4,4,0.872477,0.567705,0.812675,0.433109,0.916667,0.674619,0.886574,0.494718,0.757255,0.458472,0.75573,0.422137


Unnamed: 0,mean,std
AUROC,0.816497,0.039635
AUPRC,0.418531,0.108445
AUROC_external,0.81278,0.002781
AUPRC_external,0.425987,0.008325


AUROC: 0.8165 ± 0.0396
AUPRC: 0.4185 ± 0.1084
AUROC_external: 0.8128 ± 0.0028
AUPRC_external: 0.4260 ± 0.0083
