In [1]:
import os
import numpy as np
import pandas as pd

np.random.seed(1)

In [2]:
CONNECTOMES_ROOT = None

In [3]:
# measurement = one session of one given subject
connectome_measurements = set(tuple(fname.split("-")[1::2]) for fname in os.listdir(CONNECTOMES_ROOT))
len(connectome_measurements)

674

In [4]:
combined_df = pd.read_csv("combined.tsv", sep="\t")
combined_df.head(5)

Unnamed: 0,participant_id,session_id,scan_number,singleton,sedation,birth_age,scan_age,sex,birth_weight,head_circumference_scan,...,qc_fmri_fieldmap_type,qc_fmri_dvars_z,qc_fmri_tsnr_z,qc_fmri_mcdc2sbref_z,qc_fmri_sbref2struct_z,qc_fmri_fmap2struct_z,qc_fmri_standard2struct_z,qc_fmri_flagged,qc_fmri_comment,qc_smri_pipeline_status
0,CC00050XX01,7201,1,S,False,43.0,43.29,female,3.91,37.0,...,,,,,,,,True,failed fmri recon,full
1,CC00051XX02,7702,1,S,False,39.857143,40.0,female,3.31,35.0,...,,,,,,,,True,failed fmri recon,full
2,CC00052XX03,8300,1,S,False,38.0,38.71,female,2.64,33.0,...,,,,,,,,True,failed fmri recon,full
3,CC00053XX04,8607,1,S,False,40.0,40.43,female,3.46,32.0,...,,,,,,,,True,failed fmri recon,full
4,CC00054XX05,8800,1,S,False,41.857143,42.14,male,3.69,35.0,...,,,,,,,,True,failed fmri recon,full


In [5]:
def add_split_info(df: pd.DataFrame, task: str = "birth_age") -> None:
    """The following columns will be added to `df`:
       - `meas_in_{task}_{split}` -> measurement (exactly the sub,ses pair) is in the given split of the task
       - `subj_in_{task}_{split}` -> subj is part of the given split, but not necesarrily with this session
    """
    for split in ["train", "val", "test"]:
        df[[f"meas_in_{task}_{split}", f"subj_in_{task}_{split}"]] = False # add new columns
        # get measurements in split as set of (sub, ses) tuples, and also a set of just the sub-s.
        split_df = pd.read_csv(f"splits/{task}_{split}.txt", header=None)[0].str.split("[-_]", expand=True)[[1, 3]]
        split_measurements = set(tup[1:] for tup in split_df.itertuples(name=None))
        split_subjects = set(sub for sub, _ in split_measurements)
        assert len(split_measurements) == len(split_subjects), "There must be at most one session for each subject"
        # add flags related to measurement and subect 
        for i, row in df.iterrows():
            if (row.participant_id, str(row.session_id)) in split_measurements:
                df.loc[i, f"meas_in_{task}_{split}"] = True
            if row.participant_id in split_subjects:
                df.loc[i, f"subj_in_{task}_{split}"] = True

In [6]:
add_split_info(combined_df, task="scan_age")
add_split_info(combined_df, task="birth_age")

In [7]:
def add_has_connectome(df: pd.DataFrame, connectome_measurements: set[tuple[str]]) -> None:
    df["has_connectome"] = False
    for i, row in df.iterrows():
        if (row.participant_id, str(row.session_id)) in connectome_measurements:
            df.loc[i, "has_connectome"] = True

In [8]:
add_has_connectome(combined_df, connectome_measurements)

In [9]:
# no. of measurements that are both in the original split and have connectome
for task in ["birth_age", "scan_age"]:
    for split in ["train", "val", "test"]:
        num = (combined_df["has_connectome"] & combined_df[f"meas_in_{task}_{split}"]).sum()
        print(f"{task}\t{split}:\t{num}")

birth_age	train:	371
birth_age	val:	51
birth_age	test:	49
scan_age	train:	387
scan_age	val:	52
scan_age	test:	51


In [10]:
# no. of measurements/unique subjects that were in neither split but have connectome
for task in ["birth_age", "scan_age"]:
    mask = combined_df["has_connectome"].copy() # without the copy() it will overwrite the column in the DF
    for split in ["train", "val", "test"]: # mask: not part of any split
        mask &= ~combined_df[f"subj_in_{task}_{split}"]
    print(f"{task}\t#meas:\t{mask.sum()}")
    n_unique = combined_df.loc[mask, "participant_id"].nunique()
    print(f"{task}\t#subs:\t{n_unique}")


birth_age	#meas:	149
birth_age	#subs:	136
scan_age	#meas:	139
scan_age	#subs:	118


In [11]:
def assign_splits(df: pd.DataFrame, task: str, n_samples: list[int]) -> None:
    # mask: connectomes whose subject isn't in any original split of that task
    # (=admissible measurements)
    mask = df["has_connectome"].copy()
    for split in ["train", "val", "test"]:
        mask &= ~df[f"subj_in_{task}_{split}"]

    # additional_candidate = selected measurements from the pool of admissible measurements (=selected measurements)
    # when we have >1 sessions per subject, we select the candidate session for the given subject based on the task:
    # BA - we select the later measurement, SA - the earlier one
    sa_group = df[mask].groupby("participant_id")["scan_age"]
    additional_candidate_index = sa_group.idxmax() if task == "birth_age" else sa_group.idxmin()

    # we assign the additional candidates
    # assert sum(n_samples) == len(additional_candidate_index)
    split_idx = np.array( ["train"]*n_samples[0] + ["val"]*n_samples[1] + ["test"]*n_samples[2] )
    np.random.shuffle(split_idx)
    split_dummies = pd.get_dummies(split_idx, prefix=f"conn_{task}")
    df[split_dummies.columns] = False # adds columns: 'conn_{task}_train', 'conn_{task}_val', 'conn_{task}_test'
    df.loc[additional_candidate_index, split_dummies.columns] = split_dummies.values # name the dummies according to new_cols

    # we assign the rest
    for split in ["train", "val", "test"]:
        # connectomes whose measurement was in an original split get added automatically
        is_conn_in_original = df["has_connectome"] & df[f"meas_in_{task}_{split}"]
        df.loc[is_conn_in_original, f"conn_{task}_{split}"] = True


In [22]:
split_additional_samples = {
    "birth_age": [115, 9, 12],
    "scan_age": [99, 9, 10]
}

for task in ["birth_age", "scan_age"]:
    assign_splits(combined_df, task=task, n_samples=split_additional_samples[task])
    # pretty print
    n_total_in_split = combined_df.loc[ :, combined_df.columns.str.contains(f"conn_{task}") ].values.sum(axis=None)
    for split in ["train", "val", "test"]:
        samples = combined_df[f"conn_{task}_{split}"].sum()
        ratio = samples / n_total_in_split
        print(f"{task}\t{split}\t{samples}\t[{ratio:.2f}]")


birth_age	train	486	[0.80]
birth_age	val	60	[0.10]
birth_age	test	61	[0.10]
scan_age	train	486	[0.80]
scan_age	val	61	[0.10]
scan_age	test	61	[0.10]


In [36]:
def save_split():
    for task in ["birth_age", "scan_age"]:
        for split in ["train", "val", "test"]:  
            split_df = combined_df[combined_df[f"conn_{task}_{split}"]]
            series = "sub-" + split_df["participant_id"] + "_ses-" + split_df["session_id"].astype(str)
            series.to_csv(f"splits/connectome_{task}_{split}.txt", header=None, index=None)          

In [37]:
save_split()