Skip to content

Commit

Permalink
Merge pull request #2993 from Cadair/sample_data_dir
Browse files Browse the repository at this point in the history
Improvements to sample and download data directories
  • Loading branch information
Cadair committed Apr 4, 2019
2 parents 71acc4c + c69ebcb commit dc0c8cf
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 105 deletions.
2 changes: 2 additions & 0 deletions changelog/2993.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
The default location of the sunpy sample data has changed to be in the platform
specific data directory as provided by `appdirs <https://github.com/ActiveState/appdirs>`__.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ install_requires =
matplotlib>=1.3
pandas
astropy>=3.1
parfive[ftp]

[options.extras_require]
database = sqlalchemy
Expand All @@ -36,7 +37,6 @@ net =
python-dateutil
zeep
tqdm
parfive[ftp]
asdf = asdf
tests =
hypothesis
Expand Down
23 changes: 23 additions & 0 deletions sunpy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib
import warnings
import importlib
import tempfile

import pytest

Expand Down Expand Up @@ -56,6 +57,28 @@ def undo_config_dir_patch():
os.environ["SUNPY_CONFIGDIR"] = oridir


@pytest.fixture(scope='session', autouse=True)
def tmp_dl_dir(request):
"""
Globally set the default download directory for the test run to a tmp dir.
"""
with tempfile.TemporaryDirectory() as tmpdir:
os.environ["SUNPY_DOWNLOADDIR"] = tmpdir
yield tmpdir
del os.environ["SUNPY_DOWNLOADDIR"]


@pytest.fixture()
def undo_download_dir_patch():
"""
Provide a way for certain tests to not have tmp download dir.
"""
oridir = os.environ["SUNPY_DOWNLOADDIR"]
del os.environ["SUNPY_DOWNLOADDIR"]
yield
os.environ["SUNPY_DOWNLOADDIR"] = oridir


def pytest_runtest_setup(item):
"""
pytest hook to skip all tests that have the mark 'remotedata' if the
Expand Down
139 changes: 52 additions & 87 deletions sunpy/data/_sample.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
# -*- coding: utf-8 -*-
"""SunPy sample data files"""
import socket
import os.path
import warnings
from shutil import move
from zipfile import ZipFile
from pathlib import Path
from collections import namedtuple
from urllib.parse import urljoin

from astropy.utils.data import download_file
import parfive

from sunpy.util.net import url_exists
from sunpy.util.config import get_and_create_sample_dir
from sunpy.util.exceptions import SunpyUserWarning

__author__ = "Steven Christe"
__email__ = "steven.christe@nasa.gov"

_base_urls = (
'http://data.sunpy.org/sunpy/v1/',
'https://github.com/sunpy/sample-data/raw/master/sunpy/v1/'
'https://github.com/sunpy/sample-data/raw/master/sunpy/v1/',
)

# Shortcut requirements:
Expand All @@ -31,6 +25,8 @@

# the files should include necessary extensions
_sample_files = {
# Do roll image first because it's the largest file.
"AIA_171_ROLL_IMAGE": "aiacalibim5.fits.gz",
"HMI_LOS_IMAGE": "HMI20110607_063211_los_lowres.fits",
"AIA_131_IMAGE": "AIA20110607_063301_0131_lowres.fits",
"AIA_171_IMAGE": "AIA20110607_063302_0171_lowres.fits",
Expand All @@ -50,7 +46,6 @@
# Not in the sample-data repo
# "RHESSI_EVENT_LIST": "hsi_calib_ev_20020220_1106_20020220_1106_25_40.fits",
"SWAP_LEVEL1_IMAGE": "swap_lv1_20110607_063329.fits",
"AIA_171_ROLL_IMAGE": "aiacalibim5.fits.gz",
"EVE_TIMESERIES": "20110607_EVE_L0CS_DIODES_1m.txt",
# Uncomment this if it needs to be used. Commented out to save bandwidth.
# "LYRA_LIGHTCURVE": ("lyra_20110810-000000_lev2_std.fits.gz", ,
Expand All @@ -63,96 +58,66 @@
"NORH_TIMESERIES": "tca110607.fits"
}

# Reverse the dict because we want to use it backwards, but it is nicer to
# write the other way around
_sample_files = {v: k for k, v in _sample_files.items()}

def download_sample_data(show_progress=True):
_error = namedtuple("error", ("filepath_partial", "url", "response"))

def download_sample_data(overwrite=False):
"""
Download all sample data at once. This will overwrite any existing files.
Parameters
----------
show_progress: `bool`
Show a progress bar during download
overwrite: `bool`
Overwrite existing sample data.
Returns
-------
None
"""
for file_name in _sample_files.value():
get_sample_file(file_name, show_progress=show_progress,
url_list=_base_urls, overwrite=True)
# Creating the directory for sample files to be downloaded
sampledata_dir = Path(get_and_create_sample_dir())

dl = parfive.Downloader(overwrite=overwrite)

def get_sample_file(filename, url_list, show_progress=True, overwrite=False,
timeout=None):
"""
Downloads a sample file. Will download a sample data file and move it to
the sample data directory. Also, uncompresses zip files if necessary.
Returns the local file if exists.
first_url = _base_urls[0]

