Skip to content

Commit

Permalink
Add enforce_sub_prefix option to AFQDataset and load_afq_data
Browse files Browse the repository at this point in the history
  • Loading branch information
richford committed Dec 2, 2021
1 parent c767485 commit 8352ece
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
39 changes: 38 additions & 1 deletion afqinsight/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def bundles2channels(X, n_nodes, n_channels, channels_last=True):
return output


def standardize_subject_id(sub_id):
"""Standardize subject ID to start with the prefix 'sub-'.
Parameters
----------
sub_id : str
subject ID.
Returns
-------
str
Standardized subject IDs.
"""
return sub_id if sub_id.startswith("sub-") else "sub-" + sub_id


def load_afq_data(
fn_nodes="nodes.csv",
fn_subjects="subjects.csv",
Expand All @@ -103,6 +119,7 @@ def load_afq_data(
unsupervised=False,
concat_subject_session=False,
return_bundle_means=False,
enforce_sub_prefix=True,
):
"""Load AFQ data from CSV, transform it, return feature matrix and target.
Expand Down Expand Up @@ -154,6 +171,12 @@ def load_afq_data(
If True, return diffusion metrics averaged along the length of each
bundle.
enforce_sub_prefix : bool, default=True
If True, standardize subject IDs to start with the prefix "sub-".
This is useful, for example, if the subject IDs in the nodex.csv file
have the sub prefix but the subject IDs in the subjects.csv file do
not. Default is True in order to comform to the BIDS standard.
Returns
-------
AFQData : namedtuple
Expand Down Expand Up @@ -219,7 +242,10 @@ def load_afq_data(
mapper = AFQDataFrameMapper(concat_subject_session=concat_subject_session)

X = mapper.fit_transform(nodes)
subjects = mapper.subjects_
subjects = [
standardize_subject_id(sub_id) if enforce_sub_prefix else sub_id
for sub_id in mapper.subjects_
]
groups = mapper.groups_
feature_names = mapper.feature_names_

Expand Down Expand Up @@ -250,6 +276,9 @@ def load_afq_data(
unnamed_cols = [col for col in targets.columns if "Unnamed:" in col]
targets.drop(unnamed_cols, axis="columns", inplace=True)

if enforce_sub_prefix:
targets.index = targets.index.map(standardize_subject_id)

# Drop subjects that are not in the dwi feature matrix
targets = pd.DataFrame(index=subjects).merge(
targets, how="left", left_index=True, right_index=True
Expand Down Expand Up @@ -382,6 +411,12 @@ class AFQDataset:
IDs with the session IDs. This is useful when subjects have multiple
sessions and you with to disambiguate between them.
enforce_sub_prefix : bool, default=True
If True, standardize subject IDs to start with the prefix "sub-".
This is useful, for example, if the subject IDs in the nodex.csv file
have the sub prefix but the subject IDs in the subjects.csv file do
not. Default is True in order to comform to the BIDS standard.
Attributes
----------
X : array-like of shape (n_samples, n_features)
Expand Down Expand Up @@ -424,6 +459,7 @@ def __init__(
index_col="subjectID",
unsupervised=False,
concat_subject_session=False,
enforce_sub_prefix=True,
):
afq_data = load_afq_data(
fn_nodes=fn_nodes,
Expand All @@ -434,6 +470,7 @@ def __init__(
index_col=index_col,
unsupervised=unsupervised,
concat_subject_session=concat_subject_session,
enforce_sub_prefix=enforce_sub_prefix,
)

self.X = afq_data.X
Expand Down
12 changes: 11 additions & 1 deletion afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
download_sarica,
download_weston_havens,
AFQDataset,
standardize_subject_id,
)

data_path = op.join(afqi.__path__[0], "data")
Expand All @@ -32,6 +33,11 @@ def test_bundles2channels():
bundles2channels(X0, n_nodes=1000, n_channels=7)


def test_standardize_subject_id():
assert standardize_subject_id("sub-01") == "sub-01"
assert standardize_subject_id("01") == "sub-01"


@pytest.mark.parametrize("target_cols", [["class"], ["age", "class"]])
def test_AFQDataset(target_cols):
sarica_dir = download_sarica()
Expand Down Expand Up @@ -148,7 +154,8 @@ def test_AFQDataset(target_cols):


@pytest.mark.parametrize("dwi_metrics", [["md", "fa"], None])
def test_fetch(dwi_metrics):
@pytest.mark.parametrize("enforce_sub_prefix", [True, False])
def test_fetch(dwi_metrics, enforce_sub_prefix):
sarica_dir = download_sarica()

with pytest.raises(ValueError):
Expand All @@ -167,6 +174,7 @@ def test_fetch(dwi_metrics):
dwi_metrics=dwi_metrics,
target_cols=["class"],
label_encode_cols=["class"],
enforce_sub_prefix=enforce_sub_prefix,
)

n_features = 16000 if dwi_metrics is None else 4000
Expand Down Expand Up @@ -259,6 +267,7 @@ def test_load_afq_data(dwi_metrics):
target_cols=["test_class"],
label_encode_cols=["test_class"],
return_bundle_means=False,
enforce_sub_prefix=False,
)

nodes = pd.read_csv(op.join(test_data_path, "nodes.csv"))
Expand All @@ -285,6 +294,7 @@ def test_load_afq_data(dwi_metrics):
target_cols=["test_class"],
label_encode_cols=["test_class"],
return_bundle_means=True,
enforce_sub_prefix=False,
)

means_ref = (
Expand Down

0 comments on commit 8352ece

Please sign in to comment.