Skip to content

Commit

Permalink
[ENH] predict_proba capability tag for classifiers (#4012)
Browse files Browse the repository at this point in the history
This PR adds a tag `capability:predict_proba` to classifiers, which
signifies whether the `predict_proba` implementation is different from
the default `predict_proba`.

Also fixes some dependent tag logic in the classifier pipelines.

Depends on #4014 for diagnostics.
  • Loading branch information
fkiraly committed Apr 22, 2023
1 parent f5919fe commit 20007f9
Show file tree
Hide file tree
Showing 32 changed files with 48 additions and 5 deletions.
1 change: 1 addition & 0 deletions sktime/classification/base.py
Expand Up @@ -63,6 +63,7 @@ class BaseClassifier(BaseEstimator, ABC):
"capability:train_estimate": False,
"capability:contractable": False,
"capability:multithreading": False,
"capability:predict_proba": False,
"python_version": None, # PEP 440 python version specifier to limit versions
"requires_cython": False, # whether C compiler is required in env, e.g., gcc
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/compose/_column_ensemble.py
Expand Up @@ -22,6 +22,7 @@ class BaseColumnEnsembleClassifier(_HeterogenousMetaEstimator, BaseClassifier):

_tags = {
"capability:multivariate": True,
"capability:predict_proba": True,
"X_inner_mtype": ["nested_univ", "pd-multiindex"],
}

Expand Down
1 change: 1 addition & 0 deletions sktime/classification/compose/_ensemble.py
Expand Up @@ -629,6 +629,7 @@ class WeightedEnsembleClassifier(_HeterogenousMetaEstimator, BaseClassifier):
_tags = {
"capability:multivariate": True,
"capability:missing_values": True,
"capability:predict_proba": True,
"X_inner_mtype": [
"pd-multiindex",
"df-list",
Expand Down
14 changes: 9 additions & 5 deletions sktime/classification/compose/_pipeline.py
Expand Up @@ -99,6 +99,7 @@ class ClassifierPipeline(_HeterogenousMetaEstimator, BaseClassifier):
"capability:train_estimate": False,
"capability:contractable": False,
"capability:multithreading": False,
"capability:predict_proba": True,
}

# no default tag values - these are set dynamically below
Expand Down Expand Up @@ -133,6 +134,8 @@ def __init__(self, classifier, transformers):
unequal = unequal or self.transformers_.get_tag(
"capability:unequal_length:removes", False
)
# predict_proba is same as that of classifier
predict_proba = classifier.get_tag("capability:predict_proba")
# last three tags are always False, since not supported by transformers
tags_to_set = {
"capability:multivariate": multivariate,
Expand All @@ -141,6 +144,7 @@ def __init__(self, classifier, transformers):
"capability:contractable": False,
"capability:train_estimate": False,
"capability:multithreading": False,
"capability:predict_proba": predict_proba,
}
self.set_tags(**tags_to_set)

Expand Down Expand Up @@ -403,12 +407,13 @@ class SklearnClassifierPipeline(_HeterogenousMetaEstimator, BaseClassifier):

_tags = {
"X_inner_mtype": "pd-multiindex", # which type do _fit/_predict accept
"capability:multivariate": False,
"capability:unequal_length": False,
"capability:multivariate": True,
"capability:unequal_length": True,
"capability:missing_values": True,
"capability:train_estimate": False,
"capability:contractable": False,
"capability:multithreading": False,
"capability:predict_proba": True,
}

# no default tag values - these are set dynamically below
Expand All @@ -424,9 +429,8 @@ def __init__(self, classifier, transformers):

super(SklearnClassifierPipeline, self).__init__()

# can handle multivariate iff all transformers can
# sklearn transformers always support multivariate
multivariate = not self.transformers_.get_tag("univariate-only", True)
# all sktime and sklearn transformers always support multivariate
multivariate = True
# can handle missing values iff transformer chain removes missing data
# sklearn classifiers might be able to handle missing data (but no tag there)
# so better set the tag liberally
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/dictionary_based/_boss.py
Expand Up @@ -129,6 +129,7 @@ class BOSSEnsemble(BaseClassifier):
"capability:train_estimate": True,
"capability:multithreading": True,
"classifier_type": "dictionary",
"capability:predict_proba": True,
"python_dependencies": "numba",
}

Expand Down
1 change: 1 addition & 0 deletions sktime/classification/dictionary_based/_cboss.py
Expand Up @@ -135,6 +135,7 @@ class ContractableBOSS(BaseClassifier):
"capability:contractable": True,
"capability:multithreading": True,
"classifier_type": "dictionary",
"capability:predict_proba": True,
"python_dependencies": "numba",
}

