Skip to content

Commit

Permalink
[ENH] set mirrors for time series classification data loaders (#5260)
Browse files Browse the repository at this point in the history
This sets the mirrors for time series classification data loaders:

1. to the new URL of the UEA repository (the one that is constantly
changing)
2. to the mirror URL on the sktime GitHub, as a backup

Testing of the download utility is re-enabled.

Related: #4754, #4749
  • Loading branch information
fkiraly committed Sep 17, 2023
1 parent 9896650 commit 894a031
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
10 changes: 8 additions & 2 deletions sktime/datasets/_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def _mkdir_if_not_exist(*path):
return full_path


CLASSIF_URLS = ["https://timeseriesclassification.com/ClassificationDownloads"]
CLASSIF_URLS = [
"https://timeseriesclassification.com/aeon-toolkit", # main mirror (UEA)
"https://github.com/sktime/sktime-datasets/raw/main/TSC", # backup mirror (sktime)
]


def _load_dataset(name, split, return_X_y, return_type=None, extract_path=None):
Expand Down Expand Up @@ -290,12 +293,15 @@ def _get_data_from(path):
if extract_path is None:
extract_path = os.path.join(MODULE, "local_data")

# in either case below, we need to ensure the directory exists
_mkdir_if_not_exist(extract_path)

# search if the dataset is already in the extract path after download
if name in _list_available_datasets(extract_path):
return _get_data_from(extract_path)

# now we know the dataset is not in the download/cache path
# so we need to download it
_mkdir_if_not_exist(extract_path)

# download the dataset from CLASSIF_URLS
# will try multiple mirrors if necessary
Expand Down
3 changes: 0 additions & 3 deletions sktime/datasets/tests/test_single_problem_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ def test_load_numpy2d_multivariate_raises(loader):
X, y = loader(return_type="numpy2d")


@pytest.mark.xfail(
reason="repeated upstream location failures, see 4754. xfail until fixed."
)
def test_load_UEA():
"""Test loading of a random subset of the UEA data, to check API."""
from sktime.datasets.tsc_dataset_names import multivariate, univariate
Expand Down

0 comments on commit 894a031

Please sign in to comment.