Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] classification test scenario with three classes and pd-multiindex mtype #6374

Merged
merged 10 commits into from
May 29, 2024
33 changes: 32 additions & 1 deletion sktime/utils/_testing/scenarios_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from sktime.base import BaseObject
from sktime.registry import scitype
from sktime.utils._testing.hierarchical import _make_hierarchical
from sktime.utils._testing.panel import _make_classification_y, _make_panel_X
from sktime.utils._testing.panel import (
_make_classification_y,
_make_panel,
_make_panel_X,
)
from sktime.utils._testing.scenarios import TestScenario

# random seed for generating data to keep scenarios exactly reproducible
Expand Down Expand Up @@ -117,6 +121,31 @@ def args(self):
default_arg_sequence = ["fit", "predict", "predict", "predict"]


class ClassifierFitPredictThreeClasses(ClassifierTestScenario):
"""Fit/predict with univariate panel X, pd-multiindex mtype, and three classes."""

_tags = {
"X_univariate": True,
"X_unequal_length": False,
"is_enabled": True,
"n_classes": 3,
fkiraly marked this conversation as resolved.
Show resolved Hide resolved
}

@property
def args(self):
y = _make_classification_y(n_instances=18, n_classes=3, random_state=RAND_SEED)
X = _make_panel(n_instances=18, n_timepoints=20, random_state=RAND_SEED, y=y)
X_test = _make_panel_X(n_instances=5, n_timepoints=20, random_state=RAND_SEED)

return {
"fit": {"y": y, "X": X},
"predict": {"X": X_test},
}

default_method_sequence = ["fit", "predict", "predict_proba", "decision_function"]
default_arg_sequence = ["fit", "predict", "predict", "predict"]


class ClassifierFitPredictNumpy(ClassifierTestScenario):
"""Fit/predict with univariate panel X, numpy3D mtype, and labels y."""

Expand Down Expand Up @@ -216,6 +245,7 @@ def args(self):

scenarios_classification = [
ClassifierFitPredict,
ClassifierFitPredictThreeClasses,
ClassifierFitPredictNumpy,
ClassifierFitPredictMultivariate,
ClassifierFitPredictUnequalLength,
Expand All @@ -224,6 +254,7 @@ def args(self):
# same scenarios used for early classification
scenarios_early_classification = [
ClassifierFitPredict,
ClassifierFitPredictThreeClasses,
ClassifierFitPredictNumpy,
ClassifierFitPredictMultivariate,
ClassifierFitPredictUnequalLength,
Expand Down