Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOC] Include section on unequal length data in classification notebook #3809

Merged
merged 27 commits into from Nov 23, 2022
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b169fc0
catch22 and drcif bugs
MatthewMiddlehurst Jan 12, 2022
46b73cf
test update
MatthewMiddlehurst Jan 12, 2022
cbfa4a4
accidental change revert
MatthewMiddlehurst Jan 12, 2022
4a68bcb
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Feb 7, 2022
eccafaf
bugfixes
MatthewMiddlehurst Feb 7, 2022
4fdeec1
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Feb 7, 2022
bd9e486
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Mar 4, 2022
144584f
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Mar 18, 2022
465c5cc
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Jul 7, 2022
71e42d2
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Jul 11, 2022
1a6a807
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Jul 29, 2022
003cb2d
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Aug 6, 2022
2bd870b
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Aug 21, 2022
4f3538a
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Sep 14, 2022
3cb5110
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Oct 3, 2022
487c46e
classification notebook cleanup
MatthewMiddlehurst Oct 3, 2022
1344467
classification notebook cleanup 2
MatthewMiddlehurst Oct 3, 2022
acc5c25
fixes
MatthewMiddlehurst Oct 4, 2022
91923d6
typo
MatthewMiddlehurst Oct 4, 2022
3d26dbd
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Oct 24, 2022
e286d7b
Merge branch 'classifier_notebook' of https://github.com/alan-turing-…
MatthewMiddlehurst Oct 24, 2022
c0e91f7
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Nov 17, 2022
e425825
classification notebook
MatthewMiddlehurst Nov 17, 2022
3038404
unequal length notebook example
MatthewMiddlehurst Nov 17, 2022
e72a090
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Nov 18, 2022
80b336d
Merge branch 'main' of https://github.com/alan-turing-institute/sktim…
MatthewMiddlehurst Nov 21, 2022
65862f6
tag
MatthewMiddlehurst Nov 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 82 additions & 14 deletions examples/02_classification.ipynb
Expand Up @@ -46,7 +46,7 @@
},
"outputs": [],
"source": [
"# Imports used in this notebook\n",
"# Plotting and data loading imports used in this notebook\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from sktime.datasets import (\n",
Expand Down Expand Up @@ -165,7 +165,7 @@
"collapsed": false
},
"source": [
"Some data sets have unequal length series. Two data sets with this characteristic are shipped with sktime: PLAID (univariate) and JapaneseVowels (multivariate). We cannot store unequal length series in numpy arrays. Instead, we use nested pandas data frames, where each cell is a pandas Series. This is the default return type for all single problem loaders."
"Some data sets have unequal length series. Two data sets with this characteristic are shipped with sktime: PLAID (univariate) and JapaneseVowels (multivariate). We cannot store unequal length series in `numpy` arrays. Instead, we use a nested `pandas` `DataFrame`, where each cell is a `pandas` `Series`. This is the default return type for all single problem loaders."
]
},
{
Expand Down Expand Up @@ -197,6 +197,8 @@
"outputs": [],
"source": [
"plaid_X, plaid_y = load_plaid()\n",
"plaid_train_X, plaid_train_y = load_plaid(split=\"train\")\n",
"plaid_test_X, plaid_test_y = load_plaid(split=\"test\")\n",
"print(type(plaid_X))\n",
"\n",
"plt.title(\" Four instances of PLAID dataset\")\n",
Expand Down Expand Up @@ -251,6 +253,7 @@
"arrow_test_X_2d, arrow_test_y_2d = load_arrow_head(split=\"test\", return_type=\"numpy2d\")\n",
"classifier.fit(arrow_train_X_2d, arrow_train_y_2d)\n",
"y_pred = classifier.predict(arrow_test_X_2d)\n",
"\n",
"accuracy_score(arrow_test_y_2d, y_pred)"
]
},
Expand All @@ -276,6 +279,7 @@
"rocket = RocketClassifier(num_kernels=2000)\n",
"rocket.fit(arrow_train_X, arrow_train_y)\n",
"y_pred = rocket.predict(arrow_test_X)\n",
"\n",
"accuracy_score(arrow_test_y, y_pred)"
]
},
Expand All @@ -301,6 +305,7 @@
"hc2 = HIVECOTEV2(time_limit_in_minutes=0.2)\n",
"hc2.fit(arrow_train_X, arrow_train_y)\n",
"y_pred = hc2.predict(arrow_test_X)\n",
"\n",
"accuracy_score(arrow_test_y, y_pred)"
]
},
Expand Down Expand Up @@ -338,14 +343,15 @@
"\n",
"pipe.fit(arrow_train_X, arrow_train_y)\n",
"y_pred = pipe.predict(arrow_test_X)\n",
"\n",
"accuracy_score(arrow_test_y, y_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Other transformations such as the TSFreshFeatureExtractor for the tsfresh feature set, SummaryTransformer for basic summary statistics, and the RandomShapeletTransform for the shapelet transform can also be used in pipelines following the same creation and fit/predict structure.\n",
"Other transformations such as the `TSFreshFeatureExtractor` for the [tsfresh](https://www.sciencedirect.com/science/article/pii/S0925231218304843) feature set, `SummaryTransformer` for basic summary statistics, and the `RandomShapeletTransform` for the [shapelet transform](https://link.springer.com/chapter/10.1007/978-3-662-55608-5_2) can also be used in pipelines following the same creation and fit/predict structure.\n",
"\n",
"In the following example, we pipeline an `sktime` transformer with an `sktime` time series classifier using the `*` dunder operator, which is a shorthand for `make_pipeline`. Estimators on the right are pipelined after estimators on the left of the operator:"
]
Expand All @@ -365,6 +371,7 @@
"\n",
"pipe_sktime.fit(arrow_train_X, arrow_train_y)\n",
"y_pred = pipe_sktime.predict(arrow_test_X)\n",
"\n",
"accuracy_score(arrow_test_y, y_pred)"
]
},
Expand Down Expand Up @@ -449,6 +456,7 @@
"\n",
"parameter_tuning_method.fit(arrow_train_X, arrow_train_y)\n",
"y_pred = parameter_tuning_method.predict(arrow_test_X)\n",
"\n",
"accuracy_score(arrow_test_y, y_pred)"
]
},
Expand Down Expand Up @@ -476,24 +484,21 @@
"\n",
"calibrated_drcif.fit(arrow_train_X, arrow_train_y)\n",
"y_pred = calibrated_drcif.predict(arrow_test_X)\n",
"\n",
"accuracy_score(arrow_test_y, y_pred)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## Multivariate Classification\n",
"\n",
"Many classifiers, including ROCKET and HC2, are configured to work with multivariate input. For example:"
]
},
Expand All @@ -510,6 +515,7 @@
"rocket = RocketClassifier(num_kernels=2000)\n",
"rocket.fit(motions_train_X, motions_train_y)\n",
"y_pred = rocket.predict(motions_test_X)\n",
"\n",
"accuracy_score(motions_test_y, y_pred)"
]
},
Expand All @@ -520,9 +526,10 @@
"source": [
"from sktime.classification.hybrid import HIVECOTEV2\n",
"\n",
"HIVECOTEV2(time_limit_in_minutes=0.25)\n",
"HIVECOTEV2(time_limit_in_minutes=0.2)\n",
"hc2.fit(motions_train_X, motions_train_y)\n",
"y_pred = hc2.predict(motions_test_X)\n",
"\n",
"accuracy_score(motions_test_y, y_pred)"
],
"metadata": {
Expand Down Expand Up @@ -556,7 +563,9 @@
"\n",
"clf = ColumnConcatenator() * DrCIF(n_estimators=10, n_intervals=5)\n",
"clf.fit(motions_train_X, motions_train_y)\n",
"clf.score(motions_test_X, motions_test_y)"
"y_pred = clf.predict(motions_test_X)\n",
"\n",
"accuracy_score(motions_test_y, y_pred)"
]
},
{
Expand All @@ -580,17 +589,76 @@
"from sktime.classification.interval_based import DrCIF\n",
"from sktime.classification.kernel_based import RocketClassifier\n",
"\n",
"clf = ColumnEnsembleClassifier(\n",
"col = ColumnEnsembleClassifier(\n",
" estimators=[\n",
" (\"DrCIF0\", DrCIF(n_estimators=10, n_intervals=5), [0]),\n",
" (\"ROCKET3\", RocketClassifier(num_kernels=1000), [3]),\n",
" ]\n",
")\n",
"\n",
"clf.fit(motions_train_X, motions_train_y)\n",
"clf.score(motions_test_X, motions_test_y)"
"col.fit(motions_train_X, motions_train_y)\n",
"y_pred = col.predict(motions_test_X)\n",
"\n",
"accuracy_score(motions_test_y, y_pred)"
]
},
{
"cell_type": "markdown",
"source": [
"## Classification with Unequal Length Series\n",
"\n",
"A common trait in time series data is absence of a uniform series length, as seen in the PLAID and JapaneseVowels datasets introduced previously. None of the `numpy` data formats support ragged arrays, as such one of the `pandas` `DataFrame` formats must be used for unequal length data.\n",
"\n",
"At the time of writing the number of classifiers which natively support unequal length series is limited. The following outputs the current classifiers which support unequal length data."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from sktime.registry import all_estimators\n",
"\n",
"# search for all classifiers which can handle unequal length data. This may give some\n",
"# UserWarnings if soft dependencies are not installed.\n",
"all_estimators(\n",
" filter_tags={\"capability:unequal_length\": True}, estimator_types=\"classifier\"\n",
")"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"source": [
"Certain `sktime` transformers such as the `PaddingTransformer` and `TruncationTransformer` can be used in a pipeline to process unequal length data for use in a wider range of classification algorithms. Transformers which equalise the length of seres can be found using the `\"capability:unequal_length:removes\"` tag.\n"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"from sktime.classification.feature_based import RandomIntervalClassifier\n",
"from sktime.transformations.panel.padder import PaddingTransformer\n",
"\n",
"padded_clf = PaddingTransformer() * RandomIntervalClassifier(n_intervals=5)\n",
"padded_clf.fit(plaid_train_X, plaid_test_y)\n",
"y_pred = padded_clf.predict(plaid_test_X)\n",
"\n",
"accuracy_score(plaid_test_y, y_pred)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -604,7 +672,7 @@
"One nearest neighbour (1-NN) classification with Dynamic Time Warping (DTW) is one of the oldest TSC approaches, and is commonly used as a performance benchmark.\n",
"\n",
"#### RocketClassifier\n",
"The RocketClassifier is based on a pipeline combination of the ROCKET transformation (transformations.panel.rocket) and the sklearn RidgeClassifierCV classifier. The RocketClassifier is configurable to use variants MiniRocket and MultiRocket. ROCKET is based on generating random convolutions. A large number are generated then the classifier performs a feature selection.\n",
"The RocketClassifier is based on a pipeline combination of the ROCKET transformation (transformations.panel.rocket) and the sklearn RidgeClassifierCV classifier. The RocketClassifier is configurable to use variants MiniRocket and MultiRocket. ROCKET is based on generating random convolutional kernels. A large number are generated, then a linear classifier is built on the output.\n",
"\n",
"[1] Dempster, Angus, François Petitjean, and Geoffrey I. Webb. \"Rocket: exceptionally fast and accurate time series classification using random convolutional kernels.\" Data Mining and Knowledge Discovery (2020)\n",
"[arXiv version](https://arxiv.org/abs/1910.13051)\n",
Expand Down