Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add Monash Forecasting Repository data loader #4826

Merged
merged 9 commits into from Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions sktime/datasets/__init__.py
Expand Up @@ -32,6 +32,7 @@
"load_unit_test_tsf",
"load_solar",
"load_covid_3month",
"fetch_forecastingorg",
"write_panel_to_tsfile",
"write_dataframe_to_tsfile",
"write_ndarray_to_tsfile",
Expand All @@ -55,6 +56,7 @@
write_tabular_transformation_to_arff,
)
from sktime.datasets._single_problem_loaders import (
fetch_forecastingorg,
load_acsf1,
load_airline,
load_arrow_head,
Expand Down
22 changes: 17 additions & 5 deletions sktime/datasets/_data_io.py
Expand Up @@ -109,16 +109,24 @@ def _download_and_extract(url, extract_path=None):
)


def _list_available_datasets(extract_path):
def _list_available_datasets(extract_path, origin_repo=None):
"""Return a list of all the currently downloaded datasets.

To count as available, each directory <dir_name> in the extract_path must contain
files called <dir_name>_TRAIN.ts and <dir_name>_TEST.ts.
Forecastingorg datasets are in the format <dataset_name>.tsf while classification
are in the format <dataset_name>_TRAIN.ts and <dataset_name>_TEST.ts.
To count as available, each directory <dir_name>
in the extract_path must contain files called
1. <dir_name>_TRAIN.ts and <dir_name>_TEST.ts if datasets from classification repo.
2. <dir_name>.tsf if datasets from forecasting repo.

Parameters
----------
extract_path: string
root directory where to look for files, if None defaults to sktime/datasets/data
origin_repo: string, optional (default=None)
if None, returns all available classification datasets in extract_path,
if string (must be "forecastingorg"), returns all available
forecastingorg datasets in extract_path.

Returns
-------
Expand All @@ -134,8 +142,12 @@ def _list_available_datasets(extract_path):
sub_dir = os.path.join(data_dir, name)
if os.path.isdir(sub_dir):
all_files = os.listdir(sub_dir)
if name + "_TRAIN.ts" in all_files and name + "_TEST.ts" in all_files:
datasets.append(name)
if origin_repo == "forecastingorg":
if name + ".tsf" in all_files:
datasets.append(name)
else:
if name + "_TRAIN.ts" in all_files and name + "_TEST.ts" in all_files:
datasets.append(name)
return datasets


Expand Down
103 changes: 103 additions & 0 deletions sktime/datasets/_single_problem_loaders.py
Expand Up @@ -41,17 +41,21 @@
]

import os
import zipfile
from urllib.error import HTTPError, URLError
from warnings import warn

import numpy as np
import pandas as pd

from sktime.datasets._data_io import (
_download_and_extract,
_list_available_datasets,
_load_dataset,
_load_provided_dataset,
load_tsf_to_dataframe,
)
from sktime.datasets.tsf_dataset_names import tsf_all, tsf_all_datasets
from sktime.utils.validation._dependencies import _check_soft_dependencies

DIRNAME = "data"
Expand Down Expand Up @@ -1267,3 +1271,102 @@ def load_covid_3month(split=None, return_X_y=True):
"""
name = "Covid3Month"
return _load_dataset(name, split, return_X_y)


def fetch_forecastingorg(
name,
replace_missing_vals="NAN",
value_column_name="series_value",
return_type="default_tsf",
extract_path=None,
):
"""Fetch forecasting datasets from Monash Time Series Forecasting Archive.

Downloads and extracts dataset if not already downloaded. Data is assumed to be
in the standard .tsf format. See https://forecastingdata.org/ for more details.

Parameters
----------
name: str
Name of data set. If a dataset that is listed in tsf_all_dataset is given,
this function will look in the extract_path first, and if it is not present,
attempt to download the data from https://forecastingdata.org/, saving it to
the extract_path.
replace_missing_vals: str, default="NAN"
A term to indicate the missing values in series in the returning dataframe.
value_column_name: str, default="series_value"
Any name that is preferred to have as the name of the column containing series
values in the returning dataframe.
return_type : str - "pd_multiindex_hier" (default), "default_tsf", or valid sktime
mtype string for in-memory data container format specification of the
return type:
- "pd_multiindex_hier" = pd.DataFrame of sktime type `pd_multiindex_hier`
- "default_tsf" = container that faithfully mirrors tsf format from the original
implementation in: https://github.com/rakshitha123/TSForecasting/
blob/master/utils/data_loader.py.
- other valid mtype strings are Panel or Hierarchical mtypes in
datatypes.MTYPE_REGISTER. If Panel or Hierarchical mtype str is given, a
conversion to that mtype will be attempted
For tutorials and detailed specifications, see
examples/AA_datatypes_and_datasets.ipynb
extract_path : str, optional (default=None)
the path to look for the data. If no path is provided, the function
looks in `sktime/datasets/data/`. If a path is given, it can be absolute,
e.g. C:/Temp or relative, e.g. Temp or ./Temp.

Returns
-------
loaded_data: pd.DataFrame
The converted dataframe containing the time series.
metadata: dict
The metadata for the forecasting problem. The dictionary keys are:
"frequency", "forecast_horizon", "contain_missing_values",
"contain_equal_length"
"""
# Allow user to have non standard extract path
if extract_path is not None:
local_module = os.path.dirname(extract_path)
local_dirname = extract_path
else: # this is the default path for downloaded dataset
local_module = MODULE
local_dirname = DIRNAME