Parameters
----------
filename: `str`
Name of the file
url_list: `str` or `list`
urls where to look for the file
show_progress: `bool`
Show a progress bar during download
overwrite: `bool`
If True download and overwrite an existing file.
timeout: `float`
The timeout in seconds. If `None` the default timeout is used from
`astropy.utils.data.Conf.remote_timeout`.
already_downloaded = []
for file_name in _sample_files.keys():
url = urljoin(first_url, file_name)
fname = sampledata_dir/file_name
# We have to avoid calling download if we already have all the files.
if fname.exists() and not overwrite:
already_downloaded.append(fname)
else:
dl.enqueue_file(url, filename=sampledata_dir/file_name)

Returns
-------
result: `str`
The local path of the file. None if it failed.
"""
if dl.queued_downloads:
results = dl.download()
else:
return already_downloaded

# Creating the directory for sample files to be downloaded
sampledata_dir = get_and_create_sample_dir()
if not results.errors:
return results

if filename[-3:] == 'zip':
uncompressed_filename = filename[:-4]
else:
uncompressed_filename = filename
# check if the (uncompressed) file exists
if not overwrite and os.path.isfile(os.path.join(sampledata_dir,
uncompressed_filename)):
return os.path.join(sampledata_dir, uncompressed_filename)
else:
# check each provided url to find the file
for base_url in url_list:
online_filename = filename
if base_url.count('github'):
online_filename += '?raw=true'
try:
url = urljoin(base_url, online_filename)
exists = url_exists(url)
if exists:
f = download_file(os.path.join(base_url, online_filename),
show_progress=show_progress,
timeout=timeout)
real_name, ext = os.path.splitext(f)

if ext == '.zip':
print("Unpacking: {}".format(real_name))
with ZipFile(f, 'r') as zip_file:
unzipped_f = zip_file.extract(real_name,
sampledata_dir)
os.remove(f)
move(unzipped_f, os.path.join(sampledata_dir,
uncompressed_filename))
return os.path.join(sampledata_dir,
uncompressed_filename)
else:
# move files to the data directory
move(f, os.path.join(sampledata_dir,
uncompressed_filename))
return os.path.join(sampledata_dir,
uncompressed_filename)
except (socket.error, socket.timeout) as e:
warnings.warn("Download failed with error {}. Retrying with different mirror.".format(e), SunpyUserWarning)
# if reach here then file has not been downloaded.
warnings.warn("File {} not found.".format(filename), SunpyUserWarning)
return None
for retry_url in _base_urls[1:]:
for i, err in enumerate(results.errors):
file_name = Path(err.url).name
# Overwrite the parfive error to change the url to a mirror
new_url = urljoin(retry_url, file_name)
results._errors[i] = _error(err.filepath_partial,
new_url,
err.response)

results = dl.retry(results)

if not results.errors:
return results

for err in results.errors:
file_name = Path(err.url).name
warnings.warn(f"File {file_name} not found.", SunpyUserWarning)

return results
18 changes: 13 additions & 5 deletions sunpy/data/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,23 @@
"""
import sys
from ._sample import _base_urls, _sample_files, get_sample_file
from pathlib import Path

from ._sample import _sample_files, download_sample_data

files = download_sample_data()

file_list = []
file_dict = {}
for _key in _sample_files:
f = get_sample_file(_sample_files[_key], _base_urls)
setattr(sys.modules[__name__], _key, f)
for f in files:
name = Path(f).name
_key = _sample_files.get(name, None)
if not _key:
continue

