Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EHN] Alternative design for dataset as an object #5270

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions sktime/datasets/_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
MODULE = os.path.dirname(__file__)


# TODO: if user gives numpyflat or numpy3D this code will break
# Return appropriate return_type in case an alias was used
def _alias_mtype_check(return_type):
if return_type is None:
Expand All @@ -62,6 +63,7 @@ def _alias_mtype_check(return_type):
return return_type


# TODO: Can depreciate this function in favor of dataset object
def _download_and_extract(url, extract_path=None):
"""Download and unzip datasets (helper function).

Expand Down Expand Up @@ -108,6 +110,7 @@ def _download_and_extract(url, extract_path=None):
)


# TODO: Can depreciate this function in favor of dataset object
def _list_available_datasets(extract_path, origin_repo=None):
"""Return a list of all the currently downloaded datasets.

Expand Down Expand Up @@ -150,6 +153,7 @@ def _list_available_datasets(extract_path, origin_repo=None):
return datasets


# TODO: Can depreciate this function in favor of dataset object
def _cache_dataset(url, name, extract_path=None, repeats=1, verbose=False):
"""Download and unzip datasets from multiple mirrors or fallback sources.

Expand Down Expand Up @@ -213,6 +217,7 @@ def _cache_dataset(url, name, extract_path=None, repeats=1, verbose=False):
)


# TODO: Can depreciate this function in favor of dataset object
def _mkdir_if_not_exist(*path):
"""Shortcut for making a directory if it does not exist.

Expand All @@ -239,6 +244,7 @@ def _mkdir_if_not_exist(*path):
]


# TODO: Can depreciate this function in favor of dataset object
def _load_dataset(name, split, return_X_y, return_type=None, extract_path=None):
"""Load time series classification datasets (helper function).

Expand Down Expand Up @@ -312,6 +318,7 @@ def _get_data_from(path):
return _get_data_from(extract_path)


# TODO: Can depreciate this function in favor of dataset object
def _load_provided_dataset(
name,
split=None,
Expand Down
6 changes: 6 additions & 0 deletions sktime/datasets/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Base Class for datasets."""

__all__ = ["BaseDataset"]


from sktime.datasets.base._base import BaseDataset
111 changes: 111 additions & 0 deletions sktime/datasets/base/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import shutil
import tempfile
import zipfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Tuple
from urllib.request import urlretrieve

import pandas as pd

from sktime.datasets.base._metadata import BaseDatasetMetadata


class BaseDataset(ABC):
"""Base class for all sktime datasets."""

def __init__(
self,
metadata: BaseDatasetMetadata,
save_dir: str,
return_data_type: str,
) -> None:
super().__init__()
self._metadata = metadata
self._save_dir = save_dir
self._return_data_type = return_data_type

@property
def save_dir(self):
"""Return the save directory."""
return self._save_dir

@abstractmethod
def _load(self) -> Tuple[pd.DataFrame, pd.Series]:
"""Load the dataset."""
raise NotImplementedError()

def load(self) -> Tuple[pd.DataFrame, pd.Series]:
"""Load the dataset. If not exists, download it first."""
self.download()
return self._load()

def download(self, repeats: Optional[int] = 2, verbose: Optional[bool] = False):
"""Download the dataset."""
name = self._metadata.name
format = self._metadata.download_file_format
urls = [f"{url}/{name}.{format}" for url in self._metadata.url]
if not self._save_dir.exists():
self._save_dir.mkdir(parents=True, exist_ok=True)
self._fallback_download(urls, repeats, verbose)

def delete(self):
"""Delete the dataset."""
shutil.rmtree(self._save_dir)

def _download_extract(self, url: str) -> None:
"""Download zip file to a temp directory and extract it."""
temp_dir = tempfile.mkdtemp() # create a temp directory
zip_file_save_to = Path(temp_dir, self._metadata.name)
urlretrieve(url, zip_file_save_to)
try:
zipfile.ZipFile(zip_file_save_to, "r").extractall(self._save_dir)
except zipfile.BadZipFile:
raise zipfile.BadZipFile(
"Could not unzip dataset. Please make sure the URL is valid."
)
finally:
shutil.rmtree(temp_dir) # delete temp directory with all its contents

def _fallback_download(self, urls, repeats, verbose) -> None:
"""Download the dataset from a fallback URL."""
for url in urls:
for repeat in range(repeats):
if verbose:
print( # noqa: T201
f"Downloading dataset {self._metadata.name} from {url} "
f"to {self._save_dir} (attempt {repeat} of {repeats} total). "
)
try:
self._download_extract(url)
return # exit loop when download is successful
except Exception as error:
if verbose:
if repeat < repeats - 1:
print( # noqa: T201
"Download failed, continuing with next attempt. "
)
else:
print( # noqa: T201
"All attempts for mirror failed, "
"continuing with next mirror."
)
shutil.rmtree(self._save_dir) # delete directory
print(f"Exception occurred: {error}") # noqa: T201