if not os.path.exists(os.path.join(local_module, local_dirname)):
os.makedirs(os.path.join(local_module, local_dirname))

path_to_data_dir = os.path.join(local_module, local_dirname)
# TODO should create a function to check if dataset exists
if name not in _list_available_datasets(path_to_data_dir, "forecastingorg"):
# Dataset is not already present in the datasets directory provided.
# If it is not there, download and install it.

# TODO: create a registry function to lookup
# valid dataset names for classification, regression, forecasting datasets repo
if name not in list(tsf_all_datasets):
raise ValueError(
{name}
+ " is not a valid dataset name. \
List of valid dataset names can be found at \
sktime.datasets.tsf_dataset_names.tsf_all_datasets"
)

url = f"https://zenodo.org/record/{tsf_all[name]}/files/{name}.zip"

# This also tests the validitiy of the URL, can't rely on the html
# status code as it always returns 200
try:
_download_and_extract(
url,
extract_path=path_to_data_dir,
)
except zipfile.BadZipFile as e:
raise ValueError(
f"Invalid dataset name ={name} is not available on extract path ="
f"{extract_path}. Nor is it available on "
f"https://forecastingdata.org/.",
) from e

path_to_file = os.path.join(path_to_data_dir, f"{name}/{name}.tsf")
return load_tsf_to_dataframe(
path_to_file, replace_missing_vals, value_column_name, return_type
)
12 changes: 12 additions & 0 deletions sktime/datasets/data/UnitTest/UnitTest.tsf
@@ -0,0 +1,12 @@
# Dataset Information
# This is a dummy dataset used to test tsf file format
#
@relation Unknown
@attribute series_name string
@attribute start_timestamp date
@frequency yearly
@horizon 4
@missing false
@equallength false
@data
T1:1979-01-01 00-00-00:25092.2284,24271.5134,25828.9883,27697.5047,27956.2276,29924.4321,30216.8321,32613.4968,36053.1674,38472.7532,38420.894,36555.6156,37385.6371,38431.9699,40345.33
19 changes: 19 additions & 0 deletions sktime/datasets/tests/test_data_io.py
Expand Up @@ -28,6 +28,7 @@
from sktime.datasets._data_io import (
MODULE,
_convert_tsf_to_hierarchical,
_list_available_datasets,
_load_provided_dataset,
)
from sktime.datatypes import check_is_mtype, scitype_to_mtype
Expand Down Expand Up @@ -1463,3 +1464,21 @@ def test_convert_tsf_to_multiindex(freq):
_convert_tsf_to_hierarchical(input_df, metadata, freq=freq),
check_dtype=False,
)


@pytest.mark.parametrize("origin_repo", [None, "forecastingorg"])
def test_list_available_datasets(origin_repo):
"""Test function for listing available datasets.

check for two datasets repo format types:
1. https://www.timeseriesclassification.com/
2 https://forecastingdata.org/

"""
dataset_name = "UnitTest"
available_datasets = _list_available_datasets(
extract_path=None, origin_repo=origin_repo
)
assert (
dataset_name in available_datasets
), f"{dataset_name} dataset should be available." # noqa: E501
38 changes: 38 additions & 0 deletions sktime/datasets/tests/test_single_problem_loaders.py
@@ -1,9 +1,12 @@
"""Test single problem loaders with varying return types."""
from urllib.request import Request, urlopen

import numpy as np
import pandas as pd
import pytest

from sktime.datasets import ( # Univariate; Unequal length; Multivariate
fetch_forecastingorg,
load_acsf1,
load_arrow_head,
load_basic_motions,
Expand All @@ -14,6 +17,7 @@
load_UCR_UEA_dataset,
load_unit_test,
)
from sktime.datasets.tsf_dataset_names import tsf_all, tsf_all_datasets

