Skip to content

Commit

Permalink
Merge pull request #76 from richford/enh/datasets
Browse files Browse the repository at this point in the history
ENH: Add dataset fetchers
  • Loading branch information
richford committed Jun 22, 2021
2 parents a97a74b + 629377a commit 0e7e7d5
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 1 deletion.
183 changes: 183 additions & 0 deletions afqinsight/datasets.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""Generate samples of synthetic data sets or extract AFQ data."""
import hashlib
import numpy as np
import os
import os.path as op
import pandas as pd
import requests

from groupyr.transform import GroupRemover
from collections import namedtuple
from shutil import copyfile
from sklearn.preprocessing import LabelEncoder

from .transform import AFQDataFrameMapper

__all__ = ["load_afq_data", "output_beta_to_afq"]
DATA_DIR = op.join(op.expanduser("~"), ".afq-insight")


def load_afq_data(
Expand Down Expand Up @@ -324,3 +329,181 @@ def output_beta_to_afq(
OutputFiles = namedtuple("OutputFiles", "nodes_file subjects_file")

return OutputFiles(nodes_file=fn_nodes_out, subjects_file=fn_subjects_out)


def _download_url_to_file(url, output_fn, encoding="utf-8"):
fn_abs = op.abspath(output_fn)
base = op.splitext(fn_abs)[0]
os.makedirs(op.dirname(output_fn), exist_ok=True)

# check if file with *.md5 exists
if op.isfile(base + ".md5"):
with open(base + ".md5", "r") as md5file:
md5sum = md5file.read().replace("\n", "")
else:
md5sum = None
# compare MD5 hash
if (
op.isfile(fn_abs)
and hashlib.md5(open(fn_abs, "rb").read()).hexdigest() == md5sum
):
print(f"File {op.relpath(fn_abs)} exists.")
else:
print(f"Downloading {url} to {op.relpath(fn_abs)}.")
# Download from url and save to file
with requests.Session() as s:
download = s.get(url)
with open(fn_abs, "w") as fp:
fp.write(download.content.decode(encoding))

# Write MD5 checksum to file
with open(base + ".md5", "w") as md5file:
md5file.write(hashlib.md5(open(fn_abs, "rb").read()).hexdigest())


def _download_afq_dataset(dataset, data_home):
urls_files = {
"sarica": [
{
"url": "https://github.com/yeatmanlab/Sarica_2017/raw/gh-pages/data/nodes.csv",
"file": op.join(data_home, "sarica_data", "nodes.csv"),
},
{
"url": "https://github.com/yeatmanlab/Sarica_2017/raw/gh-pages/data/subjects.csv",
"file": op.join(data_home, "sarica_data", "subjects.csv"),
},
],
"weston_havens": [
{
"url": "https://yeatmanlab.github.io/AFQBrowser-demo/data/nodes.csv",
"file": op.join(data_home, "weston_havens_data", "nodes.csv"),
},
{
"url": "https://yeatmanlab.github.io/AFQBrowser-demo/data/subjects.csv",
"file": op.join(data_home, "weston_havens_data", "subjects.csv"),
},
],
}

for dict_ in urls_files[dataset]:
_download_url_to_file(dict_["url"], dict_["file"])


def fetch_sarica(data_home=None):
"""Fetch the ALS classification dataset from Sarica et al [1]_.
Parameters
----------
data_home : str, default=None
Specify another download and cache folder for the datasets. By default all
afq-insight data is stored in ‘~/.afq-insight’ subfolders.
Returns
-------
X : array-like of shape (48, 3600)
The feature samples.
y : array-like of shape (48,)
Target values.
groups : list of numpy.ndarray
feature indices for each feature group
feature_names : list of tuples
The multi-indexed columns of X
group_names : list of tuples
The multi-indexed groups of X
subjects : list
Subject IDs
classes : dict
Class labels for ALS diagnosis.
References
----------
.. [1] Alessia Sarica, et al.
"The Corticospinal Tract Profile in AmyotrophicLateral Sclerosis"
Human Brain Mapping, vol. 38, pp. 727-739, 2017
DOI: 10.1002/hbm.23412
"""
data_home = data_home if data_home is not None else DATA_DIR
_download_afq_dataset("sarica", data_home=data_home)
X, y, groups, feature_names, group_names, subjects, classes = load_afq_data(
workdir=op.join(data_home, "sarica_data"),
dwi_metrics=["md", "fa"],
target_cols=["class"],
label_encode_cols=["class"],
)

gr = GroupRemover(
select=["Right Cingulum Hippocampus", "Left Cingulum Hippocampus"],
groups=groups,
group_names=group_names,
)
X = gr.fit_transform(X)

groups = groups[:36]
group_names = [grp for grp in group_names if "Cingulum Hippocampus" not in grp[1]]
feature_names = [fn for fn in feature_names if "Cingulum Hippocampus" not in fn[1]]

return X, y, groups, feature_names, group_names, subjects, classes


def fetch_weston_havens(data_home=None):
"""Load the age prediction dataset from Weston-Havens [1]_.
Parameters
----------
data_home : str, default=None
Specify another download and cache folder for the datasets. By default all
afq-insight data is stored in ‘~/.afq-insight’ subfolders.
Returns
-------
X : array-like of shape (77, 3600)
The feature samples.
y : array-like of shape (77,) or (n_samples, n_targets), optional
Target values.
groups : list of numpy.ndarray
feature indices for each feature group
feature_names : list of tuples
The multi-indexed columns of X
group_names : list of tuples
The multi-indexed groups of X
subjects : list
Subject IDs
References
----------
.. [1] Jason D. Yeatman, Brian A. Wandell, & Aviv A. Mezer,
"Lifespan maturation and degeneration of human brain white matter"
Nature Communications, vol. 5:1, pp. 4932, 2014
DOI: 10.1038/ncomms5932
"""
data_home = data_home if data_home is not None else DATA_DIR
_download_afq_dataset("weston_havens", data_home=data_home)
X, y, groups, feature_names, group_names, subjects, classes = load_afq_data(
workdir=op.join(data_home, "weston_havens_data"),
dwi_metrics=["md", "fa"],
target_cols=["Age"],
)

gr = GroupRemover(
select=["Right Cingulum Hippocampus", "Left Cingulum Hippocampus"],
groups=groups,
group_names=group_names,
)
X = gr.fit_transform(X)

groups = groups[:36]
group_names = [grp for grp in group_names if "Cingulum Hippocampus" not in grp[1]]
feature_names = [fn for fn in feature_names if "Cingulum Hippocampus" not in fn[1]]

return X, y, groups, feature_names, group_names, subjects
35 changes: 34 additions & 1 deletion afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,47 @@
import os.path as op
import pandas as pd
import pytest
import tempfile

import afqinsight as afqi
from afqinsight.datasets import load_afq_data
from afqinsight.datasets import load_afq_data, fetch_sarica, fetch_weston_havens

data_path = op.join(afqi.__path__[0], "data")
test_data_path = op.join(data_path, "test_data")


def test_fetch():
X, y, groups, feature_names, group_names, subjects, _ = fetch_sarica()
assert X.shape == (48, 3600)
assert y.shape == (48,)
assert len(groups) == 36
assert len(feature_names) == 3600
assert len(group_names) == 36
assert len(subjects) == 48
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "sarica_data", "nodes.csv"))
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "sarica_data", "subjects.csv"))

X, y, groups, feature_names, group_names, subjects = fetch_weston_havens()
assert X.shape == (77, 3600)
assert y.shape == (77,)
assert len(groups) == 36
assert len(feature_names) == 3600
assert len(group_names) == 36
assert len(subjects) == 77
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "weston_havens_data", "nodes.csv"))
assert op.isfile(
op.join(afqi.datasets.DATA_DIR, "weston_havens_data", "subjects.csv")
)

with tempfile.TemporaryDirectory() as td:
_ = fetch_sarica(data_home=td)
_ = fetch_weston_havens(data_home=td)
assert op.isfile(op.join(td, "sarica_data", "nodes.csv"))
assert op.isfile(op.join(td, "sarica_data", "subjects.csv"))
assert op.isfile(op.join(td, "weston_havens_data", "nodes.csv"))
assert op.isfile(op.join(td, "weston_havens_data", "subjects.csv"))


def test_load_afq_data_smoke():
output = load_afq_data(
workdir=test_data_path,
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ install_requires =
matplotlib
numpy
pandas>=1.1.0
requests
seaborn
scikit-learn>=0.23.1,<0.24
sklearn_pandas>=2.0.0
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ deps =
matplotlib
numpy
pandas>=1.1.0
requests
seaborn
scipy>=1.2.0,<1.6.0
sklearn_pandas>=2.0.0
Expand Down

0 comments on commit 0e7e7d5

Please sign in to comment.