diff --git a/sktime/datasets/_data_io.py b/sktime/datasets/_data_io.py index 4345cd45a4c..d00c5e92c7c 100644 --- a/sktime/datasets/_data_io.py +++ b/sktime/datasets/_data_io.py @@ -62,7 +62,6 @@ def _alias_mtype_check(return_type): return return_type -# time series classification data sets def _download_and_extract(url, extract_path=None): """Download and unzip datasets (helper function). @@ -151,6 +150,92 @@ def _list_available_datasets(extract_path, origin_repo=None): return datasets +def _cache_dataset(url, name, extract_path=None, repeats=1, verbose=False): + """Download and unzip datasets from multiple mirrors or fallback sources. + + If url is string, will attempt to download and unzip from url, to extract_path. + If url is list of str, will go through urls in order until a download succeeds. + + Parameters + ---------- + url : string or list of string + URL pointing to file to download + files are expected to be at f"{url}/{name}.zip" for a string url + or f"{url[i]}/{name}.zip" for a list of string urls + extract_path : string, optional (default: None) + path to extract downloaded zip to, None defaults + to sktime/datasets/data + repeats : int, optional (default: 1) + number of times to try downloading from each url + verbose : bool, optional (default: False) + whether to print progress + + Returns + ------- + extract_path : string or None + if successful, string containing the path of the extracted file + u : string + url from which the dataset was downloaded + repeat : int + number of times it took to download the dataset from u + If none of the attempts are successful, will raise RuntimeError + """ + if isinstance(url, str): + url = [url] + + for u in url: + name_url = f"{u}/{name}.zip" + for repeat in range(repeats): + if verbose: + print( # noqa: T201 + f"Downloading dataset {name} from {u} to {extract_path}" + f"(attempt {repeat} of {repeats} total). " + ) + + try: + _download_and_extract(name_url, extract_path=extract_path) + return extract_path, u, repeat + + except zipfile.BadZipFile: + if verbose: + if repeat < repeats - 1: + print( # noqa: T201 + "Download failed, continuing with next attempt. " + ) + else: + print( # noqa: T201 + "All attempts for mirror failed, " + "continuing with next mirror." + ) + + raise RuntimeError( + f"Dataset with name ={name} could not be downloaded from any of the mirrors." + ) + + +def _mkdir_if_not_exist(*path): + """Shortcut for making a directory if it does not exist. + + Parameters + ---------- + path : tuple of strings + Directory path to create + If multiple strings are given, they will be joined together + + Returns + ------- + os.path.join(*path) : string + Directory path created + """ + full_path = os.path.join(*path) + if not os.path.exists(full_path): + os.makedirs(full_path) + return full_path + + +CLASSIF_URLS = ["https://timeseriesclassification.com/ClassificationDownloads"] + + def _load_dataset(name, split, return_X_y, return_type=None, extract_path=None): """Load time series classification datasets (helper function). @@ -174,7 +259,10 @@ def _load_dataset(name, split, return_X_y, return_type=None, extract_path=None): "numpy2d"/"np2d"/"numpyflat": 2D np.ndarray (instance, time index) "pd-multiindex": pd.DataFrame with 2-level (instance, time) MultiIndex Exception is raised if the data cannot be stored in the requested type. - extract_path : todo author: please fill in docstring + extract_path : string, optional (default: None) + path to extract downloaded zip to + None defaults to sktime/datasets/data if the data exists there, otherwise + defaults to sktime/datasets/local_data and downloads data there Returns ------- @@ -185,46 +273,37 @@ def _load_dataset(name, split, return_X_y, return_type=None, extract_path=None): If return_X_y is False, y is appended to X instead. """ # Allow user to have non standard extract path - if extract_path is not None: - local_module = os.path.dirname(extract_path) - local_dirname = extract_path - else: - local_module = MODULE - local_dirname = "data" - - if not os.path.exists(os.path.join(local_module, local_dirname)): - os.makedirs(os.path.join(local_module, local_dirname)) - if name not in _list_available_datasets(extract_path): - if extract_path is None: - local_dirname = "local_data" - if not os.path.exists(os.path.join(local_module, local_dirname)): - os.makedirs(os.path.join(local_module, local_dirname)) - if name not in _list_available_datasets( - os.path.join(local_module, local_dirname) - ): - # Dataset is not already present in the datasets directory provided. - # If it is not there, download and install it. - url = ( - "https://timeseriesclassification.com/" - f"ClassificationDownloads/{name}.zip" - ) - # This also tests the validitiy of the URL, can't rely on the html - # status code as it always returns 200 - try: - _download_and_extract( - url, - extract_path=extract_path, - ) - except zipfile.BadZipFile as e: - raise ValueError( - f"Invalid dataset name ={name} is not available on extract path =" - f"{extract_path}. Nor is it available on " - f"https://timeseriesclassification.com/.", - ) from e - - return _load_provided_dataset( - name, split, return_X_y, return_type, local_module, local_dirname - ) + if extract_path is None: + # default for first check is sktime/datasets/data + check_path = os.path.join(MODULE, "data") + + def _get_data_from(path): + return _load_provided_dataset(name, split, return_X_y, return_type, path) + + # if the dataset exists in check_path = sktime/datasets/data, retrieve it from there + if name in _list_available_datasets(check_path): + return _get_data_from(check_path) + + # now we know the dataset is not in check_path + # so we need to check whether it is already in the download/cache path + # download path is extract_path/local_data, defaults to sktime/datasets/local_data + if extract_path is None: + extract_path = os.path.join(MODULE, "local_data") + + if name in _list_available_datasets(extract_path): + return _get_data_from(extract_path) + + # now we know the dataset is not in the download/cache path + # so we need to download it + _mkdir_if_not_exist(extract_path) + + # download the dataset from CLASSIF_URLS + # will try multiple mirrors if necessary + # if fails, will raise a RuntimeError + _cache_dataset(CLASSIF_URLS, name, extract_path=extract_path) + + # if we reach this, the data has been downloaded, now we can load it + return _get_data_from(extract_path) def _load_provided_dataset( @@ -232,8 +311,7 @@ def _load_provided_dataset( split=None, return_X_y=True, return_type=None, - local_module=MODULE, - local_dirname=DIRNAME, + extract_path=None, ): """Load baked in time series classification datasets (helper function). @@ -259,8 +337,8 @@ def _load_provided_dataset( "numpy2d"/"np2d"/"numpyflat": 2D np.ndarray (instance, time index) "pd-multiindex": pd.DataFrame with 2-level (instance, time) MultiIndex Exception is raised if the data cannot be stored in the requested type. - local_module: default = os.path.dirname(__file__), - local_dirname: default = "data" + extract_path: default = join(MODULE, DIRNAME) = os.path.dirname(__file__) + "/data" + path to extract downloaded zip to Returns ------- @@ -270,21 +348,24 @@ def _load_provided_dataset( The class labels for each time series instance in X If return_X_y is False, y is appended to X instead. """ + if extract_path is None: + extract_path = os.path.join(MODULE, DIRNAME) + if isinstance(split, str): split = split.upper() if split in ("TRAIN", "TEST"): fname = name + "_" + split + ".ts" - abspath = os.path.join(local_module, local_dirname, name, fname) + abspath = os.path.join(extract_path, name, fname) X, y = load_from_tsfile(abspath, return_data_type="nested_univ") # if split is None, load both train and test set elif split is None: fname = name + "_TRAIN.ts" - abspath = os.path.join(local_module, local_dirname, name, fname) + abspath = os.path.join(extract_path, name, fname) X_train, y_train = load_from_tsfile(abspath, return_data_type="nested_univ") fname = name + "_TEST.ts" - abspath = os.path.join(local_module, local_dirname, name, fname) + abspath = os.path.join(extract_path, name, fname) X_test, y_test = load_from_tsfile(abspath, return_data_type="nested_univ") X = pd.concat([X_train, X_test])