In [1]:
import sys
import os

# パスをsys.pathに追加
sys.path.append("/p-antipsychotics-sleep/")

import os
import sys
import argparse
import pandas as pd
import numpy as np
import faster2lib.eeg_tools as et
import re
from datetime import datetime, timedelta
from scipy import signal
from hmmlearn import hmm, base
from sklearn import mixture
import matplotlib as mpl
from matplotlib.figure import Figure
from mpl_toolkits.mplot3d import axes3d
from scipy import linalg, stats
from scipy.spatial import distance
from scipy.stats import multivariate_normal
import pickle
from glob import glob
import mne
import logging
from logging import getLogger, StreamHandler, FileHandler, Formatter
import traceback

In [2]:
FASTER2_NAME = 'FASTER2 version 0.4.1'
STAGE_LABELS = ['Wake', 'REM', 'NREM']
XLABEL = 'Total low-freq. log-powers'
YLABEL = 'Total high-freq. log-powers'
ZLABEL = 'REM metric'
SCATTER_PLOT_FIG_WIDTH = 6   # inch
SCATTER_PLOT_FIG_HEIGHT = 6  # inch
FIG_DPI = 100  # dot per inch
COLOR_WAKE = '#DC267F'
COLOR_NREM = '#648FFF'
COLOR_REM = '#FFB000'
COLOR_LIGHT = '#FFD700'  # 'gold'
COLOR_DARK = '#696969'  # 'dimgray'
COLOR_DARKLIGHT = 'lightgray'  # light hours in DD condition

In [3]:
class CustomedGHMM(hmm.GaussianHMM):
    def __init__(self, n_components=1, covariance_type='diag',
                 min_covar=1e-3,
                 startprob_prior=1.0, transmat_prior=1.0,
                 means_prior=0, means_weight=0,
                 covars_prior=1e-2, covars_weight=1,
                 algorithm="viterbi", random_state=None,
                 n_iter=10, tol=1e-2, verbose=False,
                 params="stmc", init_params="stmc"):
        super().__init__(n_components, covariance_type,
                         min_covar, startprob_prior, transmat_prior,
                         means_prior, means_weight,
                         covars_prior, covars_weight,
                         algorithm, random_state,
                         n_iter, tol, verbose,
                         params, init_params)
        self.wr_boundary = None
        self.nr_boundary = None
        self.max_rem_ax = None

    def set_wr_boundary(self, wr_boundary):
        # Wake/REM boundary. REM cluster cannot grow below this boundary
        self.wr_boundary = wr_boundary

    def set_nr_boundary(self, nr_boundary):
        # NREM/REM boundary. REM cluster cannot grow beyond this boundary
        self.nr_boundary = nr_boundary

    def set_max_rem_ax(self, max_rem_ax_len):
        # maximum length of REM's principal axis
        self.max_rem_ax = max_rem_ax_len

    def _confine_REM_in_boundary(self, rem_mean, rem_cov):
        """ By definition, REM cluster is not likely z (i.e. REM-metric)<REM_floor and
        x (i.e. low-freq power)>0.
        This function focuses on the ellipsoid that represents the 95% confidence area
        of REM cluster. If this function finds any principal axis of the ellipsoid
        penetrating the REM floor or the NREM wall (i.e. the end-point of the principal
        axis at z<REM_floor or x>0), it shrinks the length of the axis to the point on
        the constraints.
        """

        w, v = linalg.eigh(rem_cov)
        # all eigenvalues must be positive
        if np.any(w <= 0):
            raise ValueError('Invalid_REM_Cluster')

        w = 2 * np.sqrt(w)  # 95% confidence (2SD) area

        # confine above REM floor
        prn_ax = v@np.diag(w)  # 3x3 matrix: each column is the principal axis
        for i in range(3):
            arr_hd = rem_mean + prn_ax[:, i]  # the arrow head from the mean
            # the negative arrow head from the mean
            narr_hd = rem_mean - prn_ax[:, i]
            if arr_hd[2] < self.wr_boundary:
                sr = (rem_mean[2] - self.wr_boundary) / \
                    (rem_mean[2] - arr_hd[2])  # shrink ratio
            elif narr_hd[2] < self.wr_boundary:
                sr = (rem_mean[2] - self.wr_boundary)/(rem_mean[2] - narr_hd[2])
            else:
                sr = 1
            w[i] = w[i] * sr
        
        # confine the REM cluster within negative low-freq and above the diagonal line
        prn_ax = v@np.diag(w)  # 3x3 matrix: each column is the principal axis
        for i in range(3):
            arr_hd = rem_mean + prn_ax[:, i] # the arrow head from the mean
            narr_hd = rem_mean - prn_ax[:, i] # the negative arrow head from the mean
            
            # condition 1: negative low-freq and BELOW the diagonal line
            if arr_hd[0] > self.nr_boundary and arr_hd[1] < arr_hd[0]:
                # condition 2: if positive high-freq > 0 then allow to grow onto the diagonal line
                if arr_hd[1] > 0:
                    sr = self._shrink_ratio(arr_hd, rem_mean)
                # Otherwise (negative high-freq) then only allow to reach onto the y-axis. 
                else: 
                    sr = (self.nr_boundary - rem_mean[0])/(arr_hd[0] - rem_mean[0]) # shrink ratio
            elif narr_hd[0] > self.nr_boundary and narr_hd[1] < narr_hd[0]:
                if narr_hd[1] > 0:
                    sr = self._shrink_ratio(narr_hd, rem_mean)
                else: 
                    sr = (self.nr_boundary - rem_mean[0])/(narr_hd[0] - rem_mean[0]) # shrink ratio
            else:
                sr = 1
            w[i] = w[i] * sr

        # confine the length of principal axes
        prn_ax = v@np.diag(w/2)  # half w because it was doubled in the previous process
        prn_ax_len = np.sqrt(np.diag(prn_ax.T@prn_ax)) # lengths of principal axes
        for i in range(3):
            if prn_ax_len[i] > self.max_rem_ax:
                sr = self.max_rem_ax / prn_ax_len[i]
            else:
                sr = 1
            w[i] = w[i] * sr
        
        cov_updated = v@np.diag((w/2)**2)@v.T

        return cov_updated


    def _confine_Wake_in_boundary(self, wake_mean, wake_cov):
        """ By definition, Wake cluster is not likely to cross the diagonal line.
        This function focuses on the ellipsoid that represents the 95% confidence area
        of Wake cluster. If this function finds any principal axis of the ellipsoid
        penetrating the diagonal line (i.e. the end-point of the principal
        axis is in the area y < x), it shrinks the length of the axis to the point on
        the constraints.
        """

        w, v = linalg.eigh(wake_cov)
        # all eigenvalues must be positive
        if np.any(w <= 0):
            raise ValueError('Invalid_Wake_Cluster')

        w = 2 * np.sqrt(w)  # 95% confidence (2SD) area
        prn_ax = v@np.diag(w)  # 3x3 matrix: each column is the principal axis
        # confine above diagonal line
        for i in range(3):
            arr_hd = wake_mean + prn_ax[:, i]  # the arrow head from the mean
            # the negative arrow head from the mean
            narr_hd = wake_mean - prn_ax[:, i]
            if arr_hd[1] < arr_hd[0]:
                sr = self._shrink_ratio(arr_hd, wake_mean)
            elif narr_hd[1] < narr_hd[0]:
                sr = self._shrink_ratio(narr_hd, wake_mean)
            else:
                sr = 1
            w[i] = w[i] * sr


        cov_updated = v@np.diag((w/2)**2)@v.T

        return cov_updated

    
    def _confine_NREM_in_boundary(self, nrem_mean, nrem_cov):
        """ By definition, NREM cluster is not likely to cross the diagonal line.
        This function focuses on the ellipsoid that represents the 95% confidence area
        of NREM cluster. If this function finds any principal axis of the ellipsoid
        penetrating the diagonal line (i.e. the end-point of the principal
        axis is in the area y > x), it shrinks the length of the axis to the point on
        the constraints.
        """

        w, v = linalg.eigh(nrem_cov)
        # all eigenvalues must be positive
        if np.any(w <= 0):
            raise ValueError('Invalid_NREM_Cluster')

        w = 2 * np.sqrt(w)  # 95% confidence (2SD) area
        prn_ax = v@np.diag(w)  # 3x3 matrix: each column is the principal axis

        # confine above diagonal line
        for i in range(3):
            arr_hd = nrem_mean + prn_ax[:, i]  # the arrow head from the mean
            # the negative arrow head from the mean
            narr_hd = nrem_mean - prn_ax[:, i]
            if arr_hd[1] > arr_hd[0]:
                sr = self._shrink_ratio(arr_hd, nrem_mean)
            elif narr_hd[1] > narr_hd[0]:
                sr = self._shrink_ratio(narr_hd, nrem_mean)
            else:
                sr = 1
            w[i] = w[i] * sr


        cov_updated = v@np.diag((w/2)**2)@v.T

        return cov_updated
 

    def _shrink_ratio(self, arr_hd, mn):
        r = (arr_hd[1] - mn[1])/(arr_hd[0]-mn[0])
        x_on_diag = (mn[1] - r*mn[0])/(1-r)
        sr = (x_on_diag - mn[0])/(arr_hd[0] - mn[0])
        return sr


    #pylint: disable = redefined-outer-name
    def _do_mstep(self, stats):
        # pylint: disable = attribute-defined-outside-init
        ghmm_stats = stats
        # pylint: disable = protected-access
        base._BaseHMM._do_mstep(self, ghmm_stats)
        means_prior = self.means_prior
        means_weight = self.means_weight

        denom = ghmm_stats['post'][:, np.newaxis]
        if 'm' in self.params:
            self.means_ = ((means_weight * means_prior + ghmm_stats['obs'])
                           / (means_weight + denom))
 
        if 'c' in self.params:
            covars_prior = self.covars_prior
            covars_weight = self.covars_weight
            meandiff = self.means_ - means_prior

            if self.covariance_type in ('spherical', 'diag'):
                cv_num = (means_weight * meandiff**2
                          + ghmm_stats['obs**2']
                          - 2 * self.means_ * ghmm_stats['obs']
                          + self.means_**2 * denom)
                cv_den = max(covars_weight - 1, 0) + denom
                self._covars_ = \
                    (covars_prior + cv_num) / np.maximum(cv_den, 1e-5)
                if self.covariance_type == 'spherical':
                    self._covars_ = np.tile(
                        self._covars_.mean(1)[:, np.newaxis],
                        (1, self._covars_.shape[1]))
            elif self.covariance_type in ('tied', 'full'):
                cv_num = np.empty((self.n_components, self.n_features,
                                   self.n_features))
                for c in range(self.n_components):
                    obsmean = np.outer(ghmm_stats['obs'][c], self.means_[c])

                    cv_num[c] = (means_weight * np.outer(meandiff[c],
                                                         meandiff[c])
                                 + ghmm_stats['obs*obs.T'][c]
                                 - obsmean - obsmean.T
                                 + np.outer(self.means_[c], self.means_[c])
                                 * ghmm_stats['post'][c])
                cvweight = max(covars_weight - self.n_features, 0)
                if self.covariance_type == 'tied':
                    self._covars_ = ((covars_prior + cv_num.sum(axis=0)) /
                                     (cvweight + ghmm_stats['post'].sum()))
                elif self.covariance_type == 'full':
                    #covars = ((covars_prior + cv_num) /
                    #          (cvweight + ghmm_stats['post'][:, None, None]))
                    covars = ((covars_prior + cv_num) /
                            np.maximum(cvweight + ghmm_stats['post'][:, None, None], 1e-6))
                    covars = np.nan_to_num(covars, nan=0.0, posinf=1e6, neginf=-1e6)
                    confined_rem_cov = self._confine_REM_in_boundary(
                        self.means_[1], covars[1])
                    confined_wake_cov = self._confine_Wake_in_boundary(
                        self.means_[0], covars[0])
                    confined_nrem_cov = self._confine_NREM_in_boundary(
                        self.means_[2], covars[2])
                    self._covars_ = np.array(
                        [confined_wake_cov, confined_rem_cov, confined_nrem_cov])


