Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add dataset fetchers #76

Merged
merged 5 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions afqinsight/datasets.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""Generate samples of synthetic data sets or extract AFQ data."""
import hashlib
import numpy as np
import os
import os.path as op
import pandas as pd
import requests

from groupyr.transform import GroupRemover
from collections import namedtuple
from shutil import copyfile
from sklearn.preprocessing import LabelEncoder

from .transform import AFQDataFrameMapper

__all__ = ["load_afq_data", "output_beta_to_afq"]
DATA_DIR = op.join(op.expanduser("~"), ".afq-insight")


def load_afq_data(
Expand Down Expand Up @@ -316,3 +321,181 @@ def output_beta_to_afq(
OutputFiles = namedtuple("OutputFiles", "nodes_file subjects_file")

return OutputFiles(nodes_file=fn_nodes_out, subjects_file=fn_subjects_out)


def _download_url_to_file(url, output_fn, encoding="utf-8"):
fn_abs = op.abspath(output_fn)
base = op.splitext(fn_abs)[0]
os.makedirs(op.dirname(output_fn), exist_ok=True)

# check if file with *.md5 exists
if op.isfile(base + ".md5"):
with open(base + ".md5", "r") as md5file:
md5sum = md5file.read().replace("\n", "")
else:
md5sum = None
# compare MD5 hash
if (
op.isfile(fn_abs)
and hashlib.md5(open(fn_abs, "rb").read()).hexdigest() == md5sum
):
print(f"File {op.relpath(fn_abs)} exists.")
else:
print(f"Downloading {url} to {op.relpath(fn_abs)}.")
# Download from url and save to file
with requests.Session() as s:
download = s.get(url)
with open(fn_abs, "w") as fp:
fp.write(download.content.decode(encoding))

# Write MD5 checksum to file
with open(base + ".md5", "w") as md5file:
md5file.write(hashlib.md5(open(fn_abs, "rb").read()).hexdigest())


def _download_afq_dataset(dataset, data_home):
urls_files = {
"sarica": [
{
"url": "https://github.com/yeatmanlab/Sarica_2017/raw/gh-pages/data/nodes.csv",
"file": op.join(data_home, "sarica_data", "nodes.csv"),
},
{
"url": "https://github.com/yeatmanlab/Sarica_2017/raw/gh-pages/data/subjects.csv",
"file": op.join(data_home, "sarica_data", "subjects.csv"),
},
],
"weston_havens": [
{
"url": "https://yeatmanlab.github.io/AFQBrowser-demo/data/nodes.csv",
"file": op.join(data_home, "weston_havens_data", "nodes.csv"),
},
{
"url": "https://yeatmanlab.github.io/AFQBrowser-demo/data/subjects.csv",
"file": op.join(data_home, "weston_havens_data", "subjects.csv"),
},
],
}

for dict_ in urls_files[dataset]:
_download_url_to_file(dict_["url"], dict_["file"])


def fetch_sarica(data_home=None):
"""Fetch the ALS classification dataset from Sarica et al [1]_.

Parameters
----------
data_home : str, default=None
Specify another download and cache folder for the datasets. By default all
afq-insight data is stored in ‘~/.afq-insight’ subfolders.

Returns
-------
X : array-like of shape (48, 3600)
The feature samples.

y : array-like of shape (48,)
Target values.

groups : list of numpy.ndarray
feature indices for each feature group

feature_names : list of tuples
The multi-indexed columns of X

group_names : list of tuples
The multi-indexed groups of X

subjects : list
Subject IDs

classes : dict
Class labels for ALS diagnosis.

References
----------
.. [1] Alessia Sarica, et al.
"The Corticospinal Tract Profile in AmyotrophicLateral Sclerosis"
Human Brain Mapping, vol. 38, pp. 727-739, 2017
DOI: 10.1002/hbm.23412
"""
data_home = data_home if data_home is not None else DATA_DIR
_download_afq_dataset("sarica", data_home=data_home)
X, y, groups, feature_names, group_names, subjects, classes = load_afq_data(
workdir=op.join(data_home, "sarica_data"),
dwi_metrics=["md", "fa"],
target_cols=["class"],
label_encode_cols=["class"],
)

gr = GroupRemover(
select=["Right Cingulum Hippocampus", "Left Cingulum Hippocampus"],
groups=groups,
group_names=group_names,
)
X = gr.fit_transform(X)

groups = groups[:36]
group_names = [grp for grp in group_names if "Cingulum Hippocampus" not in grp[1]]
feature_names = [fn for fn in feature_names if "Cingulum Hippocampus" not in fn[1]]

return X, y, groups, feature_names, group_names, subjects, classes


def fetch_weston_havens(data_home=None):
"""Load the age prediction dataset from Weston-Havens [1]_.