Expand Down
1 change: 1 addition & 0 deletions sktime/classification/dictionary_based/_muse.py
Expand Up @@ -122,6 +122,7 @@ class MUSE(BaseClassifier):
_tags = {
"capability:multivariate": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"X_inner_mtype": "numpy3D", # which mtypes do _fit/_predict support for X?
"classifier_type": "dictionary",
"python_dependencies": "numba",
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/dictionary_based/_tde.py
Expand Up @@ -151,6 +151,7 @@ class TemporalDictionaryEnsemble(BaseClassifier):
"capability:train_estimate": True,
"capability:contractable": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "dictionary",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/dictionary_based/_weasel.py
Expand Up @@ -122,6 +122,7 @@ class WEASEL(BaseClassifier):

_tags = {
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "dictionary",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/distance_based/_elastic_ensemble.py
Expand Up @@ -91,6 +91,7 @@ class ElasticEnsemble(BaseClassifier):

_tags = {
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "distance",
}

Expand Down
2 changes: 2 additions & 0 deletions sktime/classification/distance_based/_proximity_forest.py
Expand Up @@ -1041,6 +1041,7 @@ class ProximityTree(BaseClassifier):

_tags = {
"capability:multithreading": True,
"capability:predict_proba": True,
"X_inner_mtype": "nested_univ",
"python_dependencies": "numba",
}
Expand Down Expand Up @@ -1283,6 +1284,7 @@ class ProximityForest(BaseClassifier):
_tags = {
"X_inner_mtype": "nested_univ",
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "distance",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/distance_based/_shape_dtw.py
Expand Up @@ -113,6 +113,7 @@ class ShapeDTW(BaseClassifier):
"""

_tags = {
"capability:predict_proba": True,
"classifier_type": "distance",
}

Expand Down
Expand Up @@ -121,6 +121,7 @@ class KNeighborsTimeSeriesClassifier(BaseClassifier):
"capability:multivariate": True,
"capability:unequal_length": True,
"capability:missing_values": True,
"capability:predict_proba": True,
"X_inner_mtype": ["pd-multiindex", "numpy3D"],
"classifier_type": "distance",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/dummy/_dummy.py
Expand Up @@ -65,6 +65,7 @@ class prior probabilities.
"capability:missing_values": True,
"capability:unequal_length": True,
"capability:multivariate": True,
"capability:predict_proba": True,
}

VALID_STRATEGIES = ["most_frequent", "prior", "stratified", "uniform", "constant"]
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/feature_based/_catch22_classifier.py
Expand Up @@ -84,6 +84,7 @@ class Catch22Classifier(_DelegatedClassifier):
_tags = {
"capability:multivariate": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "feature",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/feature_based/_fresh_prince.py
Expand Up @@ -64,6 +64,7 @@ class FreshPRINCE(BaseClassifier):
"capability:multivariate": True,
"capability:multithreading": True,
"capability:train_estimate": True,
"capability:predict_proba": True,
"classifier_type": "feature",
"python_version": "<3.10",
"python_dependencies": "tsfresh",
Expand Down
Expand Up @@ -67,6 +67,7 @@ class MatrixProfileClassifier(BaseClassifier):

_tags = {
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "distance",
}

Expand Down
Expand Up @@ -55,6 +55,7 @@ class RandomIntervalClassifier(BaseClassifier):
_tags = {
"capability:multivariate": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "interval",
}

Expand Down
Expand Up @@ -92,6 +92,7 @@ class SignatureClassifier(BaseClassifier):

_tags = {
"capability:multivariate": True,
"capability:predict_proba": True,
"classifier_type": "feature",
"python_dependencies": "esig",
"python_version": "<3.10",
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/feature_based/_summary_classifier.py
Expand Up @@ -67,6 +67,7 @@ class SummaryClassifier(BaseClassifier):
_tags = {
"capability:multivariate": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "feature",
}

Expand Down
1 change: 1 addition & 0 deletions sktime/classification/feature_based/_tsfresh_classifier.py
Expand Up @@ -69,6 +69,7 @@ class TSFreshClassifier(BaseClassifier):
_tags = {
"capability:multivariate": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "feature",
"python_version": "<3.10",
"python_dependencies": "tsfresh",
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/hybrid/_hivecote_v1.py
Expand Up @@ -96,6 +96,7 @@ class HIVECOTEV1(BaseClassifier):

_tags = {
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "hybrid",
}

Expand Down
1 change: 1 addition & 0 deletions sktime/classification/hybrid/_hivecote_v2.py
Expand Up @@ -94,6 +94,7 @@ class HIVECOTEV2(BaseClassifier):
"capability:multivariate": True,
"capability:contractable": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "hybrid",
}

Expand Down
1 change: 1 addition & 0 deletions sktime/classification/interval_based/_cif.py
Expand Up @@ -120,6 +120,7 @@ class CanonicalIntervalForest(BaseClassifier):
_tags = {
"capability:multivariate": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "interval",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/interval_based/_drcif.py
Expand Up @@ -144,6 +144,7 @@ class DrCIF(BaseClassifier):
"capability:train_estimate": True,
"capability:contractable": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "interval",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/interval_based/_rise.py
Expand Up @@ -179,6 +179,7 @@ class RandomIntervalSpectralEnsemble(BaseClassifier):

_tags = {
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "interval",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/interval_based/_stsf.py
Expand Up @@ -88,6 +88,7 @@ class SupervisedTimeSeriesForest(BaseClassifier):

_tags = {
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "interval",
}

Expand Down
2 changes: 2 additions & 0 deletions sktime/classification/interval_based/_tsf.py
Expand Up @@ -85,6 +85,8 @@ class TimeSeriesForestClassifier(

_base_estimator = DecisionTreeClassifier(criterion="entropy")

_tags = {"capability:predict_proba": True}

def __init__(
self,
min_interval=3,
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/kernel_based/_arsenal.py
Expand Up @@ -116,6 +116,7 @@ class Arsenal(BaseClassifier):
"capability:train_estimate": True,
"capability:contractable": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "kernel",
"python_dependencies": "numba",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/kernel_based/_svc.py
Expand Up @@ -107,6 +107,7 @@ class TimeSeriesSVC(BaseClassifier):
"capability:multivariate": True,
"capability:unequal_length": True,
"capability:missing_values": True,
"capability:predict_proba": True,
"X_inner_mtype": ["pd-multiindex", "numpy3D"],
"classifier_type": "kernel",
}
Expand Down
1 change: 1 addition & 0 deletions sktime/classification/shapelet_based/_stc.py
Expand Up @@ -132,6 +132,7 @@ class ShapeletTransformClassifier(BaseClassifier):
"capability:train_estimate": True,
"capability:contractable": True,
"capability:multithreading": True,
"capability:predict_proba": True,
"classifier_type": "shapelet",
"python_dependencies": "numba",
}
Expand Down
7 changes: 7 additions & 0 deletions sktime/registry/_tags.py
Expand Up @@ -199,6 +199,13 @@
"bool",
"does the forecaster implement predict_variance?",
),
(
"capability:predict_proba",
"classifier",
"bool",
"does the classifier implement a non-default predict_proba, "
"i.e., not just 0/1 probabilities obtained from predict?",
),
(
"capability:multivariate",
[
Expand Down

0 comments on commit 20007f9

Please sign in to comment.