diff --git a/.gitignore b/.gitignore
index 346afd4..4e5f94b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -143,3 +143,4 @@ notebooks/20_catch_errors.ipynb
# Custom
notebooks/debug*
+notebooks/my_hypno.csv
\ No newline at end of file
diff --git a/docs/changelog.rst b/docs/changelog.rst
index c89a363..d3b230b 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -29,18 +29,23 @@ which comes with several pre-built functions (aka methods) and attributes. See f
hyp.duration # Total duration of the hypnogram, in minutes
hyp.sampling_frequency # Sampling frequency of the hypnogram
hyp.mapping # Mapping from strings to integers
+ hyp.proba # Probability of each sleep stage, if specified
# Below are some class methods
hyp.sleep_statistics() # Calculate the sleep statistics
hyp.plot_hypnogram() # Plot the hypnogram
hyp.upsample_to_data() # Upsample to data
-Please see the documentation of :py:class:`yasa.Hypnogram` for more details.
+This brings along critical changes to several YASA function, for example:
-.. important::
- The adoption of object-oriented :py:class:`yasa.Hypnogram` usage brings along critical changes to several YASA function, for example:
+* :py:class:`yasa.SleepStaging` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`. The probability of each sleep stage for each epoch can now be accessed with :py:attr:`yasa.Hypnogram.proba`.
+* :py:func:`yasa.simulate_hypnogram` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`.
+* The suggested approach to plotting hypnograms is through the :py:meth:`yasa.Hypnogram.plot_hypnogram` method. The old function :py:func:`yasa.plot_hypnogram` still exists, but now *requires* a :py:class:`yasa.Hypnogram` instance as input.
+
+**Other improvements**
- * :py:func:`yasa.simulate_hypnogram` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`.
- * The suggested approach to plotting hypnograms is through the :py:meth:`yasa.Hypnogram.plot_hypnogram` method. The old function :py:func:`yasa.plot_hypnogram` still exists, but now *requires* a :py:class:`yasa.Hypnogram` instance as input.
+* Added helpful string representation (__repr__) to :py:class:`yasa.SleepStaging`.
+* :py:func:`yasa.simulate_hypnogram` now returns a :py:class:`yasa.Hypnogram` instead of a :py:class:`numpy.ndarray`.
+* The suggested approach to plotting hypnograms is through the :py:meth:`yasa.Hypnogram.plot_hypnogram` method. The old function :py:func:`yasa.plot_hypnogram` still exists, but now *requires* a :py:class:`yasa.Hypnogram` instance as input.
----------------------------------------------------------------------------------------
diff --git a/notebooks/14_automatic_sleep_staging.ipynb b/notebooks/14_automatic_sleep_staging.ipynb
index 3aad21e..706c138 100644
--- a/notebooks/14_automatic_sleep_staging.ipynb
+++ b/notebooks/14_automatic_sleep_staging.ipynb
@@ -49,24 +49,29 @@
{
"data": {
"text/html": [
- "\n",
- "\n",
"
\n",
"
\n",
"
Measurement date
\n",
+ " \n",
"
January 15, 2016 14:01:00 GMT
\n",
" \n",
"
\n",
"
\n",
"
Experimenter
\n",
- "
Unknown
\n",
+ " \n",
+ "
Unknown
\n",
+ " \n",
"
\n",
"
Participant
\n",
- "
Unknown
\n",
+ " \n",
+ "
Unknown
\n",
+ " \n",
" \n",
"
\n",
"
Digitized points
\n",
+ " \n",
"
15 points
\n",
+ " \n",
"
\n",
"
\n",
"
Good channels
\n",
@@ -83,28 +88,36 @@
"
\n",
"
ECG channels
\n",
"
Not available
\n",
+ " \n",
"
\n",
"
Sampling frequency
\n",
"
100.00 Hz
\n",
"
\n",
+ " \n",
+ " \n",
"
\n",
"
Highpass
\n",
"
0.00 Hz
\n",
"
\n",
+ " \n",
+ " \n",
"
\n",
"
Lowpass
\n",
"
50.00 Hz
\n",
"
\n",
- "\n",
+ " \n",
+ " \n",
+ " \n",
"
\n",
"
Filenames
\n",
"
sub-02_mne_raw.fif
\n",
"
\n",
+ " \n",
"
\n",
"
Duration
\n",
- "
00:48:59 (HH:MM:SS)
\n",
+ "
00:48:60 (HH:MM:SS)
\n",
"
\n",
- "
\n"
+ ""
],
"text/plain": [
""
@@ -149,8 +162,66 @@
],
"source": [
"# Let's now load the human-scored hypnogram, where each value represents a 30-sec epoch.\n",
- "hypno = np.loadtxt('sub-02_hypno_30s.txt', dtype=str)\n",
- "hypno"
+ "hyp = np.loadtxt('sub-02_hypno_30s.txt', dtype=str)\n",
+ "hyp"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Epoch\n",
+ "0 WAKE\n",
+ "1 WAKE\n",
+ "2 WAKE\n",
+ "3 WAKE\n",
+ "4 WAKE\n",
+ " ... \n",
+ "93 WAKE\n",
+ "94 WAKE\n",
+ "95 WAKE\n",
+ "96 WAKE\n",
+ "97 WAKE\n",
+ "Name: Stage, Length: 98, dtype: category\n",
+ "Categories (7, object): ['WAKE', 'N1', 'N2', 'N3', 'REM', 'ART', 'UNS']"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Convert it to a Hypnogram instance, which is the preferred way to manipulate hypnograms since v0.7\n",
+ "hyp = yasa.Hypnogram(hyp, freq=\"30s\")\n",
+ "# The hypnogram values can be obtained with\n",
+ "hyp.hypno"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAD5CAYAAADFnCTwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAxOAAAMTgF/d4wjAAAZK0lEQVR4nO3dfXAU9eHH8U9CCKLyJCBRQrgCCRRIcmh4FExigCKCUIcH5aE8jGjotDE9QKoQRYIPKHM82AeHNjwVpKgopSJW0MTSAj4wpOVZwnAhISF0UhDCUxLY3x8O9zMlypnkspdv3q+ZzGR3b/f7uexs5jO7e7dBlmVZAgAAMEyw3QEAAAD8gZIDAACMRMkBAABGouQAAAAjUXIAAICRKDkAAMBIlBwAAGCkeltyunTpYncEAADgR/W25JSUlNgdAQAA+FG9LTkAAMBslBwAAGAkSg4AADASJQcAABiJkgMAAIxEyQEAAEai5AAAACNRcgAAgJEoOQAAwEiUHAAAYCRKDgAAMJJfS84LL7ygxx9/3Dv9j3/8Q0FBQcrKyvLOS05OVlpamiRp0qRJatq0qS5cuFBhOw6HQ9nZ2ZKky5cva8SIERo9erRKS0s1efJktW3bVk6n0/uzcuVKf74tAABQB/i15CQmJlYoNJmZmerdu/cN8x544AGdO3dOf/3rXxUbG6u333670u2dP39eDz74oNq0aaMNGzYoNDRUkjRr1ixlZ2d7f6ZMmeLPtwUAAOqAEH9uvE+fPiooKFB+fr7Cw8OVlZWl5557Tq+++qokqbCwUCdOnFDfvn21evVqDRw4UI899pjcbrcmT55cYVvFxcWaNm2akpKS9Morr1Q7W1FRkbp161bt7cAeHTt21ObNmyvMe/jhh3Xs2DG/j+OLH5KlqmP4w3flDqSMAOzjj/+z1XHgwIHvXe7XkhMaGqp+/fopMzNTY8aM0fHjxzV06FClpKTo8uXLyszMVN++fXXLLbcoIyND8+fPV1JSkqZPn64jR46oc+fO3m2NHTtW06ZN08svv3zDOK+99ppWrVrlnX799dc1YMCACq9xu91yu93e6WvXrtX8G0atyMnJUWlp6Q0H2qFDh3TixAl16tTJr+P4wtcs1RnDHyrLHWgZAdinpv/P+ptfS470/5es2rdvr169ekn65gzPrl27lJWVpcTERO3bt0+FhYUaPHiwgoODNWHCBK1YsUILFy70buehhx7SO++8o5///Odq165dhTFmzZql1NTU783hcrnkcrm80+Hh4TdtgAhM3bp1U2lpaaXLOnXqVGP79fvG8YUvWao7hj/8b+5AzAjAPjX5f9bfaqXkZGRkKCIiQgkJCZKk+Ph4ZWZmKjMzU6tWrVJGRobOnz+vDh06SJLKysp07do1vfjiiwoJ+Sbir371K8XGxiohIUGZmZmKiIjwd3QAAFCH+f0j5D179tTp06e1bt26CiXnz3/+swoLCxUXF6e1a9dq9+7d8ng88ng8OnnypCIiIrRly5YK23K5XPrlL3+phIQE5ebm+js6AACow/xecho2bKj+/fvr/Pnz6tKliyQpKipK58+fV//+/fWXv/xF7du39y67bvz48crIyLhhe6mpqUpNTVV8fLyOHz8u6Zt7cr79EfLXXnvN328LAAAEOL9frpKkDz/88IZ5hYWF3t/HjBlzw/KUlBSlpKRIkjwez3cu+/YNxwAAANfxjccAAMBIlBwAAGAkSg4AADASJQcAABiJkgMAAIxEyQEAAEai5AAAACNRcgAAgJEoOQAAwEiUHAAAYCRKDgAAMBIlBwAAGImSAwAAjETJAQAARqLkAAAAI1FyAACAkSg5AADASJQcAABgJEoOAAAwEiUHAAAYiZIDAACMRMkBAABGouQAAAAjUXIAAICRKDkAAMBIlBwAAGAkSg4AADASJQcAABiJkgMAAIxEyQEAAEai5AAAACNRcgAAgJEoOQAAwEi2lhyHw6HOnTvL6XSqc+fOeuWVVyRJHo9HDRo0kNPp9P707t3buywoKEgjRoyosK3nn39eQUFB2rRpU22/DQAAEIBC7A6wYcMGOZ1OnTx5Ul27dtUDDzygO++8U02aNFF2dnal6zRr1kxfffWVioqK1KZNG127dk3r169XdHR07YYHAAABK2AuV7Vt21ZdunRRbm6uT6+fMGGC1qxZI0navn27evTooTvuuMOfEQEAQB0SMCXn8OHDKi4uVkJCgiTp/PnzFS5XjR8/vsLrJ02apNWrV0uSVqxYoalTp37v9t1ut8LDw70/JSUlfnkfAAAgMNh+uWrs2LEKDg7WkSNHtHjxYrVu3VoXLlz43stVkrxl5f3339eePXv05ptv6uWXX/7O17tcLrlcrgrrAwAAc9l+JmfDhg06dOiQPvroI/3617/Wvn37fF53ypQpmjJlih599FEFB9v+VgAAQAAJmGYwcOBATZ8+XXPnzvV5nZEjR2rmzJlKTk72YzIAAFAX2X656tvS0tLUqVMnFRcXe+/J+bYdO3ZUmG7UqJFmz55diwkBAEBdYWvJ8Xg8FaZbtGih4uJiSdLVq1crXadJkyY6e/ZspcuysrJqMB0AAKjLAuZyFQAAQE2i5AAAACNRcgAAgJEoOQAAwEiUHAAAYCRKDgAAMBIlBwAAGImSAwAAjETJAQAARqLkAAAAI1FyAACAkSg5AADASJQcAABgJEoOAAAwEiUHAAAYiZIDAACMRMkBAABGouQAAAAjUXIAAICRKDkAAMBIlBwAAGAkSg4AADASJQcAABjJ55Lz9ddf6xe/+IWGDRsmSTp48KDWr1/vt2AAAADV4XPJefLJJxUWFiaPxyNJ+tGPfqSFCxf6KxcAAEC1+FxyvvrqK82dO1cNGzaUJDVu3FiWZfktGAAAQHX4XHJCQ0MrTF+6dImSAwAAApbPJScxMVEvvviiLl++rO3bt2vUqFF65JFH/JkNAACgynwuOenp6QoODlbTpk317LPP6r777lNaWpo/swEAAFRZiM8vDAnRM888o2eeecafeQAAAGqEzyVn/vz5N8xr3ry5+vbtq549e9ZoKAAAgOry+XLVoUOH9Pvf/155eXnKz8/XG2+8oaysLI0fP17Lli3zZ0YAAIAfzOczOWfOnFF2drbatGkjSSoqKtLEiRO1e/duDRgwQCkpKX4LCQAA8EP5fCYnPz/fW3AkqU2bNiooKNAdd9zh/e4cf3A4HOrSpYvKy8u98+Li4pSVlaUtW7bo3nvvVaNGjZSamuq3DAAAoO7xueS0bdtWL7zwgvLy8pSXl6f58+fr7rvv1tWrVxUUFOTPjLpy5YoyMjJumB8ZGakVK1Zo1qxZfh0fAADUPT5frlq9erVSUlLkdDolSUlJSVq1apXKysq0Zs0af+WTJM2bN09z5szRxIkTdeutt3rnR0VFSZLee+89v46PwJObm6shQ4bcMC8yMtLv4/i6nq9ZqjrG94mIiNDy5ctrbHuVZazpMQCgpvlccsLCwvTWW29Vuiw6OrrGAlUmNjZWiYmJWrx4sebMmVOlbbjdbrndbu90SUlJTcVDLevYsWOl8yMjI79zWU2O4wtfs9Rk3utycnJqdHuVZazpMQDAH3wuOZJUUFCg/fv36/Lly955Dz/8cI2Hqkx6erp69eql5OTkKq3vcrnkcrm80+Hh4TUVDbVs8+bNxozjjzG6deum0tLSGtteZRlregwA8AefS86KFSs0f/58/fe//1VkZKT+9a9/qU+fPrVWchwOh8aNG6cFCxbUyngAAKBu8/nG48WLF2vv3r3q2LGj9uzZo08++cR7T0xtmTt3rtauXauCgoJaHRcAANQ9P+gp5C1atPB+lPv+++9Xdna2v3JVqlWrVkpJSVFhYaEk6eOPP1Z4eLjcbrcyMjIUHh5ea5cyAABAYPP5clWjRo1kWZaioqK0ZMkStW/fvlZu3vV4PBWm09LSKjwYND8/3+8ZAABA3eNzyVmwYIHOnTunV199VcnJyTp79qx+97vf+TMbAABAlflcclq1aqVmzZqpWbNm2rZtmyTp3//+t9+CAQAAVIfP9+RMnjzZp3kAAACB4KZnck6fPq1Tp07p0qVL2rdvnyzLkiSdPXtWFy5c8HtAAACAqrhpyVm/fr2WLFmigoIC73fiBAUFqWnTpnr66af9HhAAAKAqblpynnrqKT311FNKT09XWlqajh07ps2bN6tTp04aPnx4bWQEAAD4wW56T87AgQOVnZ2ttLQ0FRQUqGfPnvroo4/09NNPa+HChbWREQAA4Ae7ack5efKk98njb775puLj47V161bt3LlT69at83c+AACAKrlpyWncuLH39507d2ro0KGSpBYtWigk5Ac93xMAAKDW3LTkBAcHKz8/XyUlJfr0008VHx/vXXbx4kW/hgMAAKiqm56KefbZZ9WjRw+FhIQoMTHR+1DOnTt3yuFw+DsfAABAldy05DzyyCPq16+fioqKFBMT453vcDi0fPlyv4YDAACoKp9uqgkLC1NYWFiFeXfffbdfAgEAANQEnx/rAAAAUJdQcgAAgJEoOQAAwEiUHAAAYCRKDgAAMBIlBwAAGImSAwAAjETJAQAARqLkAAAAI1FyAACAkSg5AADASJQcAABgJEoOAAAwEiUHAAAYiZIDAACMRMkBAABGouQAAAAjUXIAAICRKDkAAMBIAV9yHA6HunTpovLycu+8uLg4ZWVladmyZerevbuio6MVExOjtWvX2pgUAAAEkoAvOZJ05coVZWRk3DC/W7du+uc//6l9+/Zpy5YtSk1N1bFjx2xICAAAAk2dKDnz5s1Tenq6Ll68WGF+UlKSmjVrJklq166dwsLClJeXZ0dEAAAQYELsDuCL2NhYJSYmavHixZozZ06lr9m+fbvOnDmjnj171nI6IPDk5uZqyJAhVVovMjLSr2MAERERWr58ud0xUA/UiZIjSenp6erVq5eSk5NvWLZv3z5NmTJFGzZs0G233Vbp+m63W2632ztdUlLit6yAnTp27FjldSMjI31avzpjoH7LycmxOwLqkTpTchwOh8aNG6cFCxZUmH/w4EENGzZMK1asUP/+/b9zfZfLJZfL5Z0ODw/3W1bATps3bzZiDJipW7duKi0ttTsG6ok6U3Ikae7cufrxj3+shg0bSpIOHTqkoUOHavny5Ro0aJDN6QAAQCCpEzceX9eqVSulpKSosLBQkpSSkqKvv/5as2fPltPplNPp1N/+9jebUwIAgEAQZFmWZXcIO4SHhys/P9/uGABQr1y/XPXhhx/aHQVVMGTIEIWGhurAgQN2R/FJnTqTAwAA4CtKDgAAMBIlBwAAGImSAwAAjETJAQAARqLkAAAAI1FyAACAkSg5AADASJQcAABgJEoOAAAwEiUHAAAYiZIDAACMRMkBAABGouQAAAAjUXIAAICRKDkAAMBIlBwAAGAkSg4AADASJQcAABiJkgMAAIxEyQEAAEai5AAAACNRcgAAgJEoOQAAwEiUHAAAYCRKDgAAMBIlBwAAGImSAwAAjETJAQAARqLkAAAAI1FyAACAkSg5AADASJQcAABgpIAvOQ6HQ126dFF5ebl3XlxcnLKysvTb3/5W0dHRcjqd6t69u5YtW2ZjUgAAEEgCvuRI0pUrV5SRkXHD/AkTJmjfvn3Kzs7Wzp07tWjRIu3du9eGhAAAINDUiZIzb948paen6+LFixXmN2vWzPv7hQsXVFZWVtvRAABAgAqxO4AvYmNjlZiYqMWLF2vOnDkVlr3zzjt6/vnnlZOTo5deekk9evSwKSUAwBe5ubkaMmSI3TFQBbm5uYqMjLQ7hs/qRMmRpPT0dPXq1UvJyckV5o8aNUqjRo2Sx+PRT3/6Uw0bNkydO3e+YX232y232+2dLikp8XtmAEBFHTt2tDsCqiEyMrJO7cM6U3IcDofGjRunBQsWfOfy3r176/3336+05LhcLrlcLu90eHi437ICACq3efNmuyOgHqkT9+RcN3fuXK1du1YFBQWSpIMHD3qX/ec//9Enn3yimJgYu+IBAIAAUqdKTqtWrZSSkqLCwkJJ0tKlS9W1a1c5nU4NHDhQqampGjRokM0pAQBAIAiyLMuyO4QdwsPDlZ+fb3cMAADgJ3XqTA4AAICvKDkAAMBIlBwAAGAkSg4AADASJQcAABiJkgMAAIxEyQEAAEai5AAAACNRcgAAgJEoOQAAwEiUHAAAYCRKDgAAMFK9fUBnSEiIwsLC7I6B/1FSUqLbb7/d7hj4H+yXwMR+CUzsl9pz++236/Dhw9+5PKQWswSUsLAwnkIegHg6fGBivwQm9ktgYr8EDi5XAQAAI1FyAACAkeptyXG5XHZHQCXYL4GJ/RKY2C+Bif0SOOrtjccAAMBs9fZMDgAAMBslBwAAGKlelpyjR4+qX79+ioqKUs+ePXXgwAG7I9U7KSkpcjgcCgoKUnZ2tnc++8Zely9f1siRIxUVFaXY2FgNGjRIOTk5kqTTp09ryJAhioyMVPfu3fX3v//d5rT1y+DBgxUTEyOn06kBAwZo7969kjhmAsHKlSsVFBSkTZs2SeJYCShWPZSYmGitXLnSsizLevvtt624uDh7A9VDn376qZWXl2e1b9/e2rt3r3c++8Zely5dsrZs2WJdu3bNsizLev311634+HjLsixrypQp1vPPP29ZlmV9/vnnVtu2ba3S0lKbktY/Z86c8f7+7rvvWjExMZZlcczY7fjx41bfvn2tPn36WO+9955lWRwrgaTelZyioiKrSZMmVllZmWVZlnXt2jWrTZs21tGjR21OVj99u+SwbwLPF198YbVv396yLMu67bbbrMLCQu+ynj17Wtu2bbMpWf22cuVKKzY2lmPGZlevXrWSkpKsL7/80oqPj/eWHI6VwFHvLlfl5eXprrvuUkjIN1/2HBQUpIiICJ04ccLmZGDfBJ6lS5dqxIgRKi4uVllZWYVHoTgcDvZNLfvZz36mdu3aKS0tTX/60584Zmzmdrt133336d577/XO41gJLPX2sQ4Avt9LL72knJwcffzxx7p06ZLdcSBpzZo1kqTVq1dr9uzZSk9PtzlR/bV//35t3LiR+20CXL07k9OuXTsVFhaqvLxckmRZlk6cOKGIiAibk4F9EzgWLVqkd999V1u3btWtt96qli1bKiQkRKdOnfK+xuPxsG9sMmnSJGVmZio8PJxjxiY7duyQx+NRZGSkHA6Hdu/erSeeeEJvvfUWx0oAqXcl584779Q999yjtWvXSpI2btyo8PBwderUyeZkYN8EBrfbrfXr12vbtm1q3ry5d/7o0aP1xhtvSJK++OILnTx5UvHx8TalrF/Onj2rgoIC7/SmTZvUsmVLjhkbTZ8+XYWFhfJ4PPJ4POrTp4+WL1+u6dOnc6wEkHr5jcdHjhzR5MmTVVxcrKZNm2rlypWKjo62O1a98uSTT2rLli06deqUWrZsqSZNmignJ4d9Y7P8/Hy1a9dOHTp0UJMmTSRJjRo10meffaaioiJNnDhRx48fV2hoqH7zm98oMTHR5sT1Q25urkaPHq1Lly4pODhYrVu31qJFi+R0OjlmAkRCQoJSU1M1cuRIjpUAUi9LDgAAMF+9u1wFAADqB0oOAAAwEiUHAAAYiZIDAACMRMkBAABGouQAAAAjUXIA+IXT6ZTT6VTXrl3VoEED7/TYsWP13HPPad26dX4bOysrS40bN5bT6dTp06d/8PpDhw7VkSNHqjz++PHjFRYWptTU1CpvA0D18ewqAH6RnZ0t6ZuvtHc6nd7p2tK5c+cqj/nBBx9Ua+x169Zp3rx5Onv2bLW2A6B6OJMDoNZNnjxZS5YskSTNmzdPY8aM0fDhwxUVFaVhw4Zp//79+slPfqKoqCg99thjunbtmiTp/PnzmjZtmnr16qWYmBg98cQTKi0t9WnMhIQEzZgxQ/fff78iIiKUlpamDz74QP3795fD4ZDb7fa+1uFweAtSQkKCZs6cqQEDBqhjx45KTk72vu6Pf/yjunbtKqfTqejoaH322Wc18wcCUCM4kwPAdl9++aX27Nmj5s2bKyEhQY8//ri2bdumxo0bKy4uTlu3btVDDz2kGTNmaMCAAfrDH/4gy7I0bdo0LV26VLNmzfJpnNzcXGVmZurcuXNyOBw6c+aMduzYoYKCAnXu3FlTp06t8Lyu644dO6bMzEyVlZWpa9eu2rVrl/r27asZM2bo8OHDuuuuu1RWVqYrV67U8F8GQHVQcgDYbvDgwWrRooUk6Z577lGjRo28z87q0aOHjh49KumbB1Pu2rXLe9bl0qVLatCggc/jjBo1Sg0aNFCLFi3UoUMHDRs2TEFBQWrbtq1at27tvbT2v8aOHauQkBCFhITI6XTq2LFj6tu3r5KSkjRx4kQNHz5cDz74oKKioqr5lwBQkyg5AGx3yy23eH9v0KDBDdPl5eWSJMuytHHjxiqXCV/Hudl611+3ceNG7dmzR1lZWRo6dKgWLFigRx99tErZANQ87skBUGeMHDlSCxcu9JaMM2fOKCcnx5Ys5eXlOnbsmOLi4jRz5kyNGjVKn3/+uS1ZAFSOkgOgzli8eLH3o+ExMTFKSkqSx+OxJcvVq1c1depUde/eXU6nU3v27JHL5bIlC4DKBVmWZdkdAgBqUlZWllJTU2v9Y+vfdv0j5Nc/RQag9nEmB4BxQkNDVVxcXOUvA6yu8ePHa+3atWratGmtjw3g/3EmBwAAGIkzOQAAwEiUHAAAYCRKDgAAMBIlBwAAGImSAwAAjETJAQAARvo/LcBvvqp03J8AAAAASUVORK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Let's plot it\n",
+ "fig, ax = plt.subplots(1, 1, figsize=(7, 3), constrained_layout=True, dpi=80)\n",
+ "ax = hyp.plot_hypnogram(fill_color=\"gainsboro\", ax=ax)"
]
},
{
@@ -164,9 +235,20 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 9,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
"# We first need to specify the channel names and, optionally, the age and sex of the participant\n",
"# - \"raw\" is the name of the variable containing the polysomnography data loaded with MNE.\n",
@@ -174,38 +256,35 @@
"# - \"eog_name\" is the name of the EOG channel (e.g. LOC-M1). This is optional.\n",
"# - \"eog_name\" is the name of the EOG channel (e.g. EMG1-EMG3). This is optional.\n",
"# - \"metadata\" is a dictionary containing the age and sex of the participant. This is optional.\n",
- "sls = yasa.SleepStaging(raw, eeg_name=\"C4\", eog_name=\"EOG1\", emg_name=\"EMG1\", metadata=dict(age=21, male=False))"
+ "sls = yasa.SleepStaging(raw, eeg_name=\"C4\", eog_name=\"EOG1\", emg_name=\"EMG1\", metadata=dict(age=21, male=False))\n",
+ "sls"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/Users/raphael/.pyenv/versions/3.8.3/lib/python3.8/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
- "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n",
+ "/Users/raphael/.pyenv/versions/3.9.6/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
+ "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
- "array(['W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n",
- " 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n",
- " 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'N2', 'N2', 'N2', 'N2',\n",
- " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2',\n",
- " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2',\n",
- " 'N2', 'N2', 'N2', 'N3', 'N3', 'N3', 'N3', 'N2', 'N3', 'N3', 'N3',\n",
- " 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3',\n",
- " 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'N3', 'W', 'W', 'W', 'W',\n",
- " 'W', 'W', 'W', 'W'], dtype=object)"
+ "\n",
+ " - Use `.hypno` to get the string values as a pandas.Series\n",
+ " - Use `.as_int()` to get the integer values as a pandas.Series\n",
+ " - Use `.plot_hypnogram()` to plot the hypnogram\n",
+ "See the online documentation for more details."
]
},
- "execution_count": 5,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -218,21 +297,122 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Epoch\n",
+ "0 WAKE\n",
+ "1 WAKE\n",
+ "2 WAKE\n",
+ "3 WAKE\n",
+ "4 WAKE\n",
+ " ... \n",
+ "93 WAKE\n",
+ "94 WAKE\n",
+ "95 WAKE\n",
+ "96 WAKE\n",
+ "97 WAKE\n",
+ "Name: Stage, Length: 98, dtype: category\n",
+ "Categories (7, object): ['WAKE', 'N1', 'N2', 'N3', 'REM', 'ART', 'UNS']"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_pred.hypno"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "The overall agreement is 0.837\n"
+ "The overall agreement is 83.67%\n"
]
}
],
"source": [
"# What is the accuracy of the prediction, compared to the human scoring\n",
- "accuracy = (hypno == y_pred).sum() / y_pred.size\n",
- "print(\"The overall agreement is %.3f\" % accuracy)"
+ "accuracy = 100 * (hyp.hypno == y_pred.hypno).mean()\n",
+ "print(f\"The overall agreement is {accuracy:.2f}%\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Plot and sleep statistics**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAD5CAYAAADFnCTwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAxOAAAMTgF/d4wjAAAYoUlEQVR4nO3de1DVdf7H8ReImlveUpMS8axyWxU4FpiYBoS6Zppuo1mp62XScGeXWNTcUsrELpZzUNtL4y7eVnOtbF13zd20oHVX7eLILl4TxiMgiDusJigKyPf3R9P5xWp5RA7n8OH5mGGG7/ec8/2+8TunnvP9noufZVmWAAAADOPv7QEAAAA8gcgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJFabORERER4ewQAAOBBLTZyKisrvT0CAADwoBYbOQAAwGxEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIzk0ch58cUX9eSTT7qW//GPf8jPz085OTmudcnJyUpPT5ckTZ06VR06dNCFCxfqbcdmsyk3N1eSdOnSJY0dO1YTJkxQdXW1pk2bph49eshut7t+1qxZ48k/CwAANAMejZzExMR6QZOdna177733qnUPPPCAzp8/rz//+c+Kjo7WO++8c83tVVRU6MEHH1T37t21efNmtWnTRpI0b9485ebmun6mT5/uyT8LAAA0AwGe3PigQYNUUlKi4uJiBQUFKScnR88//7xee+01SVJpaakKCwsVFxendevWadiwYXr88cflcDg0bdq0etsqLy/XzJkzlZSUpFdfffWmZysrK1O/fv1uejtAU+rTp4+2bdvm7TEAtAAPP/ywCgoKvD3Gdzp06NB33u7RyGnTpo0GDx6s7OxsPfroozpx4oRGjRqllJQUXbp0SdnZ2YqLi9Mtt9yirKwsLV68WElJSZo9e7aOHTum8PBw17YmTpyomTNn6pVXXrlqP6+//rrWrl3rWn7jjTc0dOjQevdxOBxyOByu5bq6usb/gwEPys/PV3V1tc//RweAGY4cOaLCwkKFhIR4e5QG82jkSP9/yapXr14aOHCgpK/O8Ozdu1c5OTlKTExUXl6eSktLNWLECPn7+2vy5MlavXq1li5d6trOQw89pHfffVc/+clP1LNnz3r7mDdvnlJTU79zjrS0NKWlpbmWg4KCrluAgC/p16+fqqurvT0GgBYkJCSkWf+/skkiJysrS8HBwUpISJAkxcfHKzs7W9nZ2Vq7dq2ysrJUUVGh3r17S5JqampUV1enl156SQEBX43485//XNHR0UpISFB2draCg4M9PToAAGjGPP4W8tjYWJ05c0YbN26sFzl/+MMfVFpaqpiYGG3YsEH79u2T0+mU0+nUqVOnFBwcrO3bt9fbVlpamn72s58pISFBJ0+e9PToAACgGfN45LRu3VpDhgxRRUWFIiIiJElhYWGqqKjQkCFD9Kc//Um9evVy3fa1SZMmKSsr66rtpaamKjU1VfHx8Tpx4oSkr16T8823kL/++uue/rMAAICP87Msy/L2EN4QFBSk4uJib48BuO3r1+T89a9/9fYoAFqAkSNHqk2bNs36NTl84jEAADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwklcjx2azKTw8XHa7XeHh4Xr11VclSU6nU61atZLdbnf93Hvvva7b/Pz8NHbs2HrbeuGFF+Tn56etW7c29Z8BAAB8UIC3B9i8ebPsdrtOnTqlvn376oEHHtAdd9yh9u3bKzc395qP6dixo7744guVlZWpe/fuqqur06ZNmxQZGdm0wwMAAJ/lM5erevTooYiICJ08edKt+0+ePFnr16+XJO3atUsDBgzQ7bff7skRAQBAM+IzkXP06FGVl5crISFBklRRUVHvctWkSZPq3X/q1Klat26dJGn16tWaMWPGd27f4XAoKCjI9VNZWemRvwMAAPgGr1+umjhxovz9/XXs2DFlZmaqW7duunDhwnderpLkipW//OUv2r9/v9566y298sor33r/tLQ0paWl1Xs8AAAwl9fP5GzevFlHjhzRBx98oF/84hfKy8tz+7HTp0/X9OnT9dhjj8nf3+t/CgAA8CE+UwbDhg3T7NmztXDhQrcfM27cOM2dO1fJyckenAwAADRHXr9c9U3p6ekKCQlReXm56zU537R79+56y23bttX8+fObcEIAANBceDVynE5nveXOnTurvLxcknTlypVrPqZ9+/Y6d+7cNW/LyclpxOkAAEBz5jOXqwAAABoTkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjuR05X375pX76059q9OjRkqTDhw9r06ZNHhsMAADgZrgdOU899ZQCAwPldDolSd///ve1dOlST80FAABwU9yOnC+++EILFy5U69atJUnt2rWTZVkeGwwAAOBmuB05bdq0qbdcVVVF5AAAAJ/lduQkJibqpZde0qVLl7Rr1y6NHz9ejzzyiCdnAwAAaDC3IycjI0P+/v7q0KGDnnvuOd13331KT0/35GwAAAANFuD2HQMC9Oyzz+rZZ5/15DwAAACNwu3IWbx48VXrOnXqpLi4OMXGxjbqUAAAADfL7ctVR44c0W9+8xsVFRWpuLhYb775pnJycjRp0iStXLnSkzMCAADcMLfP5Jw9e1a5ubnq3r27JKmsrExTpkzRvn37NHToUKWkpHhsSAAAgBvl9pmc4uJiV+BIUvfu3VVSUqLbb7/d9dk5nmCz2RQREaHa2lrXupiYGOXk5Gj79u2655571LZtW6WmpnpsBgAA0Py4HTk9evTQiy++qKKiIhUVFWnx4sW66667dOXKFfn5+XlyRl2+fFlZWVlXrQ8NDdXq1as1b948j+4fAAA0P25Hzrp163To0CHZ7XbZ7XYdPHhQa9euVU1NjdavX+/JGbVo0SJlZGTo4sWL9daHhYUpOjpaAQFuX3UDAAAthNt1EBgYqLfffvuat0VGRjbaQNcSHR2txMREZWZmasGCBQ3ahsPhkMPhcC1XVlY21ngAAMAH3dApkJKSEh08eFCXLl1yrXv44YcbfahrycjI0MCBA5WcnNygx6elpSktLc21HBQU1FijAQAAH+R25KxevVqLFy/Wf//7X4WGhupf//qXBg0a1GSRY7PZ9MQTT2jJkiVNsj8AANC8uf2anMzMTB04cEB9+vTR/v379dFHHyksLMyTs11l4cKF2rBhg0pKSpp0vwAAoPm5oW8h79y5s+ut3Pfff79yc3M9Ndc1de3aVSkpKSotLZUkffjhhwoKCpLD4VBWVpaCgoK0bdu2Jp0JAAD4JrcvV7Vt21aWZSksLEzLly9Xr169muTFu06ns95yenp6vS8GLS4u9vgMAACg+XE7cpYsWaLz58/rtddeU3Jyss6dO6df//rXnpwNAACgwdyOnK5du6pjx47q2LGjdu7cKUn697//7bHBAAAAbobbr8mZNm2aW+sAAAB8wXXP5Jw5c0anT59WVVWV8vLyZFmWJOncuXO6cOGCxwcEAABoiOtGzqZNm7R8+XKVlJS4PhPHz89PHTp00DPPPOPxAQEAABriupHz9NNP6+mnn1ZGRobS09NVUFCgbdu2KSQkRGPGjGmKGQEAAG7YdV+TM2zYMOXm5io9PV0lJSWKjY3VBx98oGeeeUZLly5tihkBAABu2HUj59SpU7Lb7ZKkt956S/Hx8dqxY4f27NmjjRs3eno+AACABrlu5LRr1871+549ezRq1ChJUufOnRUQcEPf7wkAANBkrhs5/v7+Ki4uVmVlpT7++GPFx8e7brt48aJHhwMAAGio656Kee655zRgwAAFBAQoMTHR9aWce/bskc1m8/R8AAAADXLdyHnkkUc0ePBglZWVKSoqyrXeZrNp1apVHh0OAACgodx6UU1gYKACAwPrrbvrrrs8MhAAAEBjcPtrHQAAAJoTIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkYgcAABgJCIHAAAYicgBAABGInIAAICRiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARvL5yLHZbIqIiFBtba1rXUxMjHJycrRy5Ur1799fkZGRioqK0oYNG7w4KQAA8CU+HzmSdPnyZWVlZV21vl+/fvrnP/+pvLw8bd++XampqSooKPDChAAAwNc0i8hZtGiRMjIydPHixXrrk5KS1LFjR0lSz549FRgYqKKiIm+MCAAAfEyAtwdwR3R0tBITE5WZmakFCxZc8z67du3S2bNnFRsb28TTAU3n5MmTGjlypLfHgAGCg4O1atWqRtverFmzVFhY6PH9ADeiWUSOJGVkZGjgwIFKTk6+6ra8vDxNnz5dmzdv1q233nrNxzscDjkcDtdyZWWlx2YFPKFPnz7eHgGGyM/Pb/RtFhYWqrCwUCEhIR7dD3Ajmk3k2Gw2PfHEE1qyZEm99YcPH9bo0aO1evVqDRky5Fsfn5aWprS0NNdyUFCQx2YFPGHbtm3eHgGG6Nevn6qrqxt9uyEhITp06JDH9wO4q9lEjiQtXLhQP/jBD9S6dWtJ0pEjRzRq1CitWrVKw4cP9/J0AADAlzSLFx5/rWvXrkpJSVFpaakkKSUlRV9++aXmz58vu90uu92uv/3tb16eEgAA+AKfP5PjdDrrLaenpys9PV2SlJCQ0PQDAQCAZqFZnckBAABwF5EDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACMROQAAwEhEDgAAMJLPR47NZlNERIRqa2td62JiYpSTk6Nf/epXioyMlN1uV//+/bVy5UovTgoAAHyJz0eOJF2+fFlZWVlXrZ88ebLy8vKUm5urPXv2aNmyZTpw4IAXJgQAAL6mWUTOokWLlJGRoYsXL9Zb37FjR9fvFy5cUE1NTVOPBgAAfFSAtwdwR3R0tBITE5WZmakFCxbUu+3dd9/VCy+8oPz8fL388ssaMGCAl6YEgObj5MmTGjlyZKNuLzQ01OP7QdP5tmPanDSLyJGkjIwMDRw4UMnJyfXWjx8/XuPHj5fT6dSPfvQjjR49WuHh4Vc93uFwyOFwuJYrKys9PjMA+KI+ffo0+jZDQ0Ov2q4n9oOmc61j2tz4WZZleXuI72Kz2bR161bZ7XalpqbKz89Pu3fv1rJly5SQkFDvvsnJyQoNDdWcOXOuu92goCAVFxd7aGoAAOBtzeI1OV9buHChNmzYoJKSEknS4cOHXbf95z//0UcffaSoqChvjQcAAHxIs4qcrl27KiUlRaWlpZKkFStWqG/fvrLb7Ro2bJhSU1M1fPhwL08JAAB8gc9frvIULlcBAGC2ZnUmBwAAwF1EDgAAMBKRAwAAjETkAAAAIxE5AADASEQOAAAwEpEDAACMROQAAAAjETkAAMBIRA4AADASkQMAAIxE5AAAACO12C/oDAgIUGBgoLfHwP+orKzUbbfd5u0x8D84Lr6J4+KbOC5N57bbbtPRo0e/9faAJpzFpwQGBvIt5D6Ib4f3TRwX38Rx8U0cF9/B5SoAAGAkIgcAABipxUZOWlqat0fANXBcfBPHxTdxXHwTx8V3tNgXHgMAALO12DM5AADAbEQOAAAwUouMnOPHj2vw4MEKCwtTbGysDh065O2RWpyUlBTZbDb5+fkpNzfXtZ5j412XLl3SuHHjFBYWpujoaA0fPlz5+fmSpDNnzmjkyJEKDQ1V//799fe//93L07YsI0aMUFRUlOx2u4YOHaoDBw5I4jnjC9asWSM/Pz9t3bpVEs8Vn2K1QImJidaaNWssy7Ksd955x4qJifHuQC3Qxx9/bBUVFVm9evWyDhw44FrPsfGuqqoqa/v27VZdXZ1lWZb1xhtvWPHx8ZZlWdb06dOtF154wbIsy/r000+tHj16WNXV1V6atOU5e/as6/f33nvPioqKsiyL54y3nThxwoqLi7MGDRpk/fGPf7Qsi+eKL2lxkVNWVma1b9/eqqmpsSzLsurq6qzu3btbx48f9/JkLdM3I4dj43s+++wzq1evXpZlWdatt95qlZaWum6LjY21du7c6aXJWrY1a9ZY0dHRPGe87MqVK1ZSUpL1+eefW/Hx8a7I4bniO1rc5aqioiLdeeedCgj46sOe/fz8FBwcrMLCQi9PBo6N71mxYoXGjh2r8vJy1dTU1PsqFJvNxrFpYj/+8Y/Vs2dPpaen6/e//z3PGS9zOBy67777dM8997jW8VzxLS32ax0AfLeXX35Z+fn5+vDDD1VVVeXtcSBp/fr1kqR169Zp/vz5ysjI8PJELdfBgwe1ZcsWXm/j41rcmZyePXuqtLRUtbW1kiTLslRYWKjg4GAvTwaOje9YtmyZ3nvvPe3YsUPf+9731KVLFwUEBOj06dOu+zidTo6Nl0ydOlXZ2dkKCgriOeMlu3fvltPpVGhoqGw2m/bt26dZs2bp7bff5rniQ1pc5Nxxxx26++67tWHDBknSli1bFBQUpJCQEC9PBo6Nb3A4HNq0aZN27typTp06udZPmDBBb775piTps88+06lTpxQfH++lKVuWc+fOqaSkxLW8detWdenSheeMF82ePVulpaVyOp1yOp0aNGiQVq1apdmzZ/Nc8SEt8hOPjx07pmnTpqm8vFwdOnTQmjVrFBkZ6e2xWpSnnnpK27dv1+nTp9WlSxe1b99e+fn5HBsvKy4uVs+ePdW7d2+1b99ektS2bVt98sknKisr05QpU3TixAm1adNGv/zlL5WYmOjliVuGkydPasKECaqqqpK/v7+6deumZcuWyW6385zxEQkJCUpNTdW4ceN4rviQFhk5AADAfC3uchUAAGgZiBwAAGAkIgcAABiJyAEAAEYicgAAgJGIHAAAYCQiB4BH2O122e129e3bV61atXItT5w4Uc8//7w2btzosX3n5OSoXbt2stvtOnPmzA0/ftSoUTp27FiD9z9p0iQFBgYqNTW1wdsAcPP47ioAHpGbmyvpq4+0t9vtruWmEh4e3uB9vv/++ze1740bN2rRokU6d+7cTW0HwM3hTA6AJjdt2jQtX75ckrRo0SI9+uijGjNmjMLCwjR69GgdPHhQP/zhDxUWFqbHH39cdXV1kqSKigrNnDlTAwcOVFRUlGbNmqXq6mq39pmQkKA5c+bo/vvvV3BwsNLT0/X+++9ryJAhstlscjgcrvvabDZXICUkJGju3LkaOnSo+vTpo+TkZNf9fve736lv376y2+2KjIzUJ5980jj/QAAaBWdyAHjd559/rv3796tTp05KSEjQk08+qZ07d6pdu3aKiYnRjh079NBDD2nOnDkaOnSofvvb38qyLM2cOVMrVqzQvHnz3NrPyZMnlZ2drfPnz8tms+ns2bPavXu3SkpKFB4erhkzZtT7vq6vFRQUKDs7WzU1Nerbt6/27t2ruLg4zZkzR0ePHtWdd96pmpoaXb58uZH/ZQDcDCIHgNeNGDFCnTt3liTdfffdatu2reu7swYMGKDjx49L+uqLKffu3es661JVVaVWrVq5vZ/x48erVatW6ty5s3r37q3Ro0fLz89PPXr0ULdu3VyX1v7XxIkTFRAQoICAANntdhUUFCguLk5JSUmaMmWKxowZowcffFBhYWE3+S8BoDEROQC87pZbbnH93qpVq6uWa2trJUmWZWnLli0Njgl393O9x319vy1btmj//v3KycnRqFGjtGTJEj322GMNmg1A4+M1OQCajXHjxmnp0qWuyDh79qzy8/O9Mkttba0KCgoUExOjuXPnavz48fr000+9MguAayNyADQbmZmZrreGR0VFKSkpSU6n0yuzXLlyRTNmzFD//v1lt9u1f/9+paWleWUWANfmZ1mW5e0hAKAx5eTkKDU1tcnftv5NX7+F/Ot3kQFoepzJAWCcNm3aqLy8vMEfBnizJk2apA0bNqhDhw5Nvm8A/48zOQAAwEicyQEAAEYicgAAgJGIHAAAYCQiBwAAGInIAQAARiJyAACAkf4PFrQF/9W8l1QAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Plot the predicted hypnogram\n",
+ "fig, ax = plt.subplots(1, 1, figsize=(7, 3), constrained_layout=True, dpi=80)\n",
+ "ax = y_pred.plot_hypnogram(fill_color=\"gainsboro\", ax=ax)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'TIB': 49.0,\n",
+ " 'SPT': 28.0,\n",
+ " 'WASO': 0.0,\n",
+ " 'TST': 28.0,\n",
+ " 'SE': 57.1429,\n",
+ " 'SME': 100.0,\n",
+ " 'SFI': 1.0714,\n",
+ " 'SOL': 17.0,\n",
+ " 'SOL_5min': 17.0,\n",
+ " 'Lat_REM': nan,\n",
+ " 'WAKE': 21.0,\n",
+ " 'N1': 0.0,\n",
+ " 'N2': 15.0,\n",
+ " 'N3': 13.0,\n",
+ " 'REM': 0.0,\n",
+ " '%N1': 0.0,\n",
+ " '%N2': 53.5714,\n",
+ " '%N3': 46.4286,\n",
+ " '%REM': 0.0}"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Calculate the summary sleep statistics of the predicted hypnogram\n",
+ "y_pred.sleep_statistics()"
]
},
{
@@ -244,9 +424,17 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 12,
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/raphael/GitHub/yasa/yasa/staging.py:475: FutureWarning: The `predict_proba` function is deprecated and will be removed in v0.8. The predicted probabilities can now be accessed with `yasa.Hypnogram.proba` instead, e.g `SleepStaging.predict().proba`\n",
+ " warnings.warn(\n"
+ ]
+ },
{
"data": {
"text/html": [
@@ -271,11 +459,11 @@
"
\n",
" \n",
@@ -541,27 +738,35 @@
],
"text/plain": [
" Stage Confidence\n",
- "epoch \n",
- "0 W 0.992055\n",
- "1 W 0.991345\n",
- "2 W 0.991833\n",
- "3 W 0.995557\n",
- "4 W 0.988994\n",
- "5 W 0.986805"
+ "Epoch \n",
+ "0 WAKE 0.992055\n",
+ "1 WAKE 0.991345\n",
+ "2 WAKE 0.991833\n",
+ "3 WAKE 0.995557\n",
+ "4 WAKE 0.988994\n",
+ "5 WAKE 0.986805"
]
},
- "execution_count": 10,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# Let's first create a dataframe with the predicted stages and confidence\n",
- "df_pred = pd.DataFrame({'Stage': y_pred, 'Confidence': confidence})\n",
- "df_pred.head(6)\n",
- "\n",
+ "# We can also add the confidence level:\n",
+ "df_pred = hyp.hypno.to_frame()\n",
+ "df_pred[\"Confidence\"] = confidence\n",
+ "df_pred.head(6)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
"# Now export to a CSV file\n",
- "# df_pred.to_csv(\"my_hypno.csv\")"
+ "df_pred.to_csv(\"my_hypno.csv\")"
]
},
{
@@ -573,33 +778,38 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/Users/raphael/.pyenv/versions/3.8.3/lib/python3.8/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
- "https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations\n",
+ "/Users/raphael/.pyenv/versions/3.9.6/lib/python3.9/site-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator LabelEncoder from version 0.24.2 when using version 1.1.3. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n",
+ "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
- "array(['W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n",
- " 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W',\n",
- " 'W', 'W', 'N1', 'N2', 'W', 'W', 'N2', 'N2', 'R', 'N2', 'R', 'R',\n",
- " 'N2', 'R', 'R', 'N2', 'R', 'R', 'R', 'R', 'R', 'R', 'R', 'R', 'N2',\n",
- " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2',\n",
- " 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N2', 'N3', 'N2', 'N2',\n",
- " 'N3', 'N2', 'N2', 'N3', 'N2', 'N3', 'N2', 'N2', 'N2', 'N3', 'N3',\n",
- " 'N3', 'N2', 'N3', 'N2', 'N3', 'N3', 'W', 'N3', 'W', 'W', 'W', 'W',\n",
- " 'W', 'W'], dtype=object)"
+ "Epoch\n",
+ "0 WAKE\n",
+ "1 WAKE\n",
+ "2 WAKE\n",
+ "3 WAKE\n",
+ "4 WAKE\n",
+ " ... \n",
+ "93 WAKE\n",
+ "94 WAKE\n",
+ "95 WAKE\n",
+ "96 WAKE\n",
+ "97 WAKE\n",
+ "Name: Stage, Length: 98, dtype: category\n",
+ "Categories (7, object): ['WAKE', 'N1', 'N2', 'N3', 'REM', 'ART', 'UNS']"
]
},
- "execution_count": 11,
+ "execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@@ -607,13 +817,13 @@
"source": [
"# Using just an EEG channel (= no EOG or EMG)\n",
"y_pred = yasa.SleepStaging(raw, eeg_name=\"C4\").predict()\n",
- "y_pred"
+ "y_pred.hypno"
]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -627,7 +837,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.3"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/yasa/hypno.py b/yasa/hypno.py
index d67bf82..c52a8b1 100644
--- a/yasa/hypno.py
+++ b/yasa/hypno.py
@@ -74,6 +74,10 @@ class Hypnogram:
scorer : str
An optional string indicating the scorer name. If specified, this will be set as the name
of the :py:class:`pandas.Series`, otherwise the name will be set to "Stage".
+ proba : :py:class:`pandas.DataFrame`
+ An optional dataframe with the probability of each sleep stage for each epoch in hypnogram.
+ Each row must sum to 1. This is automatically included if the hypnogram is created with
+ :py:class:`yasa.SleepStaging`.
Examples
--------
@@ -215,7 +219,7 @@ class Hypnogram:
'%REM': 8.9713}
"""
- def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None):
+ def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None, proba=None):
assert isinstance(
values, (list, np.ndarray, pd.Series)
), "`values` must be a list, numpy.array or pandas.Series"
@@ -228,6 +232,9 @@ def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None):
assert isinstance(
scorer, (type(None), str, int)
), "`scorer` must be either None, or a string or an integer."
+ assert isinstance(
+ proba, (pd.DataFrame, type(None))
+ ), "`proba` must be either None or a pandas.DataFrame"
if n_stages == 2:
accepted = ["W", "WAKE", "S", "SLEEP", "ART", "UNS"]
mapping = {"WAKE": 0, "SLEEP": 1, "ART": -1, "UNS": -2}
@@ -268,6 +275,19 @@ def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None):
fake_dt = pd.date_range(start="2022-12-03 00:00:00", freq=freq, periods=hypno.shape[0])
hypno.index.name = "Epoch"
timedelta = fake_dt - fake_dt[0]
+ # Validate proba
+ if proba is not None:
+ assert proba.shape[1] > 0, "`proba` must have at least one column."
+ assert proba.shape[0] == hypno.shape[0], "`proba` must have the same length as `values`"
+ assert np.allclose(proba.sum(1), 1), "Each row of `proba` must sum to 1."
+ in_proba_but_not_labels = np.setdiff1d(proba.columns, labels)
+ # in_labels_but_not_proba = np.setdiff1d(labels, proba.columns)
+ assert not len(in_proba_but_not_labels), (
+ f"Invalid stages in `proba`: {in_proba_but_not_labels}. The accepted stages are: "
+ f"{labels}."
+ )
+ # Ensure same order as `labels`
+ proba = proba.reindex(columns=labels).dropna(how="all", axis=1)
# Set attributes
self._hypno = hypno
self._n_epochs = hypno.shape[0]
@@ -280,6 +300,7 @@ def __init__(self, values, n_stages=5, *, freq="30s", start=None, scorer=None):
self._labels = labels
self._mapping = mapping
self._scorer = scorer
+ self._proba = proba
def __repr__(self):
# TODO v0.8: Keep only the text between < and >
@@ -294,15 +315,7 @@ def __repr__(self):
)
def __str__(self):
- text_scorer = f", scored by {self.scorer}" if self.scorer is not None else ""
- return (
- f"\n"
- " - Use `.hypno` to get the string values as a pandas.Series\n"
- " - Use `.as_int()` to get the integer values as a pandas.Series\n"
- " - Use `.plot_hypnogram()` to plot the hypnogram\n"
- "See the online documentation for more details."
- )
+ return self.__repr__()
@property
def hypno(self):
@@ -391,6 +404,14 @@ def scorer(self):
"""The scorer name."""
return self._scorer
+ @property
+ def proba(self):
+ """
+ If specified, a :py:class:`pandas.DataFrame` with the probability of each sleep stage
+ for each epoch in hypnogram.
+ """
+ return self._proba
+
# CLASS METHODS BELOW
def as_annotations(self):
@@ -558,6 +579,7 @@ def consolidate_stages(self, new_n_stages):
freq=self.freq,
start=self.start,
scorer=self.scorer,
+ proba=None, # TODO: Combine stages probability?
)
def copy(self):
@@ -568,6 +590,7 @@ def copy(self):
freq=self.freq,
start=self.start,
scorer=self.scorer,
+ proba=self.proba,
)
def find_periods(self, threshold="5min", equal_length=False):
@@ -1055,6 +1078,7 @@ def upsample(self, new_freq, **kwargs):
freq=new_freq,
start=self.start,
scorer=self.scorer,
+ proba=None, # NOTE: Do not upsample probability
)
def upsample_to_data(self, data, sf=None, verbose=True):
diff --git a/yasa/staging.py b/yasa/staging.py
index 9a59ffa..8b42705 100644
--- a/yasa/staging.py
+++ b/yasa/staging.py
@@ -4,6 +4,7 @@
import glob
import joblib
import logging
+import warnings
import numpy as np
import pandas as pd
import antropy as ant
@@ -104,9 +105,9 @@ class SleepStaging:
In addition with the predicted sleep stages, YASA can also return the predicted probabilities
of each sleep stage at each epoch. This can be used to derive a confidence score at each epoch.
- .. important:: The predictions should ALWAYS be double-check by a trained
- visual scorer, especially for epochs with low confidence. A full
- inspection should be performed in the following cases:
+ .. important:: The predictions should ALWAYS be double-check by a trained visual scorer,
+ especially for epochs with low confidence. A full inspection should be performed in the
+ following cases:
* Nap data, because the classifiers were exclusively trained on full-night recordings.
* Participants with sleep disorders.
@@ -122,13 +123,11 @@ class SleepStaging:
If you use YASA's default classifiers, these are the main references for
the `National Sleep Research Resource `_:
- * Dean, Dennis A., et al. "Scaling up scientific discovery in sleep
- medicine: the National Sleep Research Resource." Sleep 39.5 (2016):
- 1151-1164.
+ * Dean, Dennis A., et al. "Scaling up scientific discovery in sleep medicine: the National
+ Sleep Research Resource." Sleep 39.5 (2016): 1151-1164.
- * Zhang, Guo-Qiang, et al. "The National Sleep Research Resource: towards
- a sleep data commons." Journal of the American Medical Informatics
- Association 25.10 (2018): 1351-1358.
+ * Zhang, Guo-Qiang, et al. "The National Sleep Research Resource: towards a sleep data
+ commons." Journal of the American Medical Informatics Association 25.10 (2018): 1351-1358.
Examples
--------
@@ -143,12 +142,15 @@ class SleepStaging:
>>> sls = yasa.SleepStaging(raw, eeg_name="C4-M1", eog_name="LOC-M2",
... emg_name="EMG1-EMG2",
... metadata=dict(age=29, male=True))
+ >>> # Print some basic info
+ >>> sls
>>> # Get the predicted sleep stages
- >>> hypno = sls.predict()
+ >>> hyp = sls.predict()
+ >>> hyp.hypno
>>> # Get the predicted probabilities
- >>> proba = sls.predict_proba()
+ >>> hyp.proba
>>> # Get the confidence
- >>> confidence = proba.max(axis=1)
+ >>> confidence = hyp.proba.max(axis=1)
>>> # Plot the predicted probabilities
>>> sls.plot_predict_proba()
@@ -159,10 +161,10 @@ class SleepStaging:
def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None):
# Type check
- assert isinstance(eeg_name, str)
- assert isinstance(eog_name, (str, type(None)))
- assert isinstance(emg_name, (str, type(None)))
- assert isinstance(metadata, (dict, type(None)))
+ assert isinstance(eeg_name, str), "`eeg_name` must be a string."
+ assert isinstance(eog_name, (str, type(None))), "`eog_name` must be a string or None."
+ assert isinstance(emg_name, (str, type(None))), "`emg_name` must be a string or None."
+ assert isinstance(metadata, (dict, type(None))), "`metadata` must be a string or None."
# Validate metadata
if isinstance(metadata, dict):
@@ -173,7 +175,7 @@ def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None
assert metadata["male"] in [0, 1], "male must be 0 or 1."
# Validate Raw instance and load data
- assert isinstance(raw, mne.io.BaseRaw), "raw must be a MNE Raw object."
+ assert isinstance(raw, mne.io.BaseRaw), "`raw` must be a MNE Raw object."
sf = raw.info["sfreq"]
ch_names = np.array([eeg_name, eog_name, emg_name])
ch_types = np.array(["eeg", "eog", "emg"])
@@ -210,6 +212,22 @@ def __init__(self, raw, eeg_name, *, eog_name=None, emg_name=None, metadata=None
self.data = data
self.metadata = metadata
+ def __repr__(self):
+ n_samples = self.data.shape[-1]
+ duration = (n_samples / self.sf) / 60
+ return (
+ f""
+ )
+
+ def __str__(self):
+ n_samples = self.data.shape[-1]
+ duration = n_samples / self.sf
+ return (
+ f""
+ )
+
def fit(self):
"""Extract features from data.
@@ -419,9 +437,13 @@ def predict(self, path_to_model="auto"):
Returns
-------
- pred : :py:class:`numpy.ndarray`
- The predicted sleep stages.
+ pred : :py:class:`yasa.Hypnogram`
+ The predicted sleep stages. Since YASA v0.7, the predicted sleep stages are now
+ returned as a :py:class:`yasa.Hypnogram` instance, which also includes the
+ probability of each sleep stage for each epoch.
"""
+ from yasa.hypno import Hypnogram
+
if not hasattr(self, "_features"):
self.fit()
# Load and validate pre-trained classifier
@@ -430,10 +452,15 @@ def predict(self, path_to_model="auto"):
X = self._features.copy()[clf.feature_name_]
# Predict the sleep stages and probabilities
self._predicted = clf.predict(X)
- proba = pd.DataFrame(clf.predict_proba(X), columns=clf.classes_)
- proba.index.name = "epoch"
+ # Predict the probabilities
+ classes = clf.classes_.copy()
+ classes[classes == "W"] = "WAKE" # Compat for yasa.Hypnogram
+ classes[classes == "R"] = "REM"
+ proba = pd.DataFrame(clf.predict_proba(X), columns=classes)
+ proba.index.name = "Epoch"
self._proba = proba
- return self._predicted.copy()
+ # Convert to a `yasa.Hypnogram` instance (including `proba`)
+ return Hypnogram(values=self._predicted.copy(), freq="30s", n_stages=5, proba=proba.copy())
def predict_proba(self, path_to_model="auto"):
"""
@@ -454,6 +481,12 @@ def predict_proba(self, path_to_model="auto"):
proba : :py:class:`pandas.DataFrame`
The predicted probability for each sleep stage for each 30-sec epoch of data.
"""
+ warnings.warn(
+ "The `predict_proba` function is deprecated and will be removed in v0.8. "
+ "The predicted probabilities can now be accessed with `yasa.Hypnogram.proba` instead, "
+ "e.g `SleepStaging.predict().proba`",
+ FutureWarning,
+ )
if not hasattr(self, "_proba"):
self.predict(path_to_model)
return self._proba.copy()
@@ -475,19 +508,18 @@ def plot_predict_proba(
If True, probabilities of the non-majority classes will be set to 0.
"""
if proba is None and not hasattr(self, "_features"):
- raise ValueError("Must call .predict_proba before this function")
+ raise ValueError("Must call `.predict` before this function")
if proba is None:
proba = self._proba.copy()
else:
- assert isinstance(proba, pd.DataFrame), "proba must be a dataframe"
+ assert isinstance(proba, pd.DataFrame), "`proba` must be a pandas.DataFrame"
if majority_only:
cond = proba.apply(lambda x: x == x.max(), axis=1)
proba = proba.where(cond, other=0)
ax = proba.plot(kind="area", color=palette, figsize=(10, 5), alpha=0.8, stacked=True, lw=0)
# Add confidence
# confidence = proba.max(1)
- # ax.plot(confidence, lw=1, color='k', ls='-', alpha=0.5,
- # label='Confidence')
+ # ax.plot(confidence, lw=1, color='k', ls='-', alpha=0.5, label='Confidence')
ax.set_xlim(0, proba.shape[0])
ax.set_ylim(0, 1)
ax.set_ylabel("Probability")
diff --git a/yasa/tests/test_staging.py b/yasa/tests/test_staging.py
index a87edd8..e396431 100644
--- a/yasa/tests/test_staging.py
+++ b/yasa/tests/test_staging.py
@@ -3,6 +3,7 @@
import unittest
import numpy as np
import matplotlib.pyplot as plt
+from yasa.hypno import Hypnogram
from yasa.staging import SleepStaging
##############################################################################
@@ -11,7 +12,7 @@
# MNE Raw
raw = mne.io.read_raw_fif("notebooks/sub-02_mne_raw.fif", preload=True, verbose=0)
-hypno = np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str)
+y_true = Hypnogram(np.loadtxt("notebooks/sub-02_hypno_30s.txt", dtype=str))
class TestStaging(unittest.TestCase):
@@ -22,12 +23,18 @@ def test_sleep_staging(self):
sls = SleepStaging(
raw, eeg_name="C4", eog_name="EOG1", emg_name="EMG1", metadata=dict(age=21, male=False)
)
+ print(sls)
+ print(str(sls))
sls.get_features()
y_pred = sls.predict()
+ assert isinstance(y_pred, Hypnogram)
+ assert y_pred.proba is not None
proba = sls.predict_proba()
- assert y_pred.size == hypno.size
+ assert y_pred.hypno.size == y_true.hypno.size
+ assert y_true.duration == y_pred.duration
+ assert y_true.n_stages == y_pred.n_stages
# Check that the accuracy is at least 80%
- accuracy = (hypno == y_pred).sum() / y_pred.size
+ accuracy = (y_true.hypno == y_pred.hypno).mean()
assert accuracy > 0.80
# Plot