diff --git a/.gitignore b/.gitignore index 346afd4..4e5f94b 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,4 @@ notebooks/20_catch_errors.ipynb # Custom notebooks/debug* +notebooks/my_hypno.csv \ No newline at end of file diff --git a/docs/changelog.rst b/docs/changelog.rst index c89a363..d3b230b 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -29,18 +29,23 @@ which comes with several pre-built functions (aka methods) and attributes. See f hyp.duration # Total duration of the hypnogram, in minutes hyp.sampling_frequency # Sampling frequency of the hypnogram hyp.mapping # Mapping from strings to integers + hyp.proba # Probability of each sleep stage, if specified # Below are some class methods hyp.sleep_statistics() # Calculate the sleep statistics hyp.plot_hypnogram() # Plot the hypnogram hyp.upsample_to_data() # Upsample to data -Please see the documentation of :py:class:`yasa.Hypnogram` for more details. +This brings along critical changes to several YASA function, for example: -.. important:: - The adoption of object-oriented :py:class:`yasa.Hypnogram` usage brings along critical changes to several YASA function, for example: +* :py:class:`yasa.SleepStaging` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`. The probability of each sleep stage for each epoch can now be accessed with :py:attr:`yasa.Hypnogram.proba`. +* :py:func:`yasa.simulate_hypnogram` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`. +* The suggested approach to plotting hypnograms is through the :py:meth:`yasa.Hypnogram.plot_hypnogram` method. The old function :py:func:`yasa.plot_hypnogram` still exists, but now *requires* a :py:class:`yasa.Hypnogram` instance as input. + +**Other improvements** - * :py:func:`yasa.simulate_hypnogram` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`. - * The suggested approach to plotting hypnograms is through the :py:meth:`yasa.Hypnogram.plot_hypnogram` method. The old function :py:func:`yasa.plot_hypnogram` still exists, but now *requires* a :py:class:`yasa.Hypnogram` instance as input. +* Added helpful string representation (__repr__) to :py:class:`yasa.SleepStaging`. +* :py:func:`yasa.simulate_hypnogram` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`. +* The suggested approach to plotting hypnograms is through the :py:meth:`yasa.Hypnogram.plot_hypnogram` method. The old function :py:func:`yasa.plot_hypnogram` still exists, but now *requires* a :py:class:`yasa.Hypnogram` instance as input. ---------------------------------------------------------------------------------------- diff --git a/notebooks/14_automatic_sleep_staging.ipynb b/notebooks/14_automatic_sleep_staging.ipynb index 3aad21e..706c138 100644 --- a/notebooks/14_automatic_sleep_staging.ipynb +++ b/notebooks/14_automatic_sleep_staging.ipynb @@ -49,24 +49,29 @@ { "data": { "text/html": [ - "\n", - "\n", "\n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", " \n", - "\n", + " \n", + " \n", + " \n", " \n", " \n", - "\n", + " \n", + " \n", + " \n", " \n", " \n", " \n", + " \n", " \n", + " \n", " \n", " \n", " \n", @@ -83,28 +88,36 @@ " \n", " \n", " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - "\n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", + " \n", " \n", " \n", - " \n", + " \n", " \n", - "
Measurement dateJanuary 15, 2016 14:01:00 GMT
ExperimenterUnknownUnknown
ParticipantUnknownUnknown
Digitized points15 points
Good channels
ECG channelsNot available
Sampling frequency100.00 Hz
Highpass0.00 Hz
Lowpass50.00 Hz
Filenamessub-02_mne_raw.fif
Duration00:48:59 (HH:MM:SS)00:48:60 (HH:MM:SS)
\n" + "" ], "text/plain": [ "" @@ -149,8 +162,66 @@ ], "source": [ "# Let's now load the human-scored hypnogram, where each value represents a 30-sec epoch.\n", - "hypno = np.loadtxt('sub-02_hypno_30s.txt', dtype=str)\n", - "hypno" + "hyp = np.loadtxt('sub-02_hypno_30s.txt', dtype=str)\n", + "hyp" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Epoch\n", + "0 WAKE\n", + "1 WAKE\n", + "2 WAKE\n", + "3 WAKE\n", + "4 WAKE\n", + " ... \n", + "93 WAKE\n", + "94 WAKE\n", + "95 WAKE\n", + "96 WAKE\n", + "97 WAKE\n", + "Name: Stage, Length: 98, dtype: category\n", + "Categories (7, object): ['WAKE', 'N1', 'N2', 'N3', 'REM', 'ART', 'UNS']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Convert it to a Hypnogram instance, which is the preferred way to manipulate hypnograms since v0.7\n", + "hyp = yasa.Hypnogram(hyp, freq=\"30s\")\n", + "# The hypnogram values can be obtained with\n", + "hyp.hypno" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Let's plot it\n", + "fig, ax = plt.subplots(1, 1, figsize=(7, 3), constrained_layout=True, dpi=80)\n", + "ax = hyp.plot_hypnogram(fill_color=\"gainsboro\", ax=ax)" ] }, { @@ -164,9 +235,20 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# We first need to specify the channel names and, optionally, the age and sex of the participant\n", "# - \"raw\" is the name of the variable containing the polysomnography data loaded with MNE.\n", @@ -174,38 +256,35 @@ "# - \"eog_name\" is the name of the EOG channel (e.g. LOC-M1). This is optional.\n", "# - \"eog_name\" is the name of the EOG channel (e.g. EMG1-EMG3). This is optional.\n", "# - \"metadata\" is a dictionary containing the age and sex of the participant. This is optional.\n", - "sls = yasa.SleepStaging(raw, eeg_name=\"C4\", eog_name=\"EOG1\", emg_name=\"EMG1\", metadata=dict(age=21, male=False))" + "sls = yasa.SleepStaging(raw, eeg_name=\"C4\", eog_name=\"EOG1\", emg_name=\"EMG1\", metadata=dict(age=21, male=False))\n", + "sls" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/raphael/.pyenv/versions/3.8.3/lib/python3.8/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", - "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n", + "/Users/raphael/.pyenv/versions/3.9.6/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", + "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ - "array(['W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n", - " 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n", - " 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'N2', 'N2', 'N2', 'N2',\n", - " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2',\n", - " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2',\n", - " 'N2', 'N2', 'N2', 'N3', 'N3', 'N3', 'N3', 'N2', 'N3', 'N3', 'N3',\n", - " 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3',\n", - " 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'W', 'W', 'W', 'W',\n", - " 'W', 'W', 'W', 'W'], dtype=object)" + "\n", + " - Use `.hypno` to get the string values as a pandas.Series\n", + " - Use `.as_int()` to get the integer values as a pandas.Series\n", + " - Use `.plot_hypnogram()` to plot the hypnogram\n", + "See the online documentation for more details." ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -218,21 +297,122 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Epoch\n", + "0 WAKE\n", + "1 WAKE\n", + "2 WAKE\n", + "3 WAKE\n", + "4 WAKE\n", + " ... \n", + "93 WAKE\n", + "94 WAKE\n", + "95 WAKE\n", + "96 WAKE\n", + "97 WAKE\n", + "Name: Stage, Length: 98, dtype: category\n", + "Categories (7, object): ['WAKE', 'N1', 'N2', 'N3', 'REM', 'ART', 'UNS']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pred.hypno" + ] + }, + { + "cell_type": "code", + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The overall agreement is 0.837\n" + "The overall agreement is 83.67%\n" ] } ], "source": [ "# What is the accuracy of the prediction, compared to the human scoring\n", - "accuracy = (hypno == y_pred).sum() / y_pred.size\n", - "print(\"The overall agreement is %.3f\" % accuracy)" + "accuracy = 100 * (hyp.hypno == y_pred.hypno).mean()\n", + "print(f\"The overall agreement is {accuracy:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Plot and sleep statistics**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot the predicted hypnogram\n", + "fig, ax = plt.subplots(1, 1, figsize=(7, 3), constrained_layout=True, dpi=80)\n", + "ax = y_pred.plot_hypnogram(fill_color=\"gainsboro\", ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'TIB': 49.0,\n", + " 'SPT': 28.0,\n", + " 'WASO': 0.0,\n", + " 'TST': 28.0,\n", + " 'SE': 57.1429,\n", + " 'SME': 100.0,\n", + " 'SFI': 1.0714,\n", + " 'SOL': 17.0,\n", + " 'SOL_5min': 17.0,\n", + " 'Lat_REM': nan,\n", + " 'WAKE': 21.0,\n", + " 'N1': 0.0,\n", + " 'N2': 15.0,\n", + " 'N3': 13.0,\n", + " 'REM': 0.0,\n", + " '%N1': 0.0,\n", + " '%N2': 53.5714,\n", + " '%N3': 46.4286,\n", + " '%REM': 0.0}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Calculate the summary sleep statistics of the predicted hypnogram\n", + "y_pred.sleep_statistics()" ] }, { @@ -244,9 +424,17 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/raphael/GitHub/yasa/yasa/staging.py:475: FutureWarning: The `predict_proba` function is deprecated and will be removed in v0.8. The predicted probabilities can now be accessed with `yasa.Hypnogram.proba` instead, e.g `SleepStaging.predict().proba`\n", + " warnings.warn(\n" + ] + }, { "data": { "text/html": [ @@ -271,11 +459,11 @@ " N1\n", " N2\n", " N3\n", - " R\n", - " W\n", + " REM\n", + " WAKE\n", " \n", " \n", - " epoch\n", + " Epoch\n", " \n", " \n", " \n", @@ -378,8 +566,8 @@ "" ], "text/plain": [ - " N1 N2 N3 R W\n", - "epoch \n", + " N1 N2 N3 REM WAKE\n", + "Epoch \n", "0 0.002202 0.005040 0.000703 1.875966e-18 0.992055\n", "1 0.003362 0.003284 0.001926 8.279263e-05 0.991345\n", "2 0.004078 0.003225 0.000095 7.688612e-04 0.991833\n", @@ -395,31 +583,30 @@ "[98 rows x 5 columns]" ] }, - "execution_count": 7, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# What are the predicted probabilities of each sleep stage at each epoch?\n", - "sls.predict_proba()" + "proba = sls.predict_proba()\n", + "proba" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -430,13 +617,13 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "epoch\n", + "Epoch\n", "0 0.992055\n", "1 0.991345\n", "2 0.991833\n", @@ -451,14 +638,14 @@ "Length: 98, dtype: float64" ] }, - "execution_count": 9, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# From the probabilities, we can extract a confidence level (ranging from 0 to 1) for each epoch.\n", - "confidence = sls.predict_proba().max(1)\n", + "confidence = proba.max(1)\n", "confidence" ] }, @@ -471,7 +658,17 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# The predicted sleep stages can be exported to a CSV file with:\n", + "hyp.hypno.to_csv(\"my_hypno.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -499,7 +696,7 @@ " Confidence\n", " \n", " \n", - " epoch\n", + " Epoch\n", " \n", " \n", " \n", @@ -507,32 +704,32 @@ " \n", " \n", " 0\n", - " W\n", + " WAKE\n", " 0.992055\n", " \n", " \n", " 1\n", - " W\n", + " WAKE\n", " 0.991345\n", " \n", " \n", " 2\n", - " W\n", + " WAKE\n", " 0.991833\n", " \n", " \n", " 3\n", - " W\n", + " WAKE\n", " 0.995557\n", " \n", " \n", " 4\n", - " W\n", + " WAKE\n", " 0.988994\n", " \n", " \n", " 5\n", - " W\n", + " WAKE\n", " 0.986805\n", " \n", " \n", @@ -541,27 +738,35 @@ ], "text/plain": [ " Stage Confidence\n", - "epoch \n", - "0 W 0.992055\n", - "1 W 0.991345\n", - "2 W 0.991833\n", - "3 W 0.995557\n", - "4 W 0.988994\n", - "5 W 0.986805" + "Epoch \n", + "0 WAKE 0.992055\n", + "1 WAKE 0.991345\n", + "2 WAKE 0.991833\n", + "3 WAKE 0.995557\n", + "4 WAKE 0.988994\n", + "5 WAKE 0.986805" ] }, - "execution_count": 10, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Let's first create a dataframe with the predicted stages and confidence\n", - "df_pred = pd.DataFrame({'Stage': y_pred, 'Confidence': confidence})\n", - "df_pred.head(6)\n", - "\n", + "# We can also add the confidence level:\n", + "df_pred = hyp.hypno.to_frame()\n", + "df_pred[\"Confidence\"] = confidence\n", + "df_pred.head(6)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ "# Now export to a CSV file\n", - "# df_pred.to_csv(\"my_hypno.csv\")" + "df_pred.to_csv(\"my_hypno.csv\")" ] }, { @@ -573,33 +778,38 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/raphael/.pyenv/versions/3.8.3/lib/python3.8/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", - "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n", + "/Users/raphael/.pyenv/versions/3.9.6/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", + "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ - "array(['W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n", - " 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n", - " 'W', 'W', 'N1', 'N2', 'W', 'W', 'N2', 'N2', 'R', 'N2', 'R', 'R',\n", - " 'N2', 'R', 'R', 'N2', 'R', 'R', 'R', 'R', 'R', 'R', 'R', 'R', 'N2',\n", - " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2',\n", - " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N3', 'N2', 'N2',\n", - " 'N3', 'N2', 'N2', 'N3', 'N2', 'N3', 'N2', 'N2', 'N2', 'N3', 'N3',\n", - " 'N3', 'N2', 'N3', 'N2', 'N3', 'N3', 'W', 'N3', 'W', 'W', 'W', 'W',\n", - " 'W', 'W'], dtype=object)" + "Epoch\n", + "0 WAKE\n", + "1 WAKE\n", + "2 WAKE\n", + "3 WAKE\n", + "4 WAKE\n", + " ... \n", + "93 WAKE\n", + "94 WAKE\n", + "95 WAKE\n", + "96 WAKE\n", + "97 WAKE\n", + "Name: Stage, Length: 98, dtype: category\n", + "Categories (7, object): ['WAKE', 'N1', 'N2', 'N3', 'REM', 'ART', 'UNS']" ] }, - "execution_count": 11, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -607,13 +817,13 @@ "source": [ "# Using just an EEG channel (= no EOG or EMG)\n", "y_pred = yasa.SleepStaging(raw, eeg_name=\"C4\").predict()\n", - "y_pred" + "y_pred.hypno" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -627,7 +837,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.3" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/yasa/hypno.py b/yasa/hypno.py index d67bf82..c52a8b1 100644 --- a/yasa/hypno.py +++ b/yasa/hypno.py @@ -74,6 +74,10 @@ class Hypnogram: scorer : str An optional string indicating the scorer name. If specified, this will be set as the name of the :py:class:`pandas.Series`, otherwise the name will be set to "Stage". + proba : :py:class:`pandas.DataFrame` + An optional dataframe with the probability of each sleep stage for each epoch in hypnogram. + Each row must sum to 1. This is automatically included if the hypnogram is created with + :py:class:`yasa.SleepStaging`. Examples -------- @@ -215,7 +219,7 @@ class Hypnogram: '%REM': 8.9713} """ - def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None): + def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None, proba=None): assert isinstance( values, (list, np.ndarray, pd.Series) ), "`values` must be a list, numpy.array or pandas.Series" @@ -228,6 +232,9 @@ def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None): assert isinstance( scorer, (type(None), str, int) ), "`scorer` must be either None, or a string or an integer." + assert isinstance( + proba, (pd.DataFrame, type(None)) + ), "`proba` must be either None or a pandas.DataFrame" if n_stages == 2: accepted = ["W", "WAKE", "S", "SLEEP", "ART", "UNS"] mapping = {"WAKE": 0, "SLEEP": 1, "ART": -1, "UNS": -2} @@ -268,6 +275,19 @@ def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None): fake_dt = pd.date_range(start="2022-12-03 00:00:00", freq=freq, periods=hypno.shape[0]) hypno.index.name = "Epoch" timedelta = fake_dt - fake_dt[0] + # Validate proba + if proba is not None: + assert proba.shape[1] > 0, "`proba` must have at least one column." + assert proba.shape[0] == hypno.shape[0], "`proba` must have the same length as `values`" + assert np.allclose(proba.sum(1), 1), "Each row of `proba` must sum to 1." + in_proba_but_not_labels = np.setdiff1d(proba.columns, labels) + # in_labels_but_not_proba = np.setdiff1d(labels, proba.columns) + assert not len(in_proba_but_not_labels), ( + f"Invalid stages in `proba`: {in_proba_but_not_labels}. The accepted stages are: " + f"{labels}." + ) + # Ensure same order as `labels` + proba = proba.reindex(columns=labels).dropna(how="all", axis=1) # Set attributes self._hypno = hypno self._n_epochs = hypno.shape[0] @@ -280,6 +300,7 @@ def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None): self._labels = labels self._mapping = mapping self._scorer = scorer + self._proba = proba def __repr__(self): # TODO v0.8: Keep only the text between < and > @@ -294,15 +315,7 @@ def __repr__(self): ) def __str__(self): - text_scorer = f", scored by {self.scorer}" if self.scorer is not None else "" - return ( - f"\n" - " - Use `.hypno` to get the string values as a pandas.Series\n" - " - Use `.as_int()` to get the integer values as a pandas.Series\n" - " - Use `.plot_hypnogram()` to plot the hypnogram\n" - "See the online documentation for more details." - ) + return self.__repr__() @property def hypno(self): @@ -391,6 +404,14 @@ def scorer(self): """The scorer name.""" return self._scorer + @property + def proba(self): + """ + If specified, a :py:class:`pandas.DataFrame` with the probability of each sleep stage + for each epoch in hypnogram. + """ + return self._proba + # CLASS METHODS BELOW def as_annotations(self): @@ -558,6 +579,7 @@ def consolidate_stages(self, new_n_stages): freq=self.freq, start=self.start, scorer=self.scorer, + proba=None, # TODO: Combine stages probability? ) def copy(self): @@ -568,6 +590,7 @@ def copy(self): freq=self.freq, start=self.start, scorer=self.scorer, + proba=self.proba, ) def find_periods(self, threshold="5min", equal_length=False): @@ -1055,6 +1078,7 @@ def upsample(self, new_freq, **kwargs): freq=new_freq, start=self.start, scorer=self.scorer, + proba=None, # NOTE: Do not upsample probability ) def upsample_to_data(self, data, sf=None, verbose=True): diff --git a/yasa/staging.py b/yasa/staging.py index 9a59ffa..8b42705 100644 --- a/yasa/staging.py +++ b/yasa/staging.py @@ -4,6 +4,7 @@ import glob import joblib import logging +import warnings import numpy as np import pandas as pd import antropy as ant @@ -104,9 +105,9 @@ class SleepStaging: In addition with the predicted sleep stages, YASA can also return the predicted probabilities of each sleep stage at each epoch. This can be used to derive a confidence score at each epoch. - .. important:: The predictions should ALWAYS be double-check by a trained - visual scorer, especially for epochs with low confidence. A full - inspection should be performed in the following cases: + .. important:: The predictions should ALWAYS be double-check by a trained visual scorer, + especially for epochs with low confidence. A full inspection should be performed in the + following cases: * Nap data, because the classifiers were exclusively trained on full-night recordings. * Participants with sleep disorders. @@ -122,13 +123,11 @@ class SleepStaging: If you use YASA's default classifiers, these are the main references for the `National Sleep Research Resource `_: - * Dean, Dennis A., et al. "Scaling up scientific discovery in sleep - medicine: the National Sleep Research Resource." Sleep 39.5 (2016): - 1151-1164. + * Dean, Dennis A., et al. "Scaling up scientific discovery in sleep medicine: the National + Sleep Research Resource." Sleep 39.5 (2016): 1151-1164. - * Zhang, Guo-Qiang, et al. "The National Sleep Research Resource: towards - a sleep data commons." Journal of the American Medical Informatics - Association 25.10 (2018): 1351-1358. + * Zhang, Guo-Qiang, et al. "The National Sleep Research Resource: towards a sleep data + commons." Journal of the American Medical Informatics Association 25.10 (2018): 1351-1358. Examples -------- @@ -143,12 +142,15 @@ class SleepStaging: >>> sls = yasa.SleepStaging(raw, eeg_name="C4-M1", eog_name="LOC-M2", ... emg_name="EMG1-EMG2", ... metadata=dict(age=29, male=True)) + >>> # Print some basic info + >>> sls >>> # Get the predicted sleep stages - >>> hypno = sls.predict() + >>> hyp = sls.predict() + >>> hyp.hypno >>> # Get the predicted probabilities - >>> proba = sls.predict_proba() + >>> hyp.proba >>> # Get the confidence - >>> confidence = proba.max(axis=1) + >>> confidence = hyp.proba.max(axis=1) >>> # Plot the predicted probabilities >>> sls.plot_predict_proba() @@ -159,10 +161,10 @@ class SleepStaging: def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None): # Type check - assert isinstance(eeg_name, str) - assert isinstance(eog_name, (str, type(None))) - assert isinstance(emg_name, (str, type(None))) - assert isinstance(metadata, (dict, type(None))) + assert isinstance(eeg_name, str), "`eeg_name` must be a string." + assert isinstance(eog_name, (str, type(None))), "`eog_name` must be a string or None." + assert isinstance(emg_name, (str, type(None))), "`emg_name` must be a string or None." + assert isinstance(metadata, (dict, type(None))), "`metadata` must be a string or None." # Validate metadata if isinstance(metadata, dict): @@ -173,7 +175,7 @@ def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None assert metadata["male"] in [0, 1], "male must be 0 or 1." # Validate Raw instance and load data - assert isinstance(raw, mne.io.BaseRaw), "raw must be a MNE Raw object." + assert isinstance(raw, mne.io.BaseRaw), "`raw` must be a MNE Raw object." sf = raw.info["sfreq"] ch_names = np.array([eeg_name, eog_name, emg_name]) ch_types = np.array(["eeg", "eog", "emg"]) @@ -210,6 +212,22 @@ def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None self.data = data self.metadata = metadata + def __repr__(self): + n_samples = self.data.shape[-1] + duration = (n_samples / self.sf) / 60 + return ( + f"" + ) + + def __str__(self): + n_samples = self.data.shape[-1] + duration = n_samples / self.sf + return ( + f"" + ) + def fit(self): """Extract features from data. @@ -419,9 +437,13 @@ def predict(self, path_to_model="auto"): Returns ------- - pred : :py:class:`numpy.ndarray` - The predicted sleep stages. + pred : :py:class:`yasa.Hypnogram` + The predicted sleep stages. Since YASA v0.7, the predicted sleep stages are now + returned as a :py:class:`yasa.Hypnogram` instance, which also includes the + probability of each sleep stage for each epoch. """ + from yasa.hypno import Hypnogram + if not hasattr(self, "_features"): self.fit() # Load and validate pre-trained classifier @@ -430,10 +452,15 @@ def predict(self, path_to_model="auto"): X = self._features.copy()[clf.feature_name_] # Predict the sleep stages and probabilities self._predicted = clf.predict(X) - proba = pd.DataFrame(clf.predict_proba(X), columns=clf.classes_) - proba.index.name = "epoch" + # Predict the probabilities + classes = clf.classes_.copy() + classes[classes == "W"] = "WAKE" # Compat for yasa.Hypnogram + classes[classes == "R"] = "REM" + proba = pd.DataFrame(clf.predict_proba(X), columns=classes) + proba.index.name = "Epoch" self._proba = proba - return self._predicted.copy() + # Convert to a `yasa.Hypnogram` instance (including `proba`) + return Hypnogram(values=self._predicted.copy(), freq="30s", n_stages=5, proba=proba.copy()) def predict_proba(self, path_to_model="auto"): """ @@ -454,6 +481,12 @@ def predict_proba(self, path_to_model="auto"): proba : :py:class:`pandas.DataFrame` The predicted probability for each sleep stage for each 30-sec epoch of data. """ + warnings.warn( + "The `predict_proba` function is deprecated and will be removed in v0.8. " + "The predicted probabilities can now be accessed with `yasa.Hypnogram.proba` instead, " + "e.g `SleepStaging.predict().proba`", + FutureWarning, + ) if not hasattr(self, "_proba"): self.predict(path_to_model) return self._proba.copy() @@ -475,19 +508,18 @@ def plot_predict_proba( If True, probabilities of the non-majority classes will be set to 0. """ if proba is None and not hasattr(self, "_features"): - raise ValueError("Must call .predict_proba before this function") + raise ValueError("Must call `.predict` before this function") if proba is None: proba = self._proba.copy() else: - assert isinstance(proba, pd.DataFrame), "proba must be a dataframe" + assert isinstance(proba, pd.DataFrame), "`proba` must be a pandas.DataFrame" if majority_only: cond = proba.apply(lambda x: x == x.max(), axis=1) proba = proba.where(cond, other=0) ax = proba.plot(kind="area", color=palette, figsize=(10, 5), alpha=0.8, stacked=True, lw=0) # Add confidence # confidence = proba.max(1) - # ax.plot(confidence, lw=1, color='k', ls='-', alpha=0.5, - # label='Confidence') + # ax.plot(confidence, lw=1, color='k', ls='-', alpha=0.5, label='Confidence') ax.set_xlim(0, proba.shape[0]) ax.set_ylim(0, 1) ax.set_ylabel("Probability") diff --git a/yasa/tests/test_staging.py b/yasa/tests/test_staging.py index a87edd8..e396431 100644 --- a/yasa/tests/test_staging.py +++ b/yasa/tests/test_staging.py @@ -3,6 +3,7 @@ import unittest import numpy as np import matplotlib.pyplot as plt +from yasa.hypno import Hypnogram from yasa.staging import SleepStaging ############################################################################## @@ -11,7 +12,7 @@ # MNE Raw raw = mne.io.read_raw_fif("notebooks/sub-02_mne_raw.fif", preload=True, verbose=0) -hypno = np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str) +y_true = Hypnogram(np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str)) class TestStaging(unittest.TestCase): @@ -22,12 +23,18 @@ def test_sleep_staging(self): sls = SleepStaging( raw, eeg_name="C4", eog_name="EOG1", emg_name="EMG1", metadata=dict(age=21, male=False) ) + print(sls) + print(str(sls)) sls.get_features() y_pred = sls.predict() + assert isinstance(y_pred, Hypnogram) + assert y_pred.proba is not None proba = sls.predict_proba() - assert y_pred.size == hypno.size + assert y_pred.hypno.size == y_true.hypno.size + assert y_true.duration == y_pred.duration + assert y_true.n_stages == y_pred.n_stages # Check that the accuracy is at least 80% - accuracy = (hypno == y_pred).sum() / y_pred.size + accuracy = (y_true.hypno == y_pred.hypno).mean() assert accuracy > 0.80 # Plot