## Five-Fold Generalizability Dataset 
Generates dataset with each participant sorted into a fold. Each fold contains approximately 19.5% of all dementia cases. Demographic balance is also verified between folds.

In [1]:
# Load necessary packages 
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold

# Load mediation dataset 
df = pd.read_csv("med_model_dataset_7.8.2025.csv")

# Initialize stratified folds and set random state for reproducibility 
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Create a column to store fold numbers to use in mediation itself 
df["fold"] = np.nan

# Stratify by dementia_event and save fold number in column 
for fold_idx, (_, val_idx) in enumerate(skf.split(df, df["dementia_event"])):
    df.loc[val_idx, "fold"] = fold_idx


# Print demographics to verify balance 
print(f"Dementia Counts\n{df.groupby("fold")["dementia_event"].value_counts(normalize=True)} \n")
print(f"Tier 1 Covariates \n {df.groupby("fold")[["age", "sex", "bmi"]].mean()} \n")
print(f"Tier 2 Covariates \n {df.groupby("fold")[["broader_education", "broader_smoking", "broader_alcohol"]].mean()} \n")
print(f"Tier 3 Covariates \n {df.groupby("fold")[["group1_med", "group2_med"]].mean()} \n")
print(f"Participant Value Counts (per fold) \n{df["fold"].value_counts().sort_index()}")

Dementia Counts
fold  dementia_event
0.0   0                 0.804938
      1                 0.195062
1.0   0                 0.804938
      1                 0.195062
2.0   0                 0.804938
      1                 0.195062
3.0   0                 0.806931
      1                 0.193069
4.0   0                 0.804455
      1                 0.195545
Name: proportion, dtype: float64 

Tier 1 Covariates 
             age       sex        bmi
fold                                
0.0   61.649383  0.518519  28.862866
1.0   61.125926  0.493827  28.524857
2.0   61.239506  0.511111  28.626476
3.0   61.933168  0.490099  28.697938
4.0   61.480198  0.490099  28.629821 

Tier 2 Covariates 
       broader_education  broader_smoking  broader_alcohol
fold                                                     
0.0            0.429630         0.083951         0.543210
1.0            0.451852         0.118519         0.466667
2.0            0.412346         0.125926         0.538272
3.0    

In [2]:
# Export stratified df to csv for mediation 

#df.to_csv("NAME FILE.csv", index=False)