setattr(sys.modules[__name__], _key, str(f))
file_list.append(f)
file_dict.update({_key: f})
__doc__ += '* ``{}``\n'.format(_key)

__all__ = list(_sample_files.keys()) + ['file_dict', 'file_list']
__all__ = list(_sample_files.values()) + ['file_dict', 'file_list']
8 changes: 4 additions & 4 deletions sunpy/data/sunpyrc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ time_format = %Y-%m-%d %H:%M:%S
; Location to save download data to. Path should be specified relative to the
; SunPy working directory.
; Default value: data/
;download_dir = /tmp
download_dir = data

; Location where the sample data will be downloaded. Path should be specified
; relative to the SunPy working directory.
sample_dir = data/sample_data
; Location where the sample data will be downloaded. If not specified, will be
; downloaded to platform specific user data directory.
; The default directory is specified by appdirs (https://github.com/ActiveState/appdirs)
; sample_dir = /data/sample_data

;;;;;;;;;;;;
; Database ;
Expand Down
11 changes: 8 additions & 3 deletions sunpy/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def load_config():
if not config.has_option('database', 'url'):
config.set('database', 'url', "sqlite:///" + str(Path.home() / "sunpy" / "sunpydb.sqlite"))

# Set the download_dir to be relative to the working_dir
working_dir = Path(config.get('general', 'working_dir'))
sample_dir = Path(config.get('downloads', 'sample_dir'))
download_dir = Path(config.get('downloads', 'download_dir'))
config.set('downloads', 'sample_dir', str((working_dir / sample_dir).expanduser().resolve()))
config.set('downloads', 'download_dir', str((working_dir / download_dir).expanduser().resolve()))
sample_dir = config.get('downloads', 'sample_dir', fallback=dirs.user_data_dir)
config.set('downloads', 'sample_dir', Path(sample_dir).expanduser().resolve().as_posix())
config.set('downloads', 'download_dir', (working_dir / download_dir).expanduser().resolve().as_posix())

return config

Expand Down Expand Up @@ -71,6 +72,10 @@ def get_and_create_download_dir():
"""
Get the config of download directory and create one if not present.
"""
download_dir = os.environ.get('SUNPY_DOWNLOADDIR')
if download_dir:
return download_dir

download_dir = Path(sunpy.config.get('downloads', 'download_dir')).expanduser().resolve()
if not _is_writable_dir(download_dir):
raise RuntimeError(f'Could not write to SunPy downloads directory="{download_dir}"')
Expand Down
11 changes: 6 additions & 5 deletions sunpy/util/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import io
import os
import sys
from pathlib import Path

from sunpy import config
from sunpy.util.config import (get_and_create_sample_dir, get_and_create_download_dir,
CONFIG_DIR, print_config,
_find_config_files, _get_user_configdir, _is_writable_dir)
_find_config_files, _get_user_configdir, _is_writable_dir, dirs)

USER = os.path.expanduser('~')

Expand Down Expand Up @@ -36,7 +37,7 @@ def test_get_user_configdir(tmpdir, tmp_path, undo_config_dir_patch):
del os.environ["SUNPY_CONFIGDIR"]


def test_print_config_files():
def test_print_config_files(undo_download_dir_patch):
# TODO: Tidy this up.
stdout = sys.stdout
out = io.StringIO()
Expand All @@ -51,10 +52,10 @@ def test_print_config_files():
assert get_and_create_sample_dir() in printed


def test_get_and_create_download_dir():
def test_get_and_create_download_dir(undo_download_dir_patch):
# test default config
path = get_and_create_download_dir()
assert path == os.path.join(USER, 'sunpy', 'data')
assert Path(path) == Path(USER) / 'sunpy' / 'data'
# test updated config
new_path = os.path.join(USER, 'sunpy_data_here_please')
config.set('downloads', 'download_dir', new_path)
Expand All @@ -68,7 +69,7 @@ def test_get_and_create_download_dir():
def test_get_and_create_sample_dir():
# test default config
path = get_and_create_sample_dir()
assert path == os.path.join(USER, 'sunpy', 'data', 'sample_data')
assert Path(path) == Path(dirs.user_data_dir)
# test updated config
new_path = os.path.join(USER, 'sample_data_here_please')
config.set('downloads', 'sample_dir', new_path)
Expand Down

0 comments on commit dc0c8cf

Please sign in to comment.