def initialize_logger(log_file):
    logger = getLogger(FASTER2_NAME)
    logger.setLevel(logging.DEBUG)

    file_handler = FileHandler(log_file)
    stream_handler = StreamHandler()

    file_handler.setLevel(logging.DEBUG)
    stream_handler.setLevel(logging.NOTSET)
    handler_formatter = Formatter('%(message)s')
    file_handler.setFormatter(handler_formatter)
    stream_handler.setFormatter(handler_formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

    return logger


def print_log(msg):
    if 'log' in globals():
        log.debug(msg)
    else:
        print(msg)


def print_log_exception(msg):
    if 'log' in globals():
        log.exception(msg)
    else:
        print(msg)


def read_mouse_info(data_dir):
    """This function reads the mouse.info.csv file
    and returns a DataFrame with fixed column names.

    Args:
        data_dir (str): A path to the data directory that contains the mouse.info.csv


        The data directory should include two information files:
        1. exp.info.csv,
        2. mouse.info.csv,
        and one directory named "raw" that contains all EEG/EMG data to be processed

    Returns:
        DataFrame: A dataframe with a fixed column names
    """

    filepath = os.path.join(data_dir, "mouse.info.csv")

    try:
        codename = et.encode_lookup(filepath)
    except LookupError as e:
        print_log(e)
        exit(1)

    csv_df = pd.read_csv(filepath,
                         engine="python",
                         dtype={'Device label': str, 'Mouse group': str,
                                'Mouse ID': str, 'DOB': str, 'Stats report': str, 'Note': str},
                         names=["Device label", "Mouse group",
                                "Mouse ID", "DOB", "Stats report", "Note"],
                         skiprows=1,
                         header=None,
                         skipinitialspace=True,
                         encoding=codename)

    return csv_df


# def read_exp_info(data_dir):
#     """This function reads the exp.info.csv file
#     and returns a DataFrame with fixed column names.

#     Args:
#         data_dir (str): A path to the data directory that contains the exp.info.csv

#     Returns:
#         DataFrame: A DataFrame with a fixed column names
#     """

#     filepath = os.path.join(data_dir, "exp.info.csv")

#     csv_df = pd.read_csv(filepath,
#                          engine="python",
#                          names=["Experiment label", "Rack label",
#                                 "Start datetime", "End datetime", "Sampling freq"],
#                          skiprows=1,
#                          header=None)

#     return csv_df


# def find_edf_files(data_dir):
#     """returns list of edf files in the directory

#     Args:
#         data_dir (str): A path to the data directory

#     Returns:
#         [list]: A list returned by glob()
#     """
#     return glob(os.path.join(data_dir, '*.edf'))


def read_voltage_matrices(data_dir, device_id, sample_freq, epoch_len_sec, epoch_num,
                          start_datetime=None):
    """ This function reads data files of EEG and EMG, then returns matrices
    in the shape of (epochs, signals).

    Args:
        data_dir (str): a path to the dirctory that contains either dsi.txt/, pkl/ directory,
        or an EDF file.
        device_id (str): a transmitter ID (e.g., ID47476) or channel ID (e.g., 09).
        sample_freq (int): sampling frequency
        epoch_len_sec (int): the length of an epoch in seconds
        epoch_num (int): the number of epochs to be read.
        start_datetime (datetime): start datetime of the analysis (used only for EDF file and
        dsi.txt).

    Returns:
        (np.array(2), np.arrray(2), bool): a pair of voltage 2D matrices in a tuple
        and a switch to tell if there was pickled data.

    Note:
        This function looks into the data_dir/ and first try to read pkl files. If pkl files
        are not found, it tries to read an EDF file. If the EDF file is also not found, it
        tries to read dsi.txt files.
    """

    if os.path.exists(os.path.join(data_dir, 'pkl', f'{device_id}_EEG.pkl')):
        # if it exists, read the pkl file
        not_yet_pickled = False
        # Try to read pickled data
        pkl_path = os.path.join(data_dir, 'pkl', f'{device_id}_EEG.pkl')
        with open(pkl_path, 'rb') as pkl:
            print_log(f'Reading {pkl_path}')
            eeg_vm = pickle.load(pkl)

        pkl_path = os.path.join(data_dir, 'pkl', f'{device_id}_EMG.pkl')
        with open(pkl_path, 'rb') as pkl:
            print_log(f'Reading {pkl_path}')
            emg_vm = pickle.load(pkl)

    elif len(find_edf_files(data_dir)) > 0:
        # try to read EDF file
        not_yet_pickled = True
        # read EDF file
        edf_file = find_edf_files(data_dir)
        if len(edf_file) != 1:
            raise FileNotFoundError(
                f'Too many EDF files were found:{edf_file}. '
                'FASTER2 assumes there is only one file.')
        edf_file = edf_file[0]

        raw = mne.io.read_raw_edf(edf_file)
        measurement_start_datetime = datetime.utcfromtimestamp(
            raw.info['meas_date'][0]) + timedelta(microseconds=raw.info['meas_date'][1])
        try:
            if isinstance(start_datetime, datetime) and (measurement_start_datetime < start_datetime):
                start_offset_sec = (
                    start_datetime - measurement_start_datetime).total_seconds()
                end_offset_sec = start_offset_sec + epoch_num * epoch_len_sec
                bidx = (raw.times >= start_offset_sec) & (
                    raw.times < end_offset_sec)
                start_slice = np.where(bidx)[0][0]
                end_slice = np.where(bidx)[0][-1]+1
                eeg = raw.get_data(f'EEG{device_id}',
                                   start_slice, end_slice)[0]
                emg = raw.get_data(f'EMG{device_id}',
                                   start_slice, end_slice)[0]
            else:
                eeg = raw.get_data(f'EEG{device_id}')[0]
                emg = raw.get_data(f'EMG{device_id}')[0]
        except ValueError:
            print_log(f'Failed to extract the data of "{device_id}" from {edf_file}. '
                      f'Check the channel name: "EEG/EMG{device_id}" is in the EDF file.')
            raise
        raw.close()
        try:
            eeg_vm = eeg.reshape(-1, epoch_len_sec * sample_freq)
            emg_vm = emg.reshape(-1, epoch_len_sec * sample_freq)
        except ValueError:
            print_log(f'Failed to extract {epoch_num} epochs from {edf_file}. '
                      'Check the validity of the epoch number, start datetime, '
                      'and sampling frequency.')
            raise
    elif os.path.exists(os.path.join(data_dir, 'dsi.txt')):
        # try to read dsi.txt
        not_yet_pickled = True
        try:
            dsi_reader_eeg = et.DSI_TXT_Reader(os.path.join(data_dir, 'dsi.txt/'),
                                               f'{device_id}', 'EEG',
                                               sample_freq=sample_freq)
            dsi_reader_emg = et.DSI_TXT_Reader(os.path.join(data_dir, 'dsi.txt/'),
                                               f'{device_id}', 'EMG',
                                               sample_freq=sample_freq)
            if isinstance(start_datetime, datetime):
                end_datetime = start_datetime + \
                    timedelta(seconds=epoch_len_sec*epoch_num)
                eeg_df = dsi_reader_eeg.read_epochs_by_datetime(
                    start_datetime, end_datetime)
                emg_df = dsi_reader_emg.read_epochs_by_datetime(
                    start_datetime, end_datetime)
            else:
                eeg_df = dsi_reader_eeg.read_epochs(1, epoch_num)
                emg_df = dsi_reader_emg.read_epochs(1, epoch_num)
            eeg_vm = eeg_df.value.values.reshape(-1,
                                                 epoch_len_sec * sample_freq)
            emg_vm = emg_df.value.values.reshape(-1,
                                                 epoch_len_sec * sample_freq)
        except FileNotFoundError:
            print_log(
                f'The dsi.txt file for {device_id} was not found in {data_dir}.')
            raise
    else:
        raise FileNotFoundError(
            f'Data file was not found for device {device_id} in {data_dir}.')

    expected_shape = (epoch_num, sample_freq * epoch_len_sec)
    if (eeg_vm.shape != expected_shape) or (emg_vm.shape != expected_shape):
        raise ValueError(f'Unexpected shape of matrices EEG:{eeg_vm.shape} or EMG:{emg_vm.shape}. '
                         f'Expected shape is {expected_shape}. Check the validity of '
                         'the data files or configurations '
                         'such as the epoch number and the sampling frequency.')

    return (eeg_vm, emg_vm, not_yet_pickled)


# def interpret_datetimestr(datetime_str):
#     """ Find a datetime string and convert it to a datatime object
#     allowing some variant forms

#     Args:
#         datetime_str (string): a string containing datetime

#     Returns:
#         a datetime object

#     Raises:
#         ValueError: raised when interpretation is failed
#     """

#     datestr_patterns = [r'(\d{4})(\d{2})(\d{2})',
#                         r'(\d{4})/(\d{1,2})/(\d{1,2})',
#                         r'(\d{4})-(\d{1,2})-(\d{1,2})']

#     timestr_patterns = [r'(\d{2})(\d{2})(\d{2})',
#                         r'(\d{1,2}):(\d{1,2}):(\d{1,2})',
#                         r'(\d{1,2})-(\d{1,2})-(\d{1,2})']

#     datetime_obj = None
#     for pat in datestr_patterns:
#         matched = re.search(pat, datetime_str)
#         if matched:
#             year = int(matched.group(1))
#             month = int(matched.group(2))
#             day = int(matched.group(3))
#             datetime_str = re.sub(pat, '', datetime_str)

#             for pat_time in timestr_patterns:
#                 matched_time = re.search(pat_time, datetime_str)
#                 if matched_time:
#                     hour = int(matched_time.group(1))
#                     minuite = int(matched_time.group(2))
#                     second = int(matched_time.group(3))
#                     datetime_obj = datetime(year, month, day,
#                                             hour, minuite, second)
#                     break
#             if not matched_time:
#                 datetime_obj = datetime(year, month, day)
#     if not datetime_obj:
#         raise ValueError(
#             'failed to interpret datetime string \'{}\''.format(datetime_str))

#     return datetime_obj


# def interpret_exp_info(exp_info_df, epoch_len_sec):
#     try:
#         start_datetime_str = exp_info_df['Start datetime'].values[0]
#         end_datetime_str = exp_info_df['End datetime'].values[0]
#         sample_freq = exp_info_df['Sampling freq'].values[0]
#         exp_label = exp_info_df['Experiment label'].values[0]
#         rack_label = exp_info_df['Rack label'].values[0]
#     except KeyError as e:
#         print_log(
#             f'Failed to parse the column: {e} in exp.info.csv. Check the headers.')
#         exit(1)

#     start_datetime = interpret_datetimestr(start_datetime_str)
#     end_datetime = interpret_datetimestr(end_datetime_str)

#     epoch_num = int(
#         (end_datetime - start_datetime).total_seconds() / epoch_len_sec)

#     return (epoch_num, sample_freq, exp_label, rack_label, start_datetime, end_datetime)


def psd(y, n_fft, sample_freq):
    return signal.welch(y, nfft=n_fft, fs=sample_freq)[1][0:129]


def plot_hist_on_separation_axis(path2figures, d, means, covars, weights, draw_pdf_plot=False):

    if means[0] > means[1]:
        # reverse the order
        means = means[::-1]
        covars = covars[::-1]
        weights = weights[::-1]

    d_axis = np.arange(-20, 20)

    fig = Figure(figsize=(SCATTER_PLOT_FIG_WIDTH,
                          SCATTER_PLOT_FIG_HEIGHT), dpi=FIG_DPI, facecolor='w')
    ax = fig.add_subplot(111)
    ax.set_xlim(-20, 20)
    ax.set_ylim(0, 0.1)

    ax.hist(d, bins=100, density=True)
    ax.plot(d_axis, weights[0]*stats.norm.pdf(d_axis,
                                              means[0], np.sqrt(covars[0])).ravel(), c=COLOR_WAKE)
    ax.plot(d_axis, weights[1]*stats.norm.pdf(d_axis,
                                              means[1], np.sqrt(covars[1])).ravel(), c=COLOR_NREM)
    ax.axvline(x=means[0], color='black', dashes=[2, 2])
    ax.axvline(x=means[1], color='black', dashes=[2, 2])
    ax.axvline(x=np.mean(means), color='black')

    ax.set_xlabel('', fontsize=10)
    ax.set_ylabel('', fontsize=10)

    _savefig(path2figures, 'histogram_on_separation_axis', fig, draw_pdf_plot)

    return fig


def plot_scatter2D(points_2D, classes, means, covariances, colors, xlabel, ylabel, diag_line=False):
    fig = Figure(figsize=(SCATTER_PLOT_FIG_WIDTH,
                          SCATTER_PLOT_FIG_HEIGHT), dpi=FIG_DPI, facecolor='w')
    ax = fig.add_subplot(111)
    ax.set_xlim(-20, 20)
    ax.set_ylim(-20, 20)

    for i, color in enumerate(colors):

        ax.scatter(points_2D[classes == i, 0],
                   points_2D[classes == i, 1], .01, color=color)

        # Plot an ellipse to show the Gaussian component
        if (len(means) > i and len(covariances) > i):
            mean = means[i]
            covar = covariances[i]
            w, v = linalg.eigh(covar)
            w = 4. * np.sqrt(w)  # 95% confidence (2SD) area (2*radius)
            angle = np.arctan(v[1, 0] / v[0, 0])
            angle = 180. * angle / np.pi  # convert to degrees
            #ell = mpl.patches.Ellipse(
            #    mean, w[0], w[1], 180. + angle, facecolor='none', edgecolor=color)
            ell = mpl.patches.Ellipse(
                xy=mean, width=w[0], height=w[1], angle=180. + angle, facecolor='none', edgecolor=color)
            ax.add_patch(ell)
    if diag_line == True:
        ax.plot([-20, 20], [-20, 20], color='gray', linewidth=0.7)
    ax.set_xlabel(xlabel, fontsize=10)
    ax.set_ylabel(ylabel, fontsize=10)
    return fig


def pickle_voltage_matrices(eeg_vm, emg_vm, data_dir, device_id):
    """ To save time for reading CSV files, pickle the voltage matrices
    in pickle files.
    
    Args:
        eeg_vm (np.array): voltage matrix for EEG data
        emg_vm (np.array): voltage matrix for EMG data
        data_dir (str):  path to the directory of pickled data (pkl/)
        device_id (str): a string to identify the recording device (e.g. ID47467)
"""
    pickle_dir = os.path.join(data_dir, 'pkl/')
    os.makedirs(pickle_dir, exist_ok=True)

    # save EEG
    pkl_path = os.path.join(pickle_dir, f'{device_id}_EEG.pkl')
    if os.path.exists(pkl_path):
        print_log(f'File already exists. Nothing to be done. {pkl_path}')
    else:
        with open(pkl_path, 'wb') as pkl:
            print_log(f'Saving the voltage matrix into {pkl_path}')
            pickle.dump(eeg_vm, pkl)

    # save EMG
    pkl_path = os.path.join(pickle_dir, f'{device_id}_EMG.pkl')
    if os.path.exists(pkl_path):
        print_log(f'File already exists. Nothing to be done. {pkl_path} ')
    else:
        with open(pkl_path, 'wb') as pkl:
            print_log(f'Saving the voltage matrix into {pkl_path}')
            pickle.dump(emg_vm, pkl)


def pickle_powerspec_matrices(spec_norm_eeg, spec_norm_emg, result_dir_path, device_id):
    """ pickles the power spectrum density matrices for subsequent analyses

    Args:
        spec_norm_eeg (dict): a dict returned by spectrum_normalize() for EEG data
        spec_norm_emg (dict): a dict returned by spectrum_normalize() for EMG data
        bidx_unknown (np.array): an array of the boolean index
        result_dir_path (str):  path to the directory of the pickled data (PSD/)
        device_id (str): a string to identify the recording device (e.g. ID47467)
    """
    pickle_dir = os.path.join(result_dir_path, 'PSD/')
    os.makedirs(pickle_dir, exist_ok=True)

    print_log(f'Saving PSD files')

    # save EEG PSD
    pkl_path = os.path.join(pickle_dir, f'{device_id}_EEG_PSD.pkl')
    #if os.path.exists(pkl_path):
    #    print_log(f'File already exists: {pkl_path}')
    #else:
    with open(pkl_path, 'wb') as pkl:
        print_log(f'Saving the EEG PSD matrix into {pkl_path}')
        pickle.dump(spec_norm_eeg, pkl)

    # save EMG PSD
    pkl_path = os.path.join(pickle_dir, f'{device_id}_EMG_PSD.pkl')
    #if os.path.exists(pkl_path):
    #    print_log(f'File already exists: {pkl_path} ')
    #else:
    with open(pkl_path, 'wb') as pkl:
        print_log(f'Saving the EMG PSD matrix into {pkl_path}')
        pickle.dump(spec_norm_emg, pkl)


def pickle_cluster_params(means2, covars2, c_means, c_covars, result_dir_path, device_id):
    """ pickles the cluster parameters
    
    Args:
        means2 (np.array(2,2)): a mean matrix of 2 stage-clusters
        covars2 (np.array(2,2,2)): a covariance matrix of 2 stage-clusters
        c_means (np.array(3,3)):  a mean matrix of 3 stage-clusters
        c_covars (np.array(3,3,3)): a covariance matrix of 3 stage-clusters
    """
    pickle_dir = os.path.join(result_dir_path, 'cluster_params/')
    os.makedirs(pickle_dir, exist_ok=True)

    # save
    pkl_path = os.path.join(pickle_dir, f'{device_id}_cluster_params.pkl')
    print_log(f'Saving the cluster parameters into {pkl_path}')
    with open(pkl_path, 'wb') as pkl:
        pickle.dump({'2stage-means': means2, '2stage-covars': covars2,
                     '3stage-means': c_means, '3stage-covars': c_covars}, pkl)

def remove_extreme_voltage(y, sample_freq):
    """An optional function to remove periodic-spike noises such as
    heart beat in EMG. Since the spikes are ofhen above 1.64 SD (upper 10%) 
    within the data region of interest, FASTER2 tries to replace those points with
    randam values.
    Note: This function is destructive i.e. it changes values of the
    given vector.


    Args:
        y (np.array(1)): a vector of voltages in an epoch
        sample_freq (int): the sampling frequency
    """
    vm = y.reshape(-1, sample_freq*2) # 2 sec
    for v in vm:
        m = np.mean(v)
        s = np.std(v)
        bidx = (abs(v) - m) > 1.64*s
        v[bidx] =  np.random.normal(m, s, np.sum(bidx))


def remove_extreme_power(y):
    """In FASTER2, the spectrum powers are normalized so that the mean and 
    SD of each frequency power over all epochs become 0 and 1, respectively.  
    This function removes extremely high or low powers in the normalized 
    spectrum by replacing the value with a random number sampled from the 
    noraml distribution: N(0, 1). 
    
    Args:
        y (np.array(1)): a vector of normalized power spectrum
    
    Returns:
        float: The ratio of the replaced exreme values in the given vector.
    """
    n_len = len(y)

    bidx = (np.abs(y) > 3)  # extreme means over 3SD
    n_extr = np.sum(bidx)

    y[bidx] = np.random.normal(0, 1, n_extr)

    return n_extr / n_len


def spectrum_normalize(voltage_matrix, n_fft, sample_freq):
    # power-spectrum normalization of EEG
    psd_mat = np.apply_along_axis(lambda y: psd(
        y, n_fft, sample_freq), 1, voltage_matrix)
    psd_mat = 10*np.log10(psd_mat)  # decibel-like
    psd_mean = np.apply_along_axis(np.nanmean, 0, psd_mat)
    psd_sd = np.apply_along_axis(np.nanstd, 0, psd_mat)
    spec_norm_fac = 1/psd_sd
    psd_norm_mat = np.apply_along_axis(lambda y: spec_norm_fac*(y - psd_mean),
                                       1,
                                       psd_mat)
    return {'psd': psd_norm_mat, 'mean': psd_mean, 'norm_fac': spec_norm_fac}


def cancel_weight_bias(stage_coord_2D):
    # Estimate the center of the two clusters
    print_log('Estimate the bias of the two cluster means')
    d = stage_coord_2D@np.array([1, -1]).T  # project onto the separation axis
    d = d.reshape(-1, 1)
    # Two states: active and quiet
    gmm = mixture.BayesianGaussianMixture(n_components=2, n_init=10)
    gmm.fit(d)
    weights = gmm.weights_
    means = gmm.means_.flatten()
    covars = gmm.covariances_.flatten()

    # this is supposed to be zero if weights are completely balanced
    bias = np.mean(means[0:2])
    s = bias/np.sqrt(2)  # x,y-axis components
    print_log(f'Estimated bias: {s}')

    stage_coord_2D = stage_coord_2D + [-s, s]

    return {'proj_data': d, 'means': means, 'covars': covars, 'weights': weights, 'stage_coord_2D': stage_coord_2D}


def classify_active_and_NREM(stage_coord_2D):
    # Initialize active/stative(NREM) clusters by Gaussian mixture model ignoring transition probablity
    print_log('Initialize active/NREM clusters with the diagonal line')

    # projection onto the separation "line" (which is perpendicular to the separation "axis")
    stage_coord_1DD = stage_coord_2D@np.array([1, 1]).T
    bidx_over_outliers = stage_coord_1DD > (
        np.mean(stage_coord_1DD) + 3*np.std(stage_coord_1DD))
    bidx_under_outliers = stage_coord_1DD < (
        np.mean(stage_coord_1DD) - 3*np.std(stage_coord_1DD))
    bidx_valid = ~(bidx_over_outliers | bidx_under_outliers)

    # To estimate clusters, use only epochs projected within the reasonable region on the separation line (<3SD)
    def _geo_classifier(coord):
        # geometrical classifier (simpl separation by the diagonal line)
        if coord[0] - coord[1] > 0:
            return 1
        else:
            return 0

    # Means and covariances of the active and NREM clusters
    geo_pred = np.array([_geo_classifier(c) for c in stage_coord_2D])
    mm_2D = np.array([
        np.mean(stage_coord_2D[geo_pred == 0 & bidx_valid], axis=0),
        np.mean(stage_coord_2D[geo_pred == 1 & bidx_valid], axis=0),
    ])
    cc_2D = np.array([
        np.cov(stage_coord_2D[geo_pred == 0 & bidx_valid], rowvar=False),
        np.cov(stage_coord_2D[geo_pred == 1 & bidx_valid], rowvar=False)
    ])

    likelihood = np.stack([multivariate_normal.pdf(stage_coord_2D, mean=mm_2D[i], cov=cc_2D[i])
                           for i in [0, 1]])
    geo_pred_proba = (likelihood / likelihood.sum(axis=0))

    return (geo_pred, geo_pred_proba.T, mm_2D, cc_2D)


def classify_active_and_NREM_by_GHMM(stage_coord_2D, pred_2D, mm_2D, cc_2D):
    # Initialize active/stative(NREM) clusters by Gaussian mixture model ignoring transition probablity
    print_log('Classify active/NREM clusters with GHMM')

    weights = np.array(
        [np.sum(pred_2D == 0), np.sum(pred_2D == 1)])/len(pred_2D)

    ghmm_2D = hmm.GaussianHMM(
        n_components=2, covariance_type='full', init_params='t', params='tmcs')
    ghmm_2D.startprob_ = weights
    ghmm_2D.means_ = mm_2D
    ghmm_2D.covars_ = cc_2D

    ghmm_2D.fit(stage_coord_2D)
    ghmm_2D_pred = ghmm_2D.predict(stage_coord_2D)
    ghmm_2D_proba = ghmm_2D.predict_proba(stage_coord_2D)
    return (ghmm_2D_pred, ghmm_2D_proba, ghmm_2D.means_, ghmm_2D.covars_)


def classify_Wake_and_REM(stage_coord_active, rem_floor):
    # Classify REM and Wake in the active cluster in the 3D space  (Low freq. x High freq. x REM metric)
    print_log('Classify REM and Wake clusters with GMM')

    # exclude intermediate points between REM and Wake, and points having substantial sleep_freq power
    bidx_wake_rem = ((stage_coord_active[:, 2] > rem_floor) | (
        stage_coord_active[:, 2] < 0)) & (stage_coord_active[:, 0] < 0)
    stage_coord_wake_rem = stage_coord_active[bidx_wake_rem, :]

    # gmm for wake & REM
    gmm_wr = mixture.GaussianMixture(n_components=3, n_init=100, means_init=[
                                     [-5, 5, -10], [0, 0, 20], [0, 0, 0]])  # Wake, REM, intermediate
    gmm_wr.fit(stage_coord_wake_rem)
    ww_wr = gmm_wr.weights_
    mm_wr = gmm_wr.means_
    cc_wr = gmm_wr.covariances_
    pred_wr = gmm_wr.predict(stage_coord_active)
    pred_wr_proba = gmm_wr.predict_proba(stage_coord_active)

    # Treat the intermediate as wake
    pred_wr[pred_wr == 2] = 0
    pred_wr_proba = np.array([[x[0]+x[2], x[1]] for x in pred_wr_proba])

    # The subsequent process uses the Wake and REM clusters
    ww_wr = ww_wr[np.r_[0, 1]]
    mm_wr = mm_wr[np.r_[0, 1]]
    cc_wr = cc_wr[np.r_[0, 1]]

    if mm_wr[0, 2] > mm_wr[1, 2]:
        # flip the order of clusters to assure the order of indices 0:Wake, 1:REM
        mm_wr = np.array([mm_wr[1], mm_wr[0]])
        cc_wr = np.array([cc_wr[1], cc_wr[0]])
        pred_wr = np.array([0 if x == 1 else 1 for x in pred_wr])
        pred_wr_proba = pred_wr_proba[:, np.r_[1, 0]]

    return (pred_wr, pred_wr_proba, mm_wr, cc_wr, ww_wr)


def classify_three_stages(stage_coord, mm_3D, cc_3D, weights_3c, max_rem_prn_len):
    # pylint: disable = attribute-defined-outside-init
    # classify REM, Wake, and NREM by Gaussian HMM in the 3D space
    print_log('Classify REM, Wake, and NREM by Gaussian HMM')
    ghmm_3D = CustomedGHMM(
        n_components=3, covariance_type='full', init_params='t', params='ct')
    ghmm_3D.startprob_ = weights_3c
    ghmm_3D.means_ = mm_3D
    ghmm_3D.covars_ = cc_3D
    ghmm_3D.set_wr_boundary(0)
    ghmm_3D.set_nr_boundary(0)
    ghmm_3D.set_max_rem_ax(max_rem_prn_len)

    ghmm_3D.fit(stage_coord)
    pred_3D = ghmm_3D.predict(stage_coord)
    pred_3D_proba = ghmm_3D.predict_proba(stage_coord)
    pred_3D_mm = ghmm_3D.means_
    pred_3D_cc = ghmm_3D.covars_

    return (pred_3D, pred_3D_proba, pred_3D_mm, pred_3D_cc)


def classify_two_stages(stage_coord, pred_2D_org, mm_2D_org, cc_2D_org, mm_active, cc_active):
    ndata = len(stage_coord)
    bidx_active = (pred_2D_org == 0)
    # perform GMM to refine active/NREM classification
    pred_2D, pred_2D_proba, mm_2D, cc_2D = classify_active_and_NREM_by_GHMM(
        stage_coord[:, 0:2], pred_2D_org, mm_2D_org, cc_2D_org)

    # construct 3D means and covariances from mm_2D and mm_active with TINY (non-effective) REM cluster
    # This non-effective REM cluster is just for convenience of plotting, so has nothing to do with analytical process.
    mm_3D = np.vstack([mm_active[0], [0, 0, 100], np.mean(
        stage_coord[pred_2D == 1], axis=0)])  # Wake, REM, NREM
    cc_3D = np.vstack([[cc_active[0]], [np.diag([0.01, 0.01, 0.01])], [
                      np.cov(stage_coord[pred_2D == 1], rowvar=False)]])

    # change label of NREM from 1 to 2 so that REM can use label:1
    pred_3D = np.array([2 if x == 1 else 0 for x in pred_2D])
    idx_active = np.where(bidx_active)[0]
    # idx_REMlike = idx_active[bidx_REMlike]
    # pred_3D[idx_REMlike] = 1

    pred_3D_proba = np.zeros([ndata, 3])
    # probability of REM is always zero, but sometimes REM like.
    pred_3D_proba[:, np.r_[0, 2]] = pred_2D_proba

    return pred_2D, pred_2D_proba, mm_2D, cc_2D, pred_3D, pred_3D_proba, mm_3D, cc_3D


def classification_process(stage_coord, rem_floor):
    if np.any(np.isnan(stage_coord)) or np.any(np.isinf(stage_coord)):
        raise ValueError(f"Invalid values in stage_coord: {stage_coord}")
    ndata = len(stage_coord)

    # 2-stage classification
    print("classify active and NREM")
    # classify active and NREM clusters on the 2D plane of (Low freq. x High freq.)
    pred_2D, pred_2D_proba, mm_2D, cc_2D = classify_active_and_NREM(
        stage_coord[:, 0:2])

    # Calculate the length of the longest principal axis of the active cluster
    w, v = linalg.eigh(cc_2D[0])
    w = np.sqrt(w)
    prn_ax = v@np.diag(w) # each column is the principal axis
    prn_ax_len = np.sqrt(np.diag(prn_ax.T@prn_ax))
    max_prn_ax_len = np.max(prn_ax_len)

    # Classify REM and Wake in the active cluster in the 3D space  (Low freq. x High freq. x REM metric)
    bidx_active = (pred_2D == 0)
    stage_coord_active = stage_coord[bidx_active, :]
    # pylint: disable=unused-variable
    print("classify Wake and REM")
    pred_active, pred_active_proba, mm_active, cc_active, ww_active = classify_Wake_and_REM(
        stage_coord_active, rem_floor)

    # If the z values of the both clusters are negative or zero, it means there is no REM cluster
    if np.all(mm_active[:, 2] <= 0):
        # process for data NOT having effective REM cluster
        print_log('No effective REM cluster was found.')

        pred_2D, pred_2D_proba, mm_2D, cc_2D, pred_3D, pred_3D_proba, mm_3D, cc_3D = classify_two_stages(
            stage_coord, pred_2D, mm_2D, cc_2D, mm_active, cc_active)

    else:
        # process for data having effective REM culster (this is the standard process)

        # construct 3D means and covariances from mm_2D and mm_active
        # Wake, REM, NREM
        mm_3D = np.vstack(
            [mm_active, np.mean(stage_coord[pred_2D == 1], axis=0)])
        cc_3D = np.vstack(
            [cc_active, [np.cov(stage_coord[pred_2D == 1], rowvar=False)]])

        # three cluster weights; Wake, REM, NREM
        weights_3c = np.array([np.sum(pred_active == 0), np.sum(
            pred_active == 1), np.sum(pred_2D == 1)])/ndata

        # 3-stage classification: classify REM, Wake, and NREM by Gaussian HMM on the 3D space
        try:
            pred_3D, pred_3D_proba, mm_3D, cc_3D = classify_three_stages(
                stage_coord, mm_3D, cc_3D, weights_3c, max_prn_ax_len)
        except ValueError as valerr:
            if valerr.args[0] == 'Invalid_REM_Cluster':
                print_log('REM cluster is invalid.')
                pred_2D, pred_2D_proba, mm_2D, cc_2D, pred_3D, pred_3D_proba, mm_3D, cc_3D = classify_two_stages(
                    stage_coord, pred_2D, mm_2D, cc_2D, mm_active, cc_active)
            else:
                raise

    return pred_2D, pred_2D_proba, mm_2D, cc_2D, pred_3D, pred_3D_proba, mm_3D, cc_3D


def draw_scatter_plots(path2figures, stage_coord, pred2, means2, covars2, c_pred3, c_means, c_covars, draw_pdf_plot=False):
    print_log('Drawing scatter plots')

    colors = [COLOR_WAKE, COLOR_NREM]
    axes = [0, 1]
    points = stage_coord[:, np.r_[axes]]
    fig = plot_scatter2D(points, pred2, means2, covars2,
                         colors, XLABEL, YLABEL, diag_line=True)
    _savefig(path2figures, 'ScatterPlot2D_LowFreq-HighFreq_Axes_Active-NREM', fig, draw_pdf_plot)

    points_active = stage_coord[((c_pred3 == 0) | (c_pred3 == 1)), :]
    pred_active = c_pred3[((c_pred3 == 0) | (c_pred3 == 1))]

    axes = [0, 2]  # Low-freq axis & REM axis
    points_prj = stage_coord[:, np.r_[axes]]
    colors = [COLOR_WAKE, COLOR_REM, COLOR_NREM]
    mm = np.array([m[np.r_[axes]] for m in c_means[np.r_[0, 1, 2]]])
    cc = np.array([c[np.r_[axes]][:, np.r_[axes]]
                   for c in c_covars[np.r_[0, 1, 2]]])
    fig = plot_scatter2D(points_prj, c_pred3, mm,
                         cc, colors, XLABEL, ZLABEL)
    _savefig(path2figures, 'ScatterPlot2D_LowFreq-REM_axes', fig, draw_pdf_plot)

    axes = [1, 2]  # High-freq axis & REM axis
    points_prj = stage_coord[:, np.r_[axes]]
    mm = np.array([m[np.r_[axes]] for m in c_means[np.r_[0, 1, 2]]])
    cc = np.array([c[np.r_[axes]][:, np.r_[axes]]
                   for c in c_covars[np.r_[0, 1, 2]]])
    fig = plot_scatter2D(points_prj, c_pred3, mm,
                         cc, colors, YLABEL, ZLABEL)
    _savefig(path2figures, 'ScatterPlot2D_HighFreq-REM_axes', fig, draw_pdf_plot)

    axes = [0, 1]  # Low-freq axis & High-freq axis
    points_prj = stage_coord[:, np.r_[axes]]
    colors = [COLOR_WAKE, COLOR_REM, COLOR_NREM]
    mm_proj = np.array([m[np.r_[axes]] for m in c_means[np.r_[0, 1, 2]]])
    cc_proj = np.array([c[np.r_[axes]][:, np.r_[axes]]
                        for c in c_covars[np.r_[0, 1, 2]]])
    fig = plot_scatter2D(points_prj, c_pred3, mm_proj,
                         cc_proj, colors, XLABEL, YLABEL, diag_line=True)
    _savefig(path2figures, 'ScatterPlot2D_LowFreq-HighFreq_axes_Wake_REM_NREM', fig, draw_pdf_plot)

    colors = [COLOR_WAKE, COLOR_REM, COLOR_NREM]
    colors_light = [lighten_color(c) for c in colors]
    fig = Figure(figsize=(SCATTER_PLOT_FIG_WIDTH,
                          SCATTER_PLOT_FIG_HEIGHT), dpi=FIG_DPI, facecolor='w')
    ax = fig.add_subplot(111, projection='3d')
    ax.view_init(elev=10, azim=-135)

    ax.set_xlim(-20, 20)
    ax.set_ylim(-20, 20)
    ax.set_zlim(-20, 20)
    ax.set_xlabel(XLABEL, fontsize=10, rotation=0)
    ax.set_ylabel(YLABEL, fontsize=10, rotation=0)
    ax.set_zlabel(ZLABEL, fontsize=10, rotation=0)
    ax.tick_params(axis='both', which='major', labelsize=8)

    for c in set(c_pred3):
        t_points = stage_coord[c_pred3 == c]
        ax.scatter3D(t_points[:, 0], t_points[:, 1], min(
            ax.get_zlim()), s=0.005, color=colors_light[c])
        ax.scatter3D(t_points[:, 0], max(ax.get_ylim()),
                     t_points[:, 2], s=0.005, color=colors_light[c])
        ax.scatter3D(max(ax.get_xlim()),
                     t_points[:, 1], t_points[:, 2], s=0.005, color=colors_light[c])

        ax.scatter3D(t_points[:, 0], t_points[:, 1],
                     t_points[:, 2], s=0.01, color=colors[c])

    _savefig(path2figures, 'ScatterPlot3D', fig, draw_pdf_plot)


def _savefig(output_dir, basefilename, fig, draw_pdf_plot):
    # PNG
    filename = f'{basefilename}.png'
    fig.savefig(os.path.join(output_dir, filename), pad_inches=0,
                bbox_inches='tight', dpi=100)
    # PDF
    if draw_pdf_plot:
        filename = f'{basefilename}.pdf'
        fig.savefig(os.path.join(output_dir, 'pdf', filename), pad_inches=0,
                    bbox_inches='tight', dpi=100)

def find_edf_files(data_dir):
    """returns list of edf files in the directory

    Args:
        data_dir (str): A path to the data directory

    Returns:
        [list]: A list returned by glob()
    """
    return glob(os.path.join(data_dir, '*.edf'))

def lighten_color(hex):
    return rgb_to_hex(tuple([int(x+(255-x)*0.5) for x in hex_to_rgb(hex)]))


def hex_to_rgb(hex_code):
    h = hex_code.lstrip('#')
    rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
    return rgb


def rgb_to_hex(rgb_tuple):
    return '#%02x%02x%02x' % rgb_tuple


def voltage_normalize(v_mat):
    v_array = v_mat.flatten()
    v_array = v_array[~np.isnan(v_array)]
    bidx_over = v_array > (np.mean(v_array)+3*np.std(v_array))
    bidx_under = v_array < (np.mean(v_array)-3*np.std(v_array))
    bidx_valid = ~(bidx_over | bidx_under)
    v_mat_norm = (
        v_mat - np.mean(v_array[bidx_valid]))/np.std(v_array[bidx_valid])

    return v_mat_norm



In [4]:
# assures frequency bins compatibe among different sampling frequencies
epoch_len_sec = 8
sample_freq = 128
n_fft = int(256 * sample_freq/100)
# same frequency bins given by signal.welch()
freq_bins = 1/(n_fft/sample_freq)*np.arange(0, 129)
bidx_sleep_freq = (freq_bins < 4) | ((freq_bins > 10) &
                                     (freq_bins < 20))  # without theta, 37 bins
bidx_active_freq = (freq_bins > 30)  # 52 bins
bidx_theta_freq = (freq_bins >= 4) & (freq_bins < 10)  # 15 bins
bidx_delta_freq = (freq_bins < 4)  # 11 bins
bidx_muscle_freq = (freq_bins > 30)  # 52 bins

n_active_freq = np.sum(bidx_active_freq)
n_sleep_freq = np.sum(bidx_sleep_freq)
n_theta_freq = np.sum(bidx_theta_freq)
n_delta_freq = np.sum(bidx_delta_freq)
n_muscle_freq = np.sum(bidx_muscle_freq)

rem_floor = np.sum(np.sqrt([n_muscle_freq, n_theta_freq]))

In [5]:
def preprocess_edf(idx,edf,epoch_num,sample_freq,epoch_len_sec,result_dir,device_id,data_dir,offset_in_msec=0):
    raw = mne.io.read_raw_edf(edf)
    measurement_start_datetime = raw.info['meas_date']
    eegname = 'EEG_{0}'.format(idx)
    emgname = 'EMG_{0}'.format(idx)
    eeg = raw.get_data(eegname)
    #print(eeg.shape)
    emg = raw.get_data(emgname)
    raw.close()
    start_idx=int(offset_in_msec*sample_freq/1000)
    actual_epoch_num=int((eeg.shape[1]-start_idx)/sample_freq/epoch_len_sec)
    if epoch_num>actual_epoch_num:
        epoch_num=actual_epoch_num
    end_idx=start_idx+epoch_num*sample_freq*epoch_len_sec
    eeg=eeg[:,start_idx:end_idx]
    emg=emg[:,start_idx:end_idx]
    eeg_vm = eeg.reshape(-1, epoch_len_sec * int(sample_freq))
    emg_vm = emg.reshape(-1, epoch_len_sec * int(sample_freq))
    os.makedirs(os.path.join(data_dir,"pkl"),exist_ok=True)
    os.makedirs(os.path.join(result_dir,"pkl"),exist_ok=True)
    eeg_pkl_path=os.path.join(result_dir,'pkl', f'{device_id}_EEG.pkl')
    #eeg_pkl_path = os.path.join(data_dir, 'pkl', f'{device_id}_EEG.pkl')
    emg_pkl_path=os.path.join(result_dir,'pkl', f'{device_id}_EMG.pkl')
    #emg_pkl_path = os.path.join(data_dir, 'pkl', f'{device_id}_EEG.pkl')


    eeg_vm_org = eeg_vm[:epoch_num,]
    emg_vm_org = emg_vm[:epoch_num,]

    with open(eeg_pkl_path, 'wb') as pkl:
        pickle.dump(eeg_vm_org, pkl)
    with open(emg_pkl_path, 'wb') as pkl:
        pickle.dump(emg_vm_org, pkl)

    print_log('Applying the optional filter on the EEG signal')
    epoch_sd = np.apply_along_axis(np.nanstd, 1 ,eeg_vm_org)
    med_sd = np.median(epoch_sd)
    bidx_no_eeg_signal = (epoch_sd / med_sd) < 0.3 # A definition of "NO signal of EEG"
    eeg_vm_org[bidx_no_eeg_signal, :] = np.nan
    print_log(f'The number of epochs with no EEG signal: {np.sum(bidx_no_eeg_signal)}')

    # recover nans in the data if possible
    nan_ratio_eeg = np.apply_along_axis(et.patch_nan, 1, eeg_vm_org)
    nan_ratio_emg = np.apply_along_axis(et.patch_nan, 1, emg_vm_org)

    # exclude unrecoverable epochs as unknown
    bidx_unknown = np.apply_along_axis(np.any, 1, np.isnan(
        eeg_vm_org)) | np.apply_along_axis(np.any, 1, np.isnan(emg_vm_org))
    eeg_vm = eeg_vm_org[~bidx_unknown, :]
    emg_vm = emg_vm_org[~bidx_unknown, :]

    # make data comparable among different mice. Not necessary for staging,
    # but convenient for subsequnet analyses.
    eeg_vm_norm = voltage_normalize(eeg_vm)
    emg_vm_norm = voltage_normalize(emg_vm)

    # remove extreme voltages (e.g. heart beat) from EMG
    print_log('Applying the optional filter on the EMG signal')
    np.apply_along_axis(remove_extreme_voltage, 1, emg_vm_norm, sample_freq)

    # power-spectrum normalization of EEG and EMG
    spec_norm_eeg = spectrum_normalize(eeg_vm_norm, n_fft, sample_freq)
    spec_norm_emg = spectrum_normalize(emg_vm_norm, n_fft, sample_freq)
    psd_norm_mat_eeg = spec_norm_eeg['psd']
    psd_norm_mat_emg = spec_norm_emg['psd']

    # remove extreme powers
    extrp_ratio_eeg = np.apply_along_axis(
        remove_extreme_power, 1, psd_norm_mat_eeg)
    extrp_ratio_emg = np.apply_along_axis(
        remove_extreme_power, 1, psd_norm_mat_emg)

    # save the PSD matrices and associated factors for subsequent analyses
    ## set bidx_unknown; other factors were set by spectrum_normalize()
    spec_norm_eeg['bidx_unknown'] = bidx_unknown
    spec_norm_emg['bidx_unknown'] = bidx_unknown
    pickle_powerspec_matrices(
        spec_norm_eeg, spec_norm_emg, result_dir, device_id)

    # spread epochs on the 3D (Low freq. x High freq. x REM metric) space
    psd_mat = np.concatenate([
        psd_norm_mat_eeg.reshape(*psd_norm_mat_eeg.shape, 1),
        psd_norm_mat_emg.reshape(*psd_norm_mat_emg.shape, 1)
    ], axis=2)
    stage_coord = np.array([(
        np.sum(y[bidx_sleep_freq, 0])/np.sqrt(n_sleep_freq),
        np.sum(y[bidx_active_freq, 0])/np.sqrt(n_active_freq),
        np.sum(y[bidx_theta_freq, 0])/np.sqrt(n_theta_freq)-np.sum(y[bidx_delta_freq, 0]) /
        np.sqrt(n_delta_freq) -
        np.sum(y[bidx_muscle_freq, 1]) / np.sqrt(n_muscle_freq)
    ) for y in psd_mat])
    ndata = len(stage_coord)

    # cancel the weight bias of active/NREM clusters
    cwb = cancel_weight_bias(stage_coord[:, 0:2])
    stage_coord[:, 0:2] = cwb['stage_coord_2D']

    # run the classification process
    pred_2D, pred_2D_proba, means_2D, covars_2D, pred_3D, pred_3D_proba, means_3D, covars_3D = classification_process(
            stage_coord, rem_floor)


    # output staging result
    stage_proba = np.zeros(3*epoch_num).reshape(epoch_num, 3)
    proba_REM = pred_3D_proba[:, 1]
    proba_WAKE = pred_3D_proba[:, 0]
    proba_NREM = pred_3D_proba[:, 2]
    stage_proba[~bidx_unknown, 0] = proba_REM
    stage_proba[~bidx_unknown, 1] = proba_WAKE
    stage_proba[~bidx_unknown, 2] = proba_NREM

    stage_call = np.repeat('Unknown', epoch_num)
    stage_call[~bidx_unknown] = np.array(
        [STAGE_LABELS[y] for y in pred_3D])

    # Print a brief result
    print_log(f'2-stage means:\n {means_2D}')
    print_log(f'2-stage covars:\n {covars_2D}')
    print_log('\n')
    print_log(f'3-stage means:\n{means_3D}')
    print_log(f'3-stage covars:\n{covars_3D}')

    # Compose stage table
    extreme_power_ratio = np.zeros(2*epoch_num).reshape(epoch_num, 2)
    extreme_power_ratio[~bidx_unknown, 0] = extrp_ratio_eeg
    extreme_power_ratio[~bidx_unknown, 1] = extrp_ratio_emg

    stage_table = pd.DataFrame({'Stage': stage_call,
                                'REM probability': stage_proba[:, 0],
                                'NREM probability': stage_proba[:, 2],
                                'Wake probability': stage_proba[:, 1],
                                'NaN ratio EEG-TS': nan_ratio_eeg,
                                'NaN ratio EMG-TS': nan_ratio_emg,
                                'Outlier ratio EEG-TS': extreme_power_ratio[:, 0],
                                'Outlier ratio EMG-TS': extreme_power_ratio[:, 1]})
    stage_file_path = os.path.join(result_dir, f'{device_id}.faster2.stage.csv')

    with open(stage_file_path, 'w', encoding='UTF-8') as f:
        f.write(f'# Start: {measurement_start_datetime} \n')
        f.write(f'# Epoch num: {epoch_num}  Epoch length: {epoch_len_sec} [s]\n')
        f.write(f'# Sampling frequency: {sample_freq} [Hz]\n')
        f.write(f'# Staged by {FASTER2_NAME}\n')
    with open(stage_file_path, 'a') as f:
        print(f)
        print(type(f))
        print(type(stage_table))
        stage_table.to_csv(f, header=True, index=False)
        #stage_table.to_csv(f, header=True, index=False, line_terminator='\n')

    path2figures = os.path.join(result_dir, 'figure', f'{device_id}')
    os.makedirs(path2figures, exist_ok=True)

    # draw the bias histogram
    plot_hist_on_separation_axis(path2figures, cwb['proj_data'], cwb['means'], cwb['covars'], cwb['weights']) 

    # draw scatter plots
    draw_scatter_plots(path2figures, stage_coord, pred_2D, means_2D, covars_2D, pred_3D, means_3D, covars_3D)

    # pickle cluster parameters
    pickle_cluster_params(means_2D, covars_2D, means_3D, covars_3D, result_dir, device_id)

    

In [6]:
# Parameters injected by papermill or CLI
prj_dir = locals().get("prj_dir", "/p-antipsychotics-sleep/raw_data/kaist")
result_dir_name = locals().get("result_dir_name", "result")
epoch_len_sec = locals().get("epoch_len_sec", 8)
sample_freq = locals().get("sample_freq", 128)
overwrite = locals().get("overwrite", False)
offset_in_msec = locals().get("offset_in_msec", 0)

crawl_prj_for_preprocess_edf(
    prj_dir,
    result_dir_name,
    epoch_len_sec,
    sample_freq=sample_freq,
    is_overwite=overwrite,
    offset_in_msec=offset_in_msec,
)


In [None]:
prj_dir="/p-antipsychotics-sleep/raw_data/kaist"
result_dir_name="result"
epoch_len_sec=8
crawl_prj_for_preprocess_edf(prj_dir,result_dir_name,epoch_len_sec,sample_freq=128,is_overwite=False,offset_in_msec=0)

/p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/data/20251120_KA001-004_C1_20251121_0659.edf
Ch0
Extracting EDF parameters from /p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/data/20251120_KA001-004_C1_20251121_0659.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Applying the optional filter on the EEG signal
The number of epochs with no EEG signal: 0
Applying the optional filter on the EMG signal
Saving PSD files
Saving the EEG PSD matrix into /p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/result/PSD/Ch0_EEG_PSD.pkl
Saving the EMG PSD matrix into /p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/result/PSD/Ch0_EMG_PSD.pkl
Estimate the bias of the two cluster means
Estimated bias: 0.2525074846393874
classify active and NREM
Initialize active/NREM clusters with the diagonal line
classify Wake and REM
Classify REM and Wake clusters with GMM
Classify REM, Wake, and NREM by Gaussian HMM


Model is not converging.  Current: -1629140.6276807406 is not greater than -1545147.617589524. Delta is -83993.01009121654


2-stage means:
 [[-3.89268075  1.56408334]
 [ 3.93686728 -1.21142719]]
2-stage covars:
 [[[3.55096284 1.70800358]
  [1.70800358 7.94878852]]

 [[9.79981055 5.32973048]
  [5.32973048 8.44204267]]]


3-stage means:
[[-4.13981282  2.13202418 -7.61062683]
 [-4.16658802 -0.21423975 13.66234877]
 [ 3.94682668 -1.20308277  5.23015322]]
3-stage covars:
[[[ 1.97220645e+02  2.28918747e+02 -2.46289591e+01]
  [ 2.28918747e+02  2.90118803e+02 -4.15912166e+01]
  [-2.46289591e+01 -4.15912166e+01  2.49589494e+01]]

 [[ 4.39978973e+00 -1.87936722e-01 -4.42189143e-01]
  [-1.87936722e-01  8.52566277e+00 -2.01004392e-02]
  [-4.42189143e-01 -2.01004392e-02  8.48691219e+00]]

 [[ 7.50331138e+01  8.74540258e+01 -3.96537663e+01]
  [ 8.74540258e+01  1.19766113e+02 -5.84542701e+01]
  [-3.96537663e+01 -5.84542701e+01  3.92456131e+01]]]
<_io.TextIOWrapper name='/p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/result/Ch0.faster2.stage.csv' mode='a' encoding='UTF-8'>
<class '_io.TextIOWrapper'>
<class 'pan

Model is not converging.  Current: -1397391.5354502366 is not greater than -1327824.0005018285. Delta is -69567.53494840814


2-stage means:
 [[-2.96504341  1.59869415]
 [ 3.09599893 -1.35592619]]
2-stage covars:
 [[[ 5.30184384  3.57037938]
  [ 3.57037938  8.70058843]]

 [[17.55087331 12.36932082]
  [12.36932082 14.81654669]]]


3-stage means:
[[-2.75423082  2.64328567 -7.04658669]
 [-4.53393678  0.50608025 13.34618073]
 [ 3.10530894 -1.3517144   4.33428252]]
3-stage covars:
[[[ 5.43447707e+02  5.98495753e+02  1.09599379e+01]
  [ 5.98495753e+02  6.75393688e+02  2.61711518e+00]
  [ 1.09599379e+01  2.61711518e+00  1.86728348e+01]]

 [[ 1.09553880e+01 -2.65949598e-16  1.74162640e-15]
  [-3.05736779e-16  1.09553880e+01 -2.15103939e-15]
  [ 1.89456606e-15 -2.38160898e-15  1.09553880e+01]]

 [[ 1.66334662e+01  1.48419188e+01 -3.00605646e+00]
  [ 1.48419188e+01  2.79491643e+01 -1.51759689e+01]
  [-3.00605646e+00 -1.51759689e+01  2.04796129e+01]]]
<_io.TextIOWrapper name='/p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/result/Ch1.faster2.stage.csv' mode='a' encoding='UTF-8'>
<class '_io.TextIOWrapper'>
<cl

Model is not converging.  Current: -1275746.1211524303 is not greater than -1273487.9274916253. Delta is -2258.193660805002


2-stage means:
 [[-3.94318376  1.28839122]
 [ 2.73996412 -1.27390965]]
2-stage covars:
 [[[ 8.91690753  4.96502583]
  [ 4.96502583  9.97698654]]

 [[ 8.30763522  6.337278  ]
  [ 6.337278   11.63294245]]]


3-stage means:
[[-5.01212345  1.1431659  -9.15767857]
 [-2.91128128 -0.1026919  11.9460041 ]
 [ 2.79925979 -1.42489872  5.05331422]]
3-stage covars:
[[[ 2.96442494e+01  5.51472244e+00  2.46608931e-02]
  [ 5.51472244e+00  9.80088566e+00 -5.53336473e+00]
  [ 2.46608931e-02 -5.53336473e+00  2.72254521e+01]]

 [[ 8.61811719e+00 -2.07191925e+00 -1.19498436e+00]
  [-2.07191925e+00  1.25471408e+01 -5.25514630e-01]
  [-1.19498436e+00 -5.25514630e-01  1.41862169e+01]]

 [[ 4.35733947e+01  5.09273650e+01 -8.74294003e+00]
  [ 5.09273650e+01  7.16639716e+01 -1.88322336e+01]
  [-8.74294003e+00 -1.88322336e+01  1.61003379e+01]]]
<_io.TextIOWrapper name='/p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/result/Ch2.faster2.stage.csv' mode='a' encoding='UTF-8'>
<class '_io.TextIOWrapper'>
<cl



Estimated bias: 0.49839837081563915
classify active and NREM
Initialize active/NREM clusters with the diagonal line
classify Wake and REM
Classify REM and Wake clusters with GMM
Classify REM, Wake, and NREM by Gaussian HMM


Model is not converging.  Current: -1260948.0467121517 is not greater than -1260849.4924072747. Delta is -98.55430487706326


2-stage means:
 [[-3.94751927  1.17692335]
 [ 3.87623548 -0.33454276]]
2-stage covars:
 [[[ 7.59934492  3.69197052]
  [ 3.69197052  9.62261437]]

 [[14.1156202  10.47979193]
  [10.47979193 13.99595105]]]


3-stage means:
[[ -5.13470028   0.6141648  -10.15438861]
 [ -2.7754298    0.18502861  12.34242145]
 [  3.86204041  -0.37566479   5.19367669]]
3-stage covars:
[[[24.30561771  3.56454196  1.69138952]
  [ 3.56454196  7.61055346 -3.647628  ]
  [ 1.69138952 -3.647628   25.58161414]]

 [[ 9.43422368  3.69568745 -0.30250105]
  [ 3.69568745  7.89363712  0.37205239]
  [-0.30250105  0.37205239 12.40858726]]

 [[12.28252613  5.64003426  6.87403941]
  [ 5.64003426 12.46615144 -2.65984376]
  [ 6.87403941 -2.65984376 15.70394118]]]
<_io.TextIOWrapper name='/p-antipsychotics-sleep/raw_data/kaist/20251120_KA001-004/result/Ch3.faster2.stage.csv' mode='a' encoding='UTF-8'>
<class '_io.TextIOWrapper'>
<class 'pandas.core.frame.DataFrame'>
Drawing scatter plots
Saving the cluster parameters into /p-anti



Estimated bias: 0.7053791672451819
classify active and NREM
Initialize active/NREM clusters with the diagonal line
classify Wake and REM
Classify REM and Wake clusters with GMM
Classify REM, Wake, and NREM by Gaussian HMM
2-stage means:
 [[-2.30313336  2.85801019]
 [ 1.57321731 -2.04185038]]
2-stage covars:
 [[[ 6.03511666  3.67239639]
  [ 3.67239639 12.28972861]]

 [[12.61846096 10.95777609]
  [10.95777609 14.52458201]]]


3-stage means:
[[-2.89729967  3.8293945  -6.91766077]
 [-4.16429545 -0.21888622 13.06369802]
 [ 1.51726171 -2.10721544  4.2964438 ]]
3-stage covars:
[[[12.28131014  0.93740312  3.79069874]
  [ 0.93740312 23.52980699 -0.15895033]
  [ 3.79069874 -0.15895033 22.06500285]]

 [[11.92632621  2.21313563 -0.37081816]
  [ 2.21313563 11.60788787  0.39845425]
  [-0.37081816  0.39845425 13.91920029]]

 [[18.4756079  21.37802081  0.44137799]
  [21.37802081 34.13305963 -5.5986988 ]
  [ 0.44137799 -5.5986988  10.60215933]]]
<_io.TextIOWrapper name='/p-antipsychotics-sleep/raw_data

Model is not converging.  Current: -1221637.287987207 is not greater than -1220551.0218082368. Delta is -1086.2661789702252


2-stage means:
 [[-2.00791695  1.36732864]
 [ 2.66576638 -1.92587904]]
2-stage covars:
 [[[ 5.57452154  4.05187835]
  [ 4.05187835  6.13775593]]

 [[11.50269544  9.5094813 ]
  [ 9.5094813  15.82646553]]]


3-stage means:
[[-2.1977545   2.11996426 -7.43324256]
 [-3.18472055 -0.62644907 15.62064365]
 [ 2.49361129 -2.19053299  3.27728175]]
3-stage covars:
[[[ 7.76926609 -0.50220448  4.02392504]
  [-0.50220448  5.20834642 -5.37500413]
  [ 4.02392504 -5.37500413 27.65331204]]

 [[ 5.96478881 -1.55891718 -0.60258664]
  [-1.55891718  9.30301302 -0.23763773]
  [-0.60258664 -0.23763773  9.82593498]]

 [[13.98937321  8.83011239 15.11890085]
  [ 8.83011239 20.12675731 -1.7127052 ]
  [15.11890085 -1.7127052  45.64887199]]]
<_io.TextIOWrapper name='/p-antipsychotics-sleep/raw_data/kaist/20251120_KA005-008/result/Ch1.faster2.stage.csv' mode='a' encoding='UTF-8'>
<class '_io.TextIOWrapper'>
<class 'pandas.core.frame.DataFrame'>
Drawing scatter plots
Saving the cluster parameters into /p-antipsychotic

Model is not converging.  Current: -1334062.7789356124 is not greater than -1322299.4080426334. Delta is -11763.370892978972


2-stage means:
 [[-2.97549343  2.84644791]
 [ 1.27173061 -1.71913553]]
2-stage covars:
 [[[11.82333025  9.27851767]
  [ 9.27851767 18.69892964]]

 [[11.97590483 10.66917343]
  [10.66917343 14.90477055]]]


3-stage means:
[[-3.76428789  4.24662563 -8.89867605]
 [-5.85278478 -2.16324544 13.79647495]
 [ 1.21552656 -1.9964256   4.93149316]]
3-stage covars:
[[[147.08241323  88.61499597  35.78440703]
  [ 88.61499597  78.27863025   8.53194182]
  [ 35.78440703   8.53194182  40.6611643 ]]

 [[ 22.58224957   2.74093434  -0.90841514]
  [  2.74093434  22.23711936   0.96740619]
  [ -0.90841514   0.96740619  24.83542271]]

 [[ 14.76820458   7.70653432   3.56173459]
  [  7.70653432   8.38234156  -3.20031541]
  [  3.56173459  -3.20031541  14.08997156]]]
<_io.TextIOWrapper name='/p-antipsychotics-sleep/raw_data/kaist/20251120_KA005-008/result/Ch2.faster2.stage.csv' mode='a' encoding='UTF-8'>
<class '_io.TextIOWrapper'>
<class 'pandas.core.frame.DataFrame'>
Drawing scatter plots
Saving the cluster param

In [2]:
import pandas as pd
print(pd.DataFrame.to_csv)

<function NDFrame.to_csv at 0x7fcb85d3ba60>


In [1]:
import matplotlib
print(matplotlib.__version__)

3.9.4


In [10]:
!pip install --upgrade matplotlib

Collecting matplotlib
  Using cached matplotlib-3.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Using cached matplotlib-3.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.3 MB)
Installing collected packages: matplotlib
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.3.4
    Uninstalling matplotlib-3.3.4:
      Successfully uninstalled matplotlib-3.3.4
Successfully installed matplotlib-3.9.4
[0m