In [4]:
import pandas as pd
import pickle
from sklearn.model_selection import GroupKFold

# ===== Step 1: Load and Inspect the Original Data =====
csv_file = "processed_dataset_with_participant_id.csv"
df = pd.read_csv(csv_file)
print("=== Original Data ===")
print("Total rows:", len(df))
print("Unique participant_ids (raw):", df["participant_id"].unique())
print("Counts per participant:\n", df["participant_id"].value_counts())
print("-" * 50)

# ===== Step 2: Group Data Solely by 'participant_id' =====
grouped_data = []
group_ids = []
# Make sure there are no hidden differences (like leading/trailing spaces)
df["participant_id"] = df["participant_id"].str.strip()

for participant, group in df.groupby("participant_id"):
    print(f"Grouping participant: {participant} with {len(group)} rows")
    grouped_data.append(group)
    group_ids.append(participant)

print("Total groups formed:", len(group_ids))
print("Unique group_ids:", set(group_ids))
print("-" * 50)

# ===== Step 3: Save the Grouped Data =====
grouped_dict = {
    "group_id": group_ids, 
    "data": grouped_data, 
    "columns": grouped_data[0].columns.tolist() if grouped_data else []
}
processed_file = "processed_2d_data_new.pkl"
with open(processed_file, "wb") as f:
    pickle.dump(grouped_dict, f)
print("Grouped data saved to", processed_file)
print("-" * 50)

# ===== Step 4: Check the GroupKFold Splitting =====
# Load back the processed data (for clarity in debugging)
with open(processed_file, "rb") as f:
    grouped_dict = pickle.load(f)

data = grouped_dict["data"]
group_ids = grouped_dict["group_id"]
print("=== Processed Data Info ===")
print("Number of unique participants:", len(set(group_ids)))
print("List of participant IDs:", group_ids)
print("-" * 50)

# Initialize GroupKFold with 5 splits
gkf = GroupKFold(n_splits=5)
fold_indices = []
for fold, (train_idx, test_idx) in enumerate(gkf.split(data, groups=group_ids)):
    train_group_ids = [group_ids[i] for i in train_idx]
    test_group_ids = [group_ids[i] for i in test_idx]
    print(f"Fold {fold+1}:")
    print("  Train indices:", train_idx)
    print("  Test indices:", test_idx)
    print("  Train participant IDs:", train_group_ids)
    print("  Test participant IDs:", test_group_ids)
    print("-" * 30)
    fold_indices.append((train_idx, test_idx))

# ===== Step 5: Save Each Fold and Print Participant IDs =====
for fold, (train_idx, test_idx) in enumerate(gkf.split(data, groups=group_ids)):
    train_data = [data[i] for i in train_idx]
    test_data = [data[i] for i in test_idx]
    
    train_file_path = f"train_2d_fold{fold+1}.pkl"
    test_file_path = f"test_2d_fold{fold+1}.pkl"
    
    train_dict = {
        "data": train_data, 
        "columns": grouped_dict["columns"], 
        "group_ids": [group_ids[i] for i in train_idx]
    }
    test_dict = {
        "data": test_data, 
        "columns": grouped_dict["columns"], 
        "group_ids": [group_ids[i] for i in test_idx]
    }
    
    with open(train_file_path, "wb") as f:
        pickle.dump(train_dict, f)
    with open(test_file_path, "wb") as f:
        pickle.dump(test_dict, f)
    
    print(f"Saved fold {fold+1}:")
    print("  Train participant IDs:", train_dict["group_ids"])
    print("  Test participant IDs:", test_dict["group_ids"])
    print("-" * 50)

# ===== Step 6: Verify No Participant Overlap Between Folds =====
fold_files = [f"train_2d_fold{i}.pkl" for i in range(1, 6)]
fold_participants = {}

for file in fold_files:
    with open(file, "rb") as f:
        fold_dict = pickle.load(f)
    participant_ids = set(fold_dict["group_ids"])
    fold_participants[file] = participant_ids
    print(f"{file} has participant IDs:", participant_ids)

print("=" * 50)
for i in range(len(fold_files)):
    for j in range(i + 1, len(fold_files)):
        common_ids = fold_participants[fold_files[i]].intersection(fold_participants[fold_files[j]])
        print(f"Common participants between {fold_files[i]} and {fold_files[j]}: {common_ids}")


=== Original Data ===
Total rows: 5351
Unique participant_ids (raw): ['combined_10' 'combined_17' 'combined_16' 'combined_11' 'combined_12'
 'combined_13' 'combined_14' 'combined_18' 'combined_20' 'combined_200'
 'combined_21' 'combined_22' 'combined_23' 'combined_24' 'combined_25'
 'combined_26' 'combined_27' 'combined_28' 'combined_29' 'combined_30'
 'combined_300' 'combined_31' 'combined_32' 'combined_33' 'combined_34'
 'combined_35' 'combined_36' 'combined_37' 'combined_38' 'combined_39'
 'combined_40' 'combined_41' 'combined_42' 'combined_43']
Counts per participant:
 participant_id
combined_18     177
combined_12     176
combined_30     175
combined_16     174
combined_20     172
combined_24     172
combined_13     172
combined_32     170
combined_36     170
combined_27     170
combined_31     170
combined_37     169
combined_35     169
combined_41     169
combined_38     169
combined_26     168
combined_300    168
combined_25     168
combined_11     166
combined_40     165
combi

In [5]:
for fold, (train_idx, test_idx) in enumerate(gkf.split(data, groups=group_ids)):
    train_group_ids = [group_ids[i] for i in train_idx]
    test_group_ids = [group_ids[i] for i in test_idx]
    
    # Check that there's no overlap within this fold:
    assert set(train_group_ids).isdisjoint(set(test_group_ids)), f"Leakage found in fold {fold+1}!"
    
    print(f"Fold {fold+1} - Train IDs: {train_group_ids}")
    print(f"Fold {fold+1} - Test IDs: {test_group_ids}")


Fold 1 - Train IDs: ['combined_10', 'combined_11', 'combined_12', 'combined_14', 'combined_16', 'combined_17', 'combined_18', 'combined_200', 'combined_21', 'combined_22', 'combined_23', 'combined_25', 'combined_26', 'combined_27', 'combined_28', 'combined_30', 'combined_300', 'combined_31', 'combined_32', 'combined_34', 'combined_35', 'combined_36', 'combined_37', 'combined_39', 'combined_40', 'combined_41', 'combined_42']
Fold 1 - Test IDs: ['combined_13', 'combined_20', 'combined_24', 'combined_29', 'combined_33', 'combined_38', 'combined_43']
Fold 2 - Train IDs: ['combined_10', 'combined_11', 'combined_13', 'combined_14', 'combined_16', 'combined_17', 'combined_20', 'combined_200', 'combined_21', 'combined_22', 'combined_24', 'combined_25', 'combined_26', 'combined_27', 'combined_29', 'combined_30', 'combined_300', 'combined_31', 'combined_33', 'combined_34', 'combined_35', 'combined_36', 'combined_38', 'combined_39', 'combined_40', 'combined_41', 'combined_43']
Fold 2 - Test IDs: 