Parameters
----------
data_home : str, default=None
Specify another download and cache folder for the datasets. By default all
afq-insight data is stored in ‘~/.afq-insight’ subfolders.

Returns
-------
X : array-like of shape (77, 3600)
The feature samples.

y : array-like of shape (77,) or (n_samples, n_targets), optional
Target values.

groups : list of numpy.ndarray
feature indices for each feature group

feature_names : list of tuples
The multi-indexed columns of X

group_names : list of tuples
The multi-indexed groups of X

subjects : list
Subject IDs

References
----------
.. [1] Jason D. Yeatman, Brian A. Wandell, & Aviv A. Mezer,
"Lifespan maturation and degeneration of human brain white matter"
Nature Communications, vol. 5:1, pp. 4932, 2014
DOI: 10.1038/ncomms5932
"""
data_home = data_home if data_home is not None else DATA_DIR
_download_afq_dataset("weston_havens", data_home=data_home)
X, y, groups, feature_names, group_names, subjects, classes = load_afq_data(
workdir=op.join(data_home, "weston_havens_data"),
dwi_metrics=["md", "fa"],
target_cols=["Age"],
)

gr = GroupRemover(
select=["Right Cingulum Hippocampus", "Left Cingulum Hippocampus"],
groups=groups,
group_names=group_names,
)
X = gr.fit_transform(X)

groups = groups[:36]
group_names = [grp for grp in group_names if "Cingulum Hippocampus" not in grp[1]]
feature_names = [fn for fn in feature_names if "Cingulum Hippocampus" not in fn[1]]

return X, y, groups, feature_names, group_names, subjects
35 changes: 34 additions & 1 deletion afqinsight/tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,47 @@
import os.path as op
import pandas as pd
import pytest
import tempfile

import afqinsight as afqi
from afqinsight.datasets import load_afq_data
from afqinsight.datasets import load_afq_data, fetch_sarica, fetch_weston_havens

data_path = op.join(afqi.__path__[0], "data")
test_data_path = op.join(data_path, "test_data")


def test_fetch():
X, y, groups, feature_names, group_names, subjects, _ = fetch_sarica()
assert X.shape == (48, 3600)
assert y.shape == (48,)
assert len(groups) == 36
assert len(feature_names) == 3600
assert len(group_names) == 36
assert len(subjects) == 48
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "sarica_data", "nodes.csv"))
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "sarica_data", "subjects.csv"))

X, y, groups, feature_names, group_names, subjects = fetch_weston_havens()
assert X.shape == (77, 3600)
assert y.shape == (77,)
assert len(groups) == 36
assert len(feature_names) == 3600
assert len(group_names) == 36
assert len(subjects) == 77
assert op.isfile(op.join(afqi.datasets.DATA_DIR, "weston_havens_data", "nodes.csv"))
assert op.isfile(
op.join(afqi.datasets.DATA_DIR, "weston_havens_data", "subjects.csv")
)

with tempfile.TemporaryDirectory() as td:
_ = fetch_sarica(data_home=td)
_ = fetch_weston_havens(data_home=td)
assert op.isfile(op.join(td, "sarica_data", "nodes.csv"))
assert op.isfile(op.join(td, "sarica_data", "subjects.csv"))
assert op.isfile(op.join(td, "weston_havens_data", "nodes.csv"))
assert op.isfile(op.join(td, "weston_havens_data", "subjects.csv"))


def test_load_afq_data_smoke():
output = load_afq_data(
workdir=test_data_path,
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ install_requires =
matplotlib
numpy
pandas>=1.1.0
requests
seaborn
scikit-learn>=0.23.1,<0.24
sklearn_pandas>=2.0.0
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ deps =
matplotlib
numpy
pandas>=1.1.0
requests
seaborn
scipy>=1.2.0,<1.6.0
sklearn_pandas>=2.0.0
Expand Down