class CSVDatasetLoader(BaseDataset):
"""Base class for .csv format datasets."""

def __init__(self, metadata, save_dir):
super().__init__(metadata, save_dir)

def _load(self):
"""Load the dataset."""
dataset = pd.read_csv(self._save_dir)
return self._preprocess(dataset)

@abstractmethod
def _preprocess(self, dataset):
"""Preprocess the dataset."""
pass
28 changes: 28 additions & 0 deletions sktime/datasets/base/_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Metadata for sktime datasets."""
from dataclasses import dataclass


@dataclass
class BaseDatasetMetadata:
"""Base class for dataset metadata."""

name: str
task_type: str # classification, regression, forecasting
download_file_format: str # zip, .ts, .arff, .csv
# is_univariate: bool
# dimensions: tuple


@dataclass
class ExternalDatasetMetadata(BaseDatasetMetadata):
"""Metadata for external datasets."""

url: list[str]
citation: str


@dataclass
class ForecastingDatasetMetadata(ExternalDatasetMetadata):
"""Metadata for forecasting datasets."""

record_number: int
1 change: 1 addition & 0 deletions sktime/datasets/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Classifaction Datasets."""
68 changes: 68 additions & 0 deletions sktime/datasets/classification/tsc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Time series classification datasets."""

from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd

from sktime.datasets._data_io import load_from_tsfile
from sktime.datasets.base._base import BaseDataset
from sktime.datasets.base._metadata import ExternalDatasetMetadata

DEFAULT_PATH = Path.cwd().parent / "data"
CITATION = ""


class TSCDatasetLoader(BaseDataset):
"""Classification dataset from UCR UEA time series archive."""

def __init__(
self,
name: str,
split: Optional[str] = None,
save_dir: Optional[str] = None,
return_data_type: str = "nested_univ",
):
metadata = ExternalDatasetMetadata(
name=name,
task_type="classification",
url=[
"https://timeseriesclassification.com/aeon-toolkit",
"https://github.com/sktime/sktime-datasets/raw/main/TSC",
],
download_file_format="zip",
citation=CITATION,
)
if save_dir is None:
save_dir = Path(DEFAULT_PATH, name)
else:
save_dir = Path(save_dir, name)
super().__init__(metadata, save_dir, return_data_type)
self._split = split.upper() if split is not None else split

def _load_from_file(self, split: str):
"""Load .ts format dataset."""
file_path = Path(self.save_dir, f"{self._metadata.name}_{split}.ts")
X, y = load_from_tsfile(
full_file_path_and_name=file_path, return_data_type=self._return_data_type
)
return X, y

def _preprocess(self, X_train, y_train, X_test, y_test):
"""Preprocess the dataset."""
X = pd.concat([X_train, X_test])
X = X.reset_index(drop=True)
y = np.concatenate([y_train, y_test])
return X, y

def _load(self):
"""Load the dataset into memory."""
X_train, y_train = self._load_train_test("TRAIN")
if self._split == "TRAIN":
return X_train, y_train
X_test, y_test = self._load_train_test("TEST")
if self._split == "TEST":
return X_test, y_test

return self._preprocess(X_train, y_train, X_test, y_test)
1 change: 1 addition & 0 deletions sktime/datasets/forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Forecasting Datasets."""
52 changes: 52 additions & 0 deletions sktime/datasets/forecasting/tsf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Time series forecasting datasets."""
from pathlib import Path
from typing import Optional, Union

from sktime.datasets._data_io import load_tsf_to_dataframe
from sktime.datasets.base._base import BaseDataset
from sktime.datasets.base._metadata import ExternalDatasetMetadata
from sktime.datasets.tsf_dataset_names import tsf_all

DEFAULT_PATH = Path.cwd().parent / "data"
CITATION = ""


class TSFDatasetLoader(BaseDataset):
"""Forecasting datasets from Monash Time Series Forecasting Archive."""

def __init__(
self,
name: str,
save_dir: Optional[str] = None,
return_data_type: str = "default_tsf",
replace_missing_vals: Union[str, float] = "NaN",
):
metadata = ExternalDatasetMetadata(
name=name,
task_type="forecasting",
url=[f"https://zenodo.org/record/{tsf_all[name]}/files"],
download_file_format="zip",
citation=CITATION,
)
if save_dir is None:
save_dir = Path(DEFAULT_PATH, name)
else:
save_dir = Path(save_dir, name)
super().__init__(metadata, save_dir, return_data_type)
self._missing_val = replace_missing_vals

def _load(self):
"""Load the dataset into memory."""
file_path = Path(self.save_dir, f"{self._metadata.name}.tsf")
y, self._info = load_tsf_to_dataframe(
full_file_path_and_name=file_path,
replace_missing_vals_with=self._missing_val,
return_type=self._return_data_type,
)
return y

@property
def metadata(self):
"""Return the dataset metadata."""
# TODO: comine with self._metadata class
return self._info