UNIVARIATE_PROBLEMS = [
load_acsf1,
Expand Down Expand Up @@ -93,3 +97,37 @@ def test_load_UEA():

for mult_name in mult_names:
load_UCR_UEA_dataset(mult_name)


def test_fetch_forecastingorg():
"""Test loading downloaded dataset from forecasting.org."""
file = "UnitTest"
loaded_datasets, metadata = fetch_forecastingorg(name=file)
assert len(loaded_datasets) == 1
assert metadata["frequency"] == "yearly"
assert metadata["forecast_horizon"] == 4
assert metadata["contain_missing_values"] is False
assert metadata["contain_equal_length"] is False


@pytest.mark.parametrize("name", tsf_all_datasets)
def test_check_link_downloadable(name):
"""Test dataset URL from forecasting.org is downloadable and exits."""
url = f"https://zenodo.org/record/{tsf_all[name]}/files/{name}.zip"

# Send a GET request to check if the link exists without downloading the file
# response = requests.get(url, stream=True)
req = Request(url, method="HEAD")
response = urlopen(req)

# Check if the response status code is 200 (OK)
assert (
response.status == 200
), f"URL is not valid or does not exist. Error code {response.status}."

# Check if the response headers indicate that the content is downloadable
content_type = response.headers.get("Content-Type")
content_disposition = response.headers.get("Content-Disposition")

assert "application/octet-stream" in content_type, "URL is not downloadable."
assert "attachment" in content_disposition, "URL is not downloadable."
66 changes: 66 additions & 0 deletions sktime/datasets/tsf_dataset_names.py
@@ -0,0 +1,66 @@
"""
A dictionary of all datasets in Monash Forecasting Repository.

Dictionary keys are dataset names, values are dataset record number.
"""

tsf_all = {
"m1_yearly_dataset": 4656193,
"m1_quarterly_dataset": 4656154,
"m1_monthly_dataset": 4656159,
"m3_yearly_dataset": 4656222,
"m3_quarterly_dataset": 4656262,
"m3_monthly_dataset": 4656298,
"m3_other_dataset": 4656335,
"m4_yearly_dataset": 4656379,
"m4_quarterly_dataset": 4656410,
"m4_monthly_dataset": 4656480,
"m4_weekly_dataset": 4656522,
"m4_daily_dataset": 4656548,
"m4_hourly_dataset": 4656589,
"tourism_yearly_dataset": 4656103,
"tourism_quarterly_dataset": 4656093,
"tourism_monthly_dataset": 4656096,
"cif_2016_dataset": 4656042,
"london_smart_meters_dataset_with_missing_values": 4656072,
"london_smart_meters_dataset_without_missing_values": 4656091,
"australian_electricity_demand_dataset": 4659727,
"wind_farms_minutely_dataset_with_missing_values": 4654909,
"wind_farms_minutely_dataset_without_missing_values": 4654858,
"dominick_dataset": 4654802,
"bitcoin_dataset_with_missing_values": 5121965,
"bitcoin_dataset_without_missing_values": 5122101,
"pedestrian_counts_dataset": 4656626,
"vehicle_trips_dataset_with_missing_values": 5122535,
"vehicle_trips_dataset_without_missing_values": 5122537,
"kdd_cup_2018_dataset_with_missing_values": 4656719,
"kdd_cup_2018_dataset_without_missing_values": 4656756,
"weather_dataset": 4654822,
"nn5_daily_dataset_with_missing_values": 4656110,
"nn5_daily_dataset_without_missing_values": 4656117,
"nn5_weekly_dataset": 4656125,
"kaggle_web_traffic_dataset_with_missing_values": 4656080,
"kaggle_web_traffic_dataset_without_missing_values": 4656075,
"kaggle_web_traffic_weekly_dataset": 4656664,
"solar_10_minutes_dataset": 4656144,
"solar_weekly_dataset": 4656151,
"solar_4_seconds_dataset": 4656027,
"electricity_hourly_dataset": 4656140,
"electricity_weekly_dataset": 4656141,
"car_parts_dataset_with_missing_values": 4656022,
"car_parts_dataset_without_missing_values": 4656021,
"fred_md_dataset": 4654833,
"traffic_hourly_dataset": 4656132,
"traffic_weekly_dataset": 4656135,
"rideshare_dataset_with_missing_values": 5122114,
"rideshare_dataset_without_missing_values": 5122232,
"hospital_dataset": 4656014,
"covid_deaths_dataset": 4656009,
"sunspot_dataset_with_missing_values": 4654773,
"sunspot_dataset_without_missing_values": 4654722,
"saugeenday_dataset": 4656058,
"us_births_dataset": 4656049,
"wind_4_seconds_dataset": 4656032,
}

tsf_all_datasets = list(tsf_all.keys())