Skip to content

Commit

Permalink
Merge branch 'main' into reduce-proba
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Jan 13, 2024
2 parents 0546795 + ee03099 commit d90a679
Show file tree
Hide file tree
Showing 19 changed files with 770 additions and 119 deletions.
22 changes: 20 additions & 2 deletions docs/source/api_reference/classification.rst
Expand Up @@ -23,6 +23,17 @@ Composition
SklearnClassifierPipeline
MultiplexClassifier

Model selection and tuning
--------------------------

.. currentmodule:: sktime.classification.model_selection

.. autosummary::
:toctree: auto_generated/
:template: class.rst

TSCGridSearchCV

Ensembles
---------

Expand Down Expand Up @@ -182,8 +193,15 @@ Shapelet-based
MrSEQL
MrSQM

sklearn
-------

sklearn classifiers
-------------------

This section contains classifiers which are not time series classifiers but
simple tabular classifiers in ``sklearn`` compatible API.

They are used internally in time series classifiers, but can also be used
directly in a tabular setting.

.. currentmodule:: sktime.classification.sklearn

Expand Down
12 changes: 6 additions & 6 deletions sktime/base/_base.py
Expand Up @@ -622,12 +622,12 @@ def _get_fitted_params_default(self, obj=None):

# default retrieves all self attributes ending in "_"
# and returns them with keys that have the "_" removed
fitted_params = [attr for attr in dir(obj) if attr.endswith("_")]
fitted_params = [x for x in fitted_params if not x.startswith("_")]
fitted_params = [x for x in fitted_params if hasattr(obj, x)]
fitted_param_dict = {p[:-1]: getattr(obj, p) for p in fitted_params}

return fitted_param_dict
fitted_params = {
attr[:-1]: getattr(obj, attr)
for attr in dir(obj)
if attr.endswith("_") and not attr.startswith("_") and hasattr(obj, attr)
}
return fitted_params

def _get_fitted_params(self):
"""Get fitted parameters.
Expand Down
6 changes: 6 additions & 0 deletions sktime/classification/model_selection/__init__.py
@@ -0,0 +1,6 @@
"""Tuning of time series classifiers."""
# copyright: sktime developers, BSD-3-Clause License (see LICENSE file)

from sktime.classification.model_selection._tune import TSCGridSearchCV

__all__ = ["TSCGridSearchCV"]

0 comments on commit d90a679

Please sign in to comment.