In [1]:
pip install scikit-image

Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'C:\Users\janlu\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [2]:
pip install emd

Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'C:\Users\janlu\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [3]:
pip install pingouin

Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'C:\Users\janlu\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [4]:
pip install h5py

Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'C:\Users\janlu\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [5]:
pip install sails

Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'C:\Users\janlu\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [1]:
import emd.sift as sift
import emd.spectra as spectra
import numpy as np
import pingouin as pg
import sails
import scipy.io as sio
import h5py
import time
import timeit
import matplotlib.pyplot as plt
from scipy.stats import zscore
from scipy.signal import convolve2d
from scipy.stats import zscore, binned_statistic
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
import os
import glob
import concurrent.futures
from skimage.feature import peak_local_max


In [2]:
# The get_rem_states function takes in an array of sleep states and the sample rate of the data.
def get_rem_states(states, sample_rate):
    """
    Extract consecutive REM (Rapid Eye Movement) sleep states and their start
    and end times from an array of sleep states.

    Parameters:
    - states (numpy.ndarray): One-dimensional array of sleep states.
    - sample_rate (int): The sample rate of the data.

    Returns:
    numpy.ndarray: An array containing start and end times of consecutive REM
    sleep states. Each row represents a pair of start and end times.

    Note:
    - Sleep states are represented numerically. In this function, REM sleep
      states are identified by the value 5 in the 'states' array.

    Example:
    ```python
    import numpy as np

    # Example usage:
    sleep_states = np.array([1, 2, 5, 5, 5, 3, 2, 5, 5, 4, 1])
    sample_rate = 2500  # Example sample rate in Hz
    rem_states_times = get_rem_states(sleep_states, sample_rate)
    print(rem_states_times)
    ```
    """
    try:
        # Ensure the sleep states array is one-dimensional.
        states = np.squeeze(states)
        # Find the indices where the sleep state is equal to 5, indicating REM sleep.
        rem_state_indices = np.where(states == 5)[0]
        
        # Check if there are no REM states. If so, return an empty array.
        if len(rem_state_indices) == 0:
            return np.array([])
        # Calculate the changes between consecutive REM state indices.
        rem_state_changes = np.diff(rem_state_indices)
        # Find the indices where consecutive REM states are not adjacent.
        split_indices = np.where(rem_state_changes != 1)[0] + 1
        # Add indices to split consecutive REM states, including the start and end indices.
        split_indices = np.concatenate(([0], split_indices, [len(rem_state_indices)]))
        # Create an empty array to store start and end times of consecutive REM states.
        consecutive_rem_states = np.empty((len(split_indices) - 1, 2))
        # Iterate through the split indices to extract start and end times.
        for i, (start, end) in enumerate(zip(split_indices, split_indices[1:])):
            start = rem_state_indices[start] * int(sample_rate)
            end = rem_state_indices[end - 1] * int(sample_rate)
            consecutive_rem_states[i] = np.array([start, end])
        # Convert the array to a numpy array.
        ##consecutive_rem_states = np.array(consecutive_rem_states)
        # Create a mask to filter out consecutive REM states with negative duration.
        null_states_mask = np.squeeze(np.diff(consecutive_rem_states) > 0)
        consecutive_rem_states = consecutive_rem_states[null_states_mask]
        # Return the array containing start and end times of consecutive REM states.
        return consecutive_rem_states
    # Handle the case where an IndexError occurs, typically due to an empty array.
    except IndexError as e:
        print(f"An IndexError occurred in get_rem_states: {e}")
        return np.array([])  # or any default value you want


# This function computes the Morlet wavelet transform of a given signal.
# It uses the SAILS library to perform the wavelet transform.
def morlet_wt(x, sample_rate, frequencies=np.arange(1, 200, 1), n=5, mode='complex'):
        """
    Compute the Morlet wavelet transform of a given signal using the SAILS library.

    Parameters:
    - x (numpy.ndarray): The input signal.
    - sample_rate (int): The rate at which the signal is sampled.
    - frequencies (numpy.ndarray, optional): The array of frequencies at which to compute the transform
      (default is from 1 to 200 Hz).
    - n (int, optional): The number of cycles in the Morlet wavelet (default is 5).
    - mode (str, optional): The mode of the return, whether 'complex', 'power', or 'amplitude'
      (default is 'complex').
      
    Returns:
    - numpy.ndarray: The computed Morlet wavelet transform of the input signal.

    Note:
    - This function relies on the SAILS library to perform the wavelet transform.

    Example:
    ```python
    import numpy as np

    # Example usage:
    signal = np.sin(2 * np.pi * 10 * np.arange(0, 1, 1/sample_rate))
    wt_result = morlet_wt(signal, sample_rate)
    print(wt_result)
    ```
    """
    wavelet_transform = sails.wavelet.morlet(x, freqs=frequencies, sample_rate=sample_rate, ncycles=n,
                                             ret_mode=mode, normalise=None)
    # Return the computed wavelet transform.
    return wavelet_transform


# The tg_split function categorizes frequency values into three groups: sub-theta, theta, and supra-theta.
def tg_split(mask_freq, theta_range=(5, 12)):
    """
    Categorize frequency values into three groups: sub-theta, theta, and supra-theta.

    Parameters:
    - mask_freq (numpy.ndarray): An array of frequency values that you want to categorize.
    - theta_range (tuple, optional): A range of frequencies considered as the theta band
      (default is (5, 12) Hz).

    Returns:
    - tuple: A tuple containing three boolean masks representing sub-theta, theta, and supra-theta categories.

    Example:
    ```python
    import numpy as np

    # Example usage:
    freq_values = np.array([3, 8, 10, 15, 20])
    sub_mask, theta_mask, supra_mask = tg_split(freq_values)
    print("Sub-theta frequencies:", freq_values[sub_mask])
    print("Theta frequencies:", freq_values[theta_mask])
    print("Supra-theta frequencies:", freq_values[supra_mask])
    ```
    """
    # Get the lower and upper bounds of the theta range.
    lower = np.min(theta_range)
    upper = np.max(theta_range)
    # Create a boolean mask for frequencies within the theta range.
    mask_index = np.logical_and(mask_freq >= lower, mask_freq < upper)
    # Create boolean masks for frequencies below and above the theta range.
    sub_mask_index = mask_freq < lower
    supra_mask_index = mask_freq > upper
    # Assign the boolean masks to variables for each category.
    sub = sub_mask_index
    theta = mask_index
    supra = supra_mask_index
    # Return the boolean masks for sub-theta, theta, and supra-theta categories.
    return sub, theta, supra

# This function finds the indices where a signal crosses zero.
# x: The input signal.
def zero_cross(x):
    """
    Find the indices where a signal crosses zero.

    Parameters:
    - x (numpy.ndarray): The input signal.

    Returns:
    - numpy.ndarray: An array containing the indices where the input signal crosses zero.

    Example:
    ```python
    import numpy as np

    # Example usage:
    signal = np.array([1, -2, 3, -1, 0, 2, -4, 5])
    zero_cross_indices = zero_cross(signal)
    print("Zero-crossing indices:", zero_cross_indices)
    ```
    """
    # Identify where the signal goes from positive to non-positive (decay).
    decay = np.logical_and((x > 0)[1:], ~(x > 0)[:-1]).nonzero()[0]
    # Identify where the signal goes from non-positive to positive (rise).
    rise = np.logical_and((x <= 0)[1:], ~(x <= 0)[:-1]).nonzero()[0]
    # Combine the indices of rise and decay, then sort them with ascending indices.
    zero_xs = np.sort(np.append(rise, decay))
    # Return the sorted indices where the signal crosses zero.
    return zero_xs

# This function identifies the zero crossings, peaks, and troughs in a signal.
def extrema(x):
    """
    Identify the zero crossings, peaks, and troughs in a signal.

    Parameters:
    - x (numpy.ndarray): The input signal.

    Returns:
    - tuple: A tuple containing three arrays - zero-crossing indices, trough indices, and peak indices.

    Example:
    ```python
    import numpy as np

    # Example usage:
    signal = np.array([1, -2, 3, -1, 0, 2, -4, 5])
    zero_crossings, trough_indices, peak_indices = extrema(signal)
    print("Zero-crossing indices:", zero_crossings)
    print("Trough indices:", trough_indices)
    print("Peak indices:", peak_indices)
    ```
    """
    # Find the indices where the signal crosses zero.
    zero_xs = zero_cross(x)
    # Initialize empty arrays to store peak and trough indices.
    peaks = np.empty((0,)).astype(int)
    troughs = np.empty((0,)).astype(int)
    # Iterate through pairs of consecutive zero crossings.
    for t1, t2 in zip(zero_xs, zero_xs[1:]):
        # Find the index of the maximum absolute value in the current segment.
        extrema0 = np.argmax(np.abs(x[t1:t2])).astype(int) + t1
        # Check if the value at the found index is positive (peak) or non-positive (trough).
        if bool(x[extrema0] > 0):
            peaks = np.append(peaks, extrema0)
        else:
            troughs = np.append(troughs, extrema0)
    # Return the indices of zero crossings, troughs, and peaks.
    return zero_xs, troughs, peaks

#The get_cycles_data function generates a nested dictionary containing extracted data and desired metadata of each REM epochs in the input sleep
def get_cycles_data(x, rem_states, sample_rate, frequencies, theta_range=(5, 12)):
    
    """
    Generate a nested dictionary containing extracted data and desired metadata of each REM epoch in the input sleep signal.

    Parameters:
    - x (numpy.ndarray): The input 1D sleep signal.
    - rem_states (numpy.ndarray): A sleep state vector where 5 represents REM sleep and other values indicate non-REM.
    - sample_rate (int or float): The sampling rate of the data.
    - frequencies (numpy.ndarray): The array of frequencies at which to compute the wavelet transform.
    - theta_range (tuple, optional): A tuple defining the theta frequency range (lower, upper).
      Default is (5, 12).

    Returns:
    - dict: A nested dictionary of extracted signal data and signal source metadata for each REM epoch.

    Notes:
    - The dictionary output structure comes out as below:
      |----REM 1
       |    |----start-end:
       |    |----wavelet_transform:
       |    |----IMFs:
       |    |----IMF_Frequencies:
       |    |----Instantaneous Phases:
       |    |----Instantaneous Frequencies:
       |    |----Instantaneous Amplitudes:
       |    |----Cycles:
      |----REM (...)
       |    |--------(...)
    """

    # Squeezing dimensions
    x = np.squeeze(x)
    rem_states = np.squeeze(rem_states)
    # Print the shapes of the input arrays (for debugging)
    print(x.shape)
    print(rem_states.shape)

    # Detect REM periods and get start and end times
    consecutive_rem_states = get_rem_states(rem_states, sample_rate).astype(int)
    # If the consecutive REM states have an extra dimension, squeeze it out
    if consecutive_rem_states.ndim == 3:
        consecutive_rem_states=np.squeeze(consecutive_rem_states,0)
    # Print the shape of the REM states array (for debugging)
    print(consecutive_rem_states.shape)

    # Initialize variables to store various extracted data
    wt_spectrum = []
    rem_imf = []
    rem_mask_freq = []
    instantaneous_phase = []
    instantaneous_freq = []
    instantaneous_amp = []
    sub_theta_sig = np.empty((0,))
    theta_peak_sig = np.empty((0,))
    cycles = np.empty((0, 5))
    rem_dict = {}
    sub_dict = rem_dict

    # Loop through each REM epoch
    for i, rem in enumerate(consecutive_rem_states, start=1):
        # Create a sub-dictionary for the current REM epoch
        sub_dict.setdefault(f'REM {i}', {})
        # Get the start and end indices of the current REM epoch
        start = rem[0]
        end = rem[1]
        # Extract the signal corresponding to the current REM epoch
        signal = x[start:end]

        # Generate the time-frequency power spectrum using Morlet wavelet transform
        wavelet_transform = morlet_wt(signal, sample_rate, frequencies, mode='amplitude')

        # Extraction of IMFs and IMF Frequencies for current REM epoch
        imf, mask_freq = sift.iterated_mask_sift(signal,
                                                 mask_0='zc',
                                                 sample_rate=sample_rate,
                                                 ret_mask_freq=True)

        # Extract Instantaneous Phase, Frequencies and Amplitudes of each IMF for current REM epoch
        IP, IF, IA = spectra.frequency_transform(imf, sample_rate, 'nht')

        # Identify sub-theta, theta, and supra-theta frequencies using a given mask
        sub_theta, theta, _ = tg_split(mask_freq, theta_range)
        # Store the results for the current REM epoch in respective lists
        wt_spectrum.append(wavelet_transform)
        rem_imf.append(imf)
        rem_mask_freq.append(mask_freq)
        instantaneous_phase.append(IP)
        instantaneous_freq.append(IF)
        instantaneous_amp.append(IA)

        # Generate the theta signal to detect cycles
        theta_sig = np.sum(imf.T[theta], axis=0)

        # Parse the sub-theta signal of all REM periods into one variable to set amplitude threshold
        sub_theta_sig = np.append(sub_theta_sig, np.sum(imf.T[sub_theta], axis=0))

        # Generate extrema locations and zero crossing on the generated theta signal
        zero_x, trough, peak = extrema(np.sum(imf.T[theta], axis=0))

        # Transpose the IMFs (Intrinsic Mode Functions) to align them properly for further analysis
        # Each row in the transposed matrix represents the corresponding IMF for all time points
        # This is done to ensure that operations like summing across rows are performed along the correct axis
        zero_x = np.vstack((zero_x[:-2:2], zero_x[1:-1:2], zero_x[2::2])).T
        

        #size_adjust = np.min([trough.shape[0], zero_x.shape[0], peak.shape[0]])
        #zero_x = zero_x[:size_adjust]
        #cycle = np.empty((size_adjust, 5))
        #cycle[:, [0, 2, 4]] = zero_x
        #if trough[0] < peak[0]:
            #cycle[:, 1] = trough[:zero_x.shape[0]]
            #cycle[:, 3] = peak[:zero_x.shape[0]]
        #else:
            #cycle[:, 3] = trough[:zero_x.shape[0]]
            #cycle[:, 1] = peak[:zero_x.shape[0]]
            
        # Calculate the minimum size among the arrays (trough, zero_x, peak)
        size_adjust = np.min([trough.shape[0], zero_x.shape[0], peak.shape[0]])
        # Check if the size_adjust is greater than 0
        if size_adjust > 0:
            # Trim zero_x to the size determined by size_adjust
            zero_x = zero_x[:size_adjust]
            # Initialize an empty array for cycles with the determined size_adjust
            cycle = np.empty((size_adjust, 5))
            # Fill the columns related to zero crossings in the cycle array
            cycle[:, [0, 2, 4]] = zero_x
            # Check the relationship between the first points of trough and peak arrays
            if trough[0] < peak[0]:
                # If trough comes first, fill the respective columns in the cycle array
                cycle[:, 1] = trough[:size_adjust]
                cycle[:, 3] = peak[:size_adjust]
            else:
                # If peak comes first, fill the respective columns in the cycle array
                cycle[:, 3] = trough[:size_adjust]
                cycle[:, 1] = peak[:size_adjust]
        else:
            # Handle case when all arrays are empty or size_adjust is zero
            cycle = np.empty((0, 5))
        # Extract broken cycles (where the condition on diff is not met)
        broken_cycle = cycle[~np.all(np.diff(cycle, axis=1) > 0, axis=1)]
        # Create a mask for broken cycles based on differences between elements
        broken_cycle_mask = np.diff(broken_cycle, axis=1) > 0
        # Check if the broken cycles follow the specific adjust_condition pattern
        adjust_condition = np.all(np.all(broken_cycle_mask[1:] == [True, False, False, True],
                                         axis=0) == True)
        # Find the locations where the condition on diff is not met, excluding the first and last points
        adjust_loc = np.where(np.all(np.diff(cycle, axis=1) > 0, axis=1) == False)[0][1:-1]
        # Extract the fixed cycles between broken cycles
        fixed_cycle = broken_cycle[1:-1]
        # Adjust the fixed cycles based on the adjust_condition
        if adjust_condition:
            # If yes, adjust the fixed_cycle based on the adjacent cycles
            fixed_cycle[:, 1] = cycle[adjust_loc - 1, 1]
            fixed_cycle[:, 3] = cycle[adjust_loc + 1, 3]
        else:
            # If not, adjust the fixed_cycle based on the adjacent cycles in a different order
            fixed_cycle[:, 3] = cycle[adjust_loc - 1, 3]
            fixed_cycle[:, 1] = cycle[adjust_loc + 1, 1]
        # Check if there are cycles to process
        if cycle.size > 0:
            # Keep only cycles where the condition on differences between elements is met
            cycle = cycle[np.all(np.diff(cycle, axis=1) > 0, axis=1)]
            # Stack the fixed cycles on top of the existing cycles
            cycle = np.vstack((cycle, fixed_cycle))
            # Adjust the columns of the cycles based on the relationship between trough and peak
            if trough[0] < peak[0]:
                # If trough comes first, adjust the columns accordingly
                cycle = np.hstack((cycle[:-1, 1:-1], cycle[1:, :2]))
            else:
                # If peak comes first, adjust the columns accordingly
                cycle = np.hstack((cycle[:-1, 3].reshape((-1, 1)), cycle[1:, :-1]))
        else:
            # Output an empty array if there are no cycles
            cycle = np.empty((0, fixed_cycle.shape[1]))


        # Create an array of amplitudes at the peaks in the theta signal
        theta_peak_sig = np.append(theta_peak_sig, theta_sig[cycle[:, 2].astype(int)])
        # Stack the cycles on top of the existing cycles and adjust for the start index
        cycles = np.vstack((cycles, cycle + start))

    # Set the minimum amplitude threshold for theta peaks based on standard deviation of sub-theta signal
    min_peak_amp = 2 * sub_theta_sig.std()
    # Create a mask for satisfying the amplitude threshold criteria for theta peaks
    peak_mask = theta_peak_sig > min_peak_amp

    # Set the frequency threshold and discard and unsatisfactory difference between trough pairs
    upper_diff = np.floor(1000 / np.min(theta_range))
    lower_diff = np.floor(1000 / np.max(theta_range))
    # Create a mask for satisfying the frequency threshold criteria for cycles
    diff_mask = np.logical_and(np.diff(cycles[:, [0, -1]], axis=1) * (1000 / sample_rate) > lower_diff,
                               np.diff(cycles[:, [0, -1]], axis=1) * (1000 / sample_rate) <= upper_diff)

    # Create a boolean mask that satisfies both the frequency and amplitude threshold criteria
    extrema_mask = np.logical_and(np.squeeze(diff_mask), peak_mask)

    # Pass the boolean mask on the cycles array to discard any unsatisfactory cycles
    cycles = cycles[extrema_mask]

    # Place outputs in a nested dictionary
    for j, rem in enumerate(rem_dict.values()):
        rem['start-end'] = consecutive_rem_states[j]
        rem['wavelet_transform'] = wt_spectrum[j]
        rem['IMFs'] = rem_imf[j]
        rem['IMF_Frequencies'] = rem_mask_freq[j]
        rem['Instantaneous Phases'] = instantaneous_phase[j]
        rem['Instantaneous Frequencies'] = instantaneous_freq[j]
        rem['Instantaneous Amplitudes'] = instantaneous_amp[j]
        # Create a boolean mask for cycles within the current REM period
        cycles_mask = (cycles > consecutive_rem_states[j, 0]) & (cycles < consecutive_rem_states[j, 1])
        # Apply the boolean mask to get cycles within the current REM period
        cycles_mask = np.all(cycles_mask == True, axis=1)
        rem_cycles = cycles[cycles_mask]
        # Assign the cycles values to the nested dictionary after converting to integers
        rem['Cycles'] = rem_cycles.astype(int)

    return rem_dict

def bin_tf_to_fpp(x, power, bin_count):
    """
    Bin the frequency power profile (TF representation) into frequency power profiles (FPP).

    Parameters:
    - x (numpy.ndarray): A 1D or 2D array specifying the frequency ranges for binning.
      For a 1D array, it represents the start and end indices of the frequency range.
      For a 2D array of size (n, 2), each row represents the start and end indices for each binning range.
    - power (numpy.ndarray): The power values in the frequency domain.
    - bin_count (int): The number of bins to use for binning the frequency power profile.

    Returns:
    - numpy.ndarray: A 2D array representing the binned frequency power profile.
      Each row corresponds to the mean power within each bin for the specified frequency range(s).

    Raises:
    - ValueError: If the size of x is invalid.

    Example:
    ```python
    import numpy as np

    # Example usage:
    frequency_ranges = np.array([[5, 10], [15, 20]])  # Define two frequency ranges
    power_spectrum = np.random.rand(100, 30)  # Replace with your actual power spectrum
    bin_count = 10

    # Bin the power spectrum into frequency power profiles
    result_fpp = bin_tf_to_fpp(frequency_ranges, power_spectrum, bin_count)
    print(result_fpp)
    ```
    """
    # Check if x is a 1D array (dimensionality of 1)
    if x.ndim == 1:  # Handle the case when x is of size (2)
        # If yes, create bin ranges using the values in x and specified bin count
        bin_ranges = np.arange(x[0], x[1], 1)
        # Calculate the mean power within each bin using binned_statistic
        fpp = binned_statistic(bin_ranges, power[:, x[0]:x[1]], 'mean', bins=bin_count)[0]
        # Add an extra dimension to match the desired output shape (row vector)
        fpp = np.expand_dims(fpp, axis=0)  # Add an extra dimension to match the desired output shape
    elif x.ndim == 2:  # Handle the case when x is of size (n, 2)
        # If yes, initialize an empty list to store results for each row in x
        fpp = []
        # Iterate through each row in x
        for i in range(x.shape[0]):
            # Create bin ranges using the values in the current row of x and specified bin count
            bin_ranges = np.arange(x[i, 0], x[i, 1], 1)
            # Calculate the mean power within each bin using binned_statistic
            fpp_row = binned_statistic(bin_ranges, power[:, x[i, 0]:x[i, 1]], 'mean', bins=bin_count)[0]
            # Append the result for the current row to the list
            fpp.append(fpp_row)
        # Convert the list of results to a numpy array
        fpp = np.array(fpp)
    # If x has an invalid size, raise a ValueError
    else:
        raise ValueError("Invalid size for x")
    # Return the final result (frequency power profile)
    return fpp


def calculate_cog(frequencies, angles, amplitudes, ratio):
    """
    Calculate the center of gravity (COG) of the frequency and phase distributions.

    Parameters:
    - frequencies (numpy.ndarray): Array of frequency values.
    - angles (numpy.ndarray): Array of phase angles (in degrees).
    - amplitudes (numpy.ndarray): Array of amplitude values corresponding to frequencies and angles.
    - ratio (float): Threshold ratio for identifying significant amplitudes.

    Returns:
    - numpy.ndarray: A 2D array representing the center of gravity (COG) for each dimension.
      Each row corresponds to the COG values for a specific dimension (frequency, phase).

    Notes:
    - The COG is calculated based on the circular mean of angles weighted by significant amplitudes.
    - The threshold for significance is determined by comparing amplitudes to the maximum amplitude within a narrow frequency range.

    Example:
    ```python
    import numpy as np

    # Example usage:
    frequencies = np.arange(1, 10, 1)
    angles = np.random.rand(3, 10) * 360  # Replace with your actual phase angles
    amplitudes = np.random.rand(3, 10)  # Replace with your actual amplitude values
    ratio = 0.5

    # Calculate the COG for the given data
    cog_result = calculate_cog(frequencies, angles, amplitudes, ratio)
    print(cog_result)
    ```
    """
    # Convert angles to radians
    angles = np.deg2rad(angles)
    # Initialize an empty array for the center of gravity (COG)
    cog = np.empty((0, 2))
    # Check if amplitudes have 2 dimensions (2D array)
    if amplitudes.ndim == 2:
        # Calculate the numerator and denominator for frequency COG
        numerator = np.sum(frequencies * np.sum(amplitudes, axis=1))
        denominator = np.sum(amplitudes)
        # Calculate the frequency COG (cog_f)
        cog_f = numerator / denominator
        # Calculate floor and ceil indices for frequency COG
        floor = np.floor(cog_f).astype(int) - frequencies[0]
        ceil = np.ceil(cog_f).astype(int) - frequencies[0]
        # Create a new frequency power profile (FPP) with values greater than the threshold ratio
        new_fpp = np.where(amplitudes >= np.max(amplitudes[[floor, ceil], :]) * ratio, amplitudes, 0)
        # Calculate phase COG using circular mean of angles weighted by FPP
        cog_ph = np.rad2deg(pg.circ_mean(angles, w=np.sum(new_fpp, axis=0)))
        # Create a 2D array for COG (frequency, phase)
        cog = np.array([cog_f, cog_ph])
    # Check if amplitudes have 3 dimensions (3D array)
    elif amplitudes.ndim == 3:
        # Initialize arrays to store indices for amplitude COG calculation
        indices_to_subset = np.empty((amplitudes.shape[0], 2)).astype(int)
        cog = np.empty((amplitudes.shape[0], 2))
        # Calculate numerator and denominator for frequency COG
        numerator = np.sum(frequencies * np.sum(amplitudes, axis=2), axis=1)
        denominator = np.sum(amplitudes, axis=(1, 2))
         # Calculate frequency COG for each dimension
        cog_f = (numerator / denominator)
        # Vectorize floor and ceil functions for efficient array operations
        vectorized_floor = np.vectorize(np.floor)
        vectorized_ceil = np.vectorize(np.ceil)
        # Set floor and ceil indices for each dimension
        indices_to_subset[:, 0] = vectorized_floor(cog_f) - frequencies[0]
        indices_to_subset[:, 1] = vectorized_ceil(cog_f) - frequencies[0]
        # Calculate max amplitudes for each dimension
        max_amps = np.max(amplitudes[np.arange(amplitudes.shape[0])[:, np.newaxis], indices_to_subset, :], axis=(1, 2))
        print(max_amps.shape)
        # Loop through each dimension and calculate phase COG
        for i, max_amp in enumerate(max_amps):
            # Create a new FPP for the current dimension with values greater than the threshold ratio
            new_fpp = np.where(amplitudes[i] >= max_amp * ratio, amplitudes[i], 0)
            # Calculate phase COG using circular mean of angles weighted by FPP
            cog[i, 1] = np.rad2deg(pg.circ_mean(angles, w=np.sum(new_fpp, axis=0)))
        # Set frequency COG values for each dimension
        cog[:, 0] = cog_f
    # Return the final COG array
    return cog


def boxcar_smooth(x, boxcar_window):
    """
    Smooth a 1D or 2D array using a boxcar window.

    Parameters:
    - x (numpy.ndarray): Input array to be smoothed.
    - boxcar_window (int or tuple): Size of the boxcar window for smoothing.
      For 1D array, an integer representing the window size.
      For 2D array, a tuple (t, f) representing window sizes along the time (t) and frequency (f) dimensions.

    Returns:
    - numpy.ndarray: Smoothed array using the boxcar window.

    Notes:
    - If the input array is 1D, the boxcar window size is adjusted to be odd.
    - If the input array is 2D, separate boxcar windows are created for the time (t) and frequency (f) dimensions.

    Example:
    ```python
    import numpy as np

    # Example usage:
    signal_1d = np.random.rand(100)  # Replace with your actual 1D signal
    window_size_1d = 5
    smoothed_1d = boxcar_smooth(signal_1d, window_size_1d)
    print(smoothed_1d)

    signal_2d = np.random.rand(100, 50)  # Replace with your actual 2D signal
    window_size_2d = (5, 3)
    smoothed_2d = boxcar_smooth(signal_2d, window_size_2d)
    print(smoothed_2d)
    ```
    """
    # Check if the input array x is 1-dimensional
    if x.ndim == 1:
        # Check if the boxcar window size is even, and if so, make it odd by adding 1
        if boxcar_window % 2 == 0:
            boxcar_window += 1
        # Create a boxcar window of size boxcar_window for smoothing
        window = np.ones((1, boxcar_window)) / boxcar_window
        # Perform 1-dimensional convolution to smooth the input array x
        x_spectrum = np.convolve(x, window, mode='same')
    else:
        # Adjust the boxcar window size to be odd for both dimensions
        bool_window = np.where(~boxcar_window % 2 == 0, boxcar_window, boxcar_window + 1)
        # Create separate boxcar windows for time (t) and frequency (f) dimensions
        window_t = np.ones((1, bool_window[0])) / bool_window[0]
        window_f = np.ones((1, bool_window[1])) / bool_window[1]
        # Perform 2-dimensional convolution first along the time dimension (t)
        x_spectrum_t = convolve2d(x, window_t, mode='same')
        # Perform 2-dimensional convolution along the frequency dimension (f)
        x_spectrum = convolve2d(x_spectrum_t, window_f.T, mode='same')
    # Return the smoothed array x_spectrum
    return x_spectrum


# def peak_cog(frequencies, angles, amplitudes, ratio):
#     def nearest_peaks(frequency, angle, amplitude, ratio):
#         peak_indices = peak_local_max(amplitude, min_distance=1, threshold_abs=0)
#         cog_f = calculate_cog(frequency, angle, amplitude, ratio)

#         if peak_indices.shape[0] == 0:
#             cog_peak = cog_f
#         else:
#             cog_fx = np.array([cog_f[0], cog_f[0] * np.cos(np.deg2rad(cog_f[1] - angle[0])),
#                                cog_f[0] * np.sin(np.deg2rad(cog_f[1] - angle[0]))])
#             peak_loc = peak_loc = np.empty((peak_indices.shape[0], 4))
#             peak_loc[:, [0, 1]] = np.array([frequency[peak_indices.T[0]], angle[peak_indices.T[1]]]).T
#             peak_loc[:, 2] = peak_loc[:, 0] * np.cos(np.deg2rad(peak_loc[:, 1] - angle[0]))
#             peak_loc[:, 3] = peak_loc[:, 0] * np.sin(np.deg2rad(peak_loc[:, 1] - angle[0]))
#             peak_loc = peak_loc[:, [0, 2, 3]]
#             distances = np.abs(peak_loc - cog_fx)

#             cog_pos = peak_indices[np.argmin(np.linalg.norm(distances, axis=1))]

#             cog_peak = np.array([frequency[cog_pos[0]], angle[cog_pos[1]]])

#         return cog_peak

#     if amplitudes.ndim == 2:
#         cog = nearest_peaks(frequencies, angles, amplitudes, ratio)
#     elif amplitudes.ndim == 3:
#         cog = np.empty((amplitudes.shape[0], 2))
#         for i, fpp in enumerate(amplitudes):
#             cog[i] = nearest_peaks(frequencies, angles, fpp, ratio)
#     return cog


# def max_peaks(amplitudes):
#     new_fpp = np.zeros(amplitudes.shape)
#     if amplitudes.ndim == 2:
#         peak_indices = peak_local_max(amplitudes, min_distance=1, threshold_abs=0)
#         if peak_indices.shape[0] == 0:
#             new_fpp = np.where(amplitudes > 0, amplitudes, 0)
#         else:
#             new_fpp[peak_indices.T[0], peak_indices.T[1]] = amplitudes[peak_indices.T[0], peak_indices.T[1]]
#     elif amplitudes.ndim == 3:
#         for i, fpp in enumerate(amplitudes):
#             peak_indices = peak_local_max(fpp, min_distance=1, threshold_abs=0)
#             if peak_indices.shape[0] == 0:
#                 new_fpp[i] = np.where(fpp > 0, fpp, 0)
#             else:
#                 new_fpp[i, peak_indices.T[0], peak_indices.T[1]] = fpp[peak_indices.T[0], peak_indices.T[1]]
#     return new_fpp


# def boundary_peaks(amplitudes):
#     adjusted_fpp = np.zeros(amplitudes.shape)
#     if amplitudes.ndim == 2:
#         peak_indices = peak_local_max(amplitudes, min_distance=1, threshold_abs=0)
#         if peak_indices.shape[0] == 0:
#             adjusted_fpp = np.where(amplitudes > 0, amplitudes, 0)
#         else:
#             new_fpp = amplitudes[peak_indices.T[0], peak_indices.T[1]]
#             maximum = np.max(new_fpp)
#             minimum = np.min(new_fpp)
#             adjusted_fpp = np.where((amplitudes <= maximum) & (amplitudes >= 0.95*minimum), amplitudes, 0)
#     elif amplitudes.ndim == 3:
#         for i, fpp in enumerate(amplitudes):
#             peak_indices = peak_local_max(fpp, min_distance=1, threshold_abs=0)
#             print(peak_indices.shape)
#             if peak_indices.shape[0] == 0:
#                 adjusted_fpp[i] = np.where(fpp > 0, fpp, 0)
#             else:
#                 maximum = np.max(fpp[peak_indices.T[0], peak_indices.T[1]])
#                 minimum = np.min(fpp[peak_indices.T[0], peak_indices.T[1]])
#                 adjusted_fpp[i] = np.where((fpp <= maximum) & (fpp >= 0.95*minimum), fpp, 0)
#     return adjusted_fpp


def rem_fpp_gen(rem_dict, x, sample_rate, frequencies, angles, ratio, boxcar_window=None, norm='', fpp_method='',
                cog_method=''):
        """
    Generate Frequency-Power-Phase (FPP) plots for each REM epoch in the input dictionary.

    Parameters:
    - rem_dict (dict): Dictionary containing information about REM epochs and cycles.
    - x (numpy.ndarray): 1D sleep signal.
    - sample_rate (int or float): Sampling rate of the sleep signal.
    - frequencies (numpy.ndarray): Array of frequency values.
    - angles (numpy.ndarray): Array of phase angles (in degrees).
    - ratio (float): Threshold ratio for identifying significant amplitudes.
    - boxcar_window (int or None): Size of the boxcar window for smoothing (default is None).
    - norm (str): Normalization method for the time-frequency power spectrum (default is '').
    - fpp_method (str): Method for generating FPP plots (default is '').
    - cog_method (str): Method for calculating the center of gravity (CoG) (default is '').

    Returns:
    - dict: Dictionary containing FPP plots and CoG information for each REM epoch.

    Notes:
    - The function processes each REM epoch in the input dictionary, extracting relevant information such as cycles,
      time indices, and the sleep signal. It then generates FPP plots based on the time-frequency power spectrum
      obtained using the Morlet wavelet. Additional options for smoothing, normalization, FPP generation, and CoG calculation
      can be applied based on the specified parameters.

    """
    # Ensure the input array x is 1-dimensional
    x = np.squeeze(x)
    # Create an empty dictionary to store REM features
    cycles_dict = rem_dict
    rem_dict = {}
    # Create a sub-dictionary to store features for each REM epoch
    sub_dict = rem_dict
    # Loop through each REM epoch in the input dictionary
    for key, value in cycles_dict.items():
        print(key)
        # Check if the REM epoch has cycle information
        if 'Cycles' in value.keys():
            # Create a sub-dictionary for the current REM epoch
            sub_dict.setdefault(key, {})
            # Extract the time indices for the current REM epoch
            sub_dict.setdefault(key, {})
            # Extract the signal for the current REM epoch
            t = value['start-end'].astype(np.int32)
            print(t, t[0], t[1])
            # Extract the signal for the current REM epoch
            sig = x[t[0]:t[1]]
            print(sig.shape)
            # Generate the time-frequency power spectrum using Morlet wavelet
            power = morlet_wt(sig, sample_rate, frequencies, mode='power').astype(np.float32)
             # Extract cycle information and adjust indices to match the current REM epoch
            cycles = (value['Cycles'][:, [0, -1]] - t[0]).astype(np.int32)
            # if boxcar_window is not None:
            #     power = boxcar_smooth(power, boxcar_window)
            # if norm == 'simple_x':
            #     power = power / np.sum(power, axis=0)
            # elif norm == 'simple_y':
            #     power = power / np.sum(power, axis=1)[:, np.newaxis]
            # elif norm == 'zscore_y':
            #     power = zscore(power, axis=0)
            # elif norm == 'zscore_x':
            #     power = zscore(power, axis=1)
            # Bin the time-frequency power spectrum to generate FPP (Frequency-Power-Phase) plots
            fpp_plots = bin_tf_to_fpp(cycles, power, 19).astype(np.float32)
            # Store the FPP plots in the sub-dictionary
            sub_dict[key]['FPP_cycles'] = fpp_plots
            # if fpp_method == 'max_peaks':
            #     fpp_plots = max_peaks(fpp_plots)
            #     print(fpp_plots.shape)
            # elif fpp_method == 'boundary_peaks':
            #     fpp_plots = boundary_peaks(fpp_plots)
            # if cog_method == 'nearest':
            #     cog = peak_cog(frequencies, angles, fpp_plots, ratio).astype(np.float32)
            # else:
            #     cog = calculate_cog(frequencies, angles, fpp_plots, ratio).astype(np.float32)
            # sub_dict[key]['CoG'] = cog
        else:
            continue
    return rem_dict

In [10]:
class DataProcessor:
    """
    A class for processing and analyzing data.
    """
    def __init__(self, data_dir, output_dir):
        """
        Initializes the DataProcessor with the input data directory and output directory.

        Parameters:
        - data_dir (str): The directory containing input data.
        - output_dir (str): The directory for storing processed output data.
        """
        self.data_dir = data_dir  # Set the input data directory.
        self.output_dir = output_dir  # Set the output data directory.

    def load_data(self, subfolder):
        """
        Load LFP (Local Field Potential) and states data from the specified subfolder.

        Parameters:
        - subfolder (str): The subfolder within the data directory containing the data files.

        Returns:
        Tuple: A tuple containing LFP data and states data.
        """
        # Generate file paths for LFP and states data using glob.
        # Check for the merged HPC file first, otherwise, use the regular HPC file.
        merged_hpc_files = glob.glob(os.path.join(self.data_dir, subfolder, '*_HPC_merged.mat'))
        hpc_files = glob.glob(os.path.join(self.data_dir, subfolder, '*HPC*.continuous*.mat'))

        if merged_hpc_files:
            # If there are merged HPC files, use the first one.
            lfp_file = merged_hpc_files[0]
        elif hpc_files:
            # If no merged HPC files found, but regular HPC files exist, use the first one.
            lfp_file = hpc_files[0]
        else:
            # If no HPC files are found, raise a FileNotFoundError.
            raise FileNotFoundError("No HPC files found in the specified subfolder.")

        # Find the file that matches the pattern '*states*' in the specified subfolder.
        states_file = glob.glob(os.path.join(self.data_dir, subfolder, '*states*'))[0]

        print("Loading data from:", lfp_file)
        print("Loading data from:", states_file)

        # Load LFP and states data using scipy's loadmat function.
        lfp_data = sio.loadmat(lfp_file)['HPC']  # Load LFP data from the MATLAB file.
        states_data = sio.loadmat(states_file)['states']  # Load states data.

        print("LFP data shape:", lfp_data.shape)
        print("States data shape:", states_data.shape)

        # Check data format
        if not isinstance(states_data, np.ndarray):
            raise ValueError("States data should be a NumPy array.")

        # Check data dimensions (adjust as needed)
        if states_data.ndim != 2:
            raise ValueError("States data should be a 2D array.")

        # Check data range
        min_value = np.min(states_data)
        max_value = np.max(states_data)
        expected_min = 0  # Adjust as needed
        expected_max = 5  # Adjust as needed
        if min_value < expected_min or max_value > expected_max:
            raise ValueError("States data range is outside of expected bounds.")

        # Check for missing values
        if np.isnan(states_data).any():
            raise ValueError("States data contains NaN values.")

        return lfp_data, states_data 
     

    def process_data(self, lfp_data, states_data):
        """
        Process LFP (Local Field Potential) and states data.

        Parameters:
        - lfp_data (numpy.ndarray): The LFP data to be processed.
        - states_data (numpy.ndarray): The states data corresponding to the LFP data.

        Returns:
        dict: A dictionary containing processed data.
        """
        print("Processing data...")
        # Define a frequency range for processing.
        frequency_range=np.arange(20,140,1)
        # Call a function (get_cycles_data) to process cycles data.
        rem_dict = get_cycles_data(lfp_data, states_data, 2500, frequency_range, (5, 12))
        print("Called get_cycles_data")
        for key, value in rem_dict.items():
            # Loop through each key-value pair in rem_dict.
            if isinstance(value, (int, float)):
                # Check if the value associated with the key is an int or float.
                rem_dict[key] = np.float32(value)  # Convert numerical values to np.float32
        # Print a message indicating that data processing is completed.
        print("Data processing completed.")
        # Return the processed rem_dict.
        return rem_dict

    
    def save_data(self, subfolder, rem_dict):
        """
        Save processed data to a file.

        Parameters:
        - subfolder (str): The subfolder within the output directory for saving data.
        - rem_dict (dict): The processed data dictionary to be saved.
        """
        # Create an output subfolder if it doesn't exist.
        output_subfolder = os.path.join(self.output_dir, subfolder)
        # Create a new folder in the output directory to store the processed data.
        os.makedirs(output_subfolder, exist_ok=True)
        # Define the output file paths based on the subfolder name.
        rem_dict_filename = f"{subfolder}_REM_dict.h5"
        # Create the full file path for saving the processed data dictionary.
        rem_dict_file = os.path.join(output_subfolder, rem_dict_filename)

        # Create a function to save a dictionary as an HDF5 group
        def save_dict_as_hdf5_group(hdf_group, data_dict):
            """
            Recursively saves a dictionary as an HDF5 group.

            Parameters:
            - hdf_group (h5py.Group): The HDF5 group to which the dictionary will be saved.
            - data_dict (dict): The dictionary to be saved.
            """
            
            for key, value in data_dict.items():
                # Loop through each key-value pair in the dictionary.
                if isinstance(value, dict):
                    # If the value is another dictionary, create a subgroup in the HDF5 group.
                    subgroup = hdf_group.create_group(key)
                    # Create a subgroup within the HDF5 group.
                    save_dict_as_hdf5_group(subgroup, value)
                else:
                    # Otherwise, save the value to the HDF5 group
                    hdf_group[key] = value

        # Save the rem_dict dictionary as an HDF5 file.
        with h5py.File(rem_dict_file, 'w') as hdf_file:
            # Use the subfolder name as the top-level group name
            subfolder_group = hdf_file.create_group(subfolder)

            # Call the function to save rem_dict within the subfolder group
            save_dict_as_hdf5_group(subfolder_group, rem_dict)

    def process_subfolder_with_timing(self, subfolder):
        """
        Process a subfolder's data with timing information.

        Parameters:
        - subfolder (str): The subfolder within the data directory to process.

        Returns:
        - dict: The processed data dictionary (rem_dict).
        """
        start_time = time.time()  # Start measuring time
        # Load LFP (Local Field Potential) and states data from the specified subfolder.
        lfp_data, states_data = self.load_data(subfolder)
        # Process the loaded data and obtain the rem_dict (processed data dictionary).
        rem_dict = self.process_data(lfp_data, states_data)  # Process and get the rem_dict
        # Save the processed data dictionary to the output directory.
        self.save_data(subfolder, rem_dict)  # Save the rem_dict
        end_time = time.time()  # Stop measuring time
        elapsed_time = end_time - start_time  # Calculate elapsed time
        print(elapsed_time)  # Print the elapsed time in seconds
        # Return the processed data dictionary (rem_dict).
        return rem_dict



# Specify the input data directory and output directory.
data_dir = "E:/Donders/11/raw/OD"  # Replace with your actual data directory path.
output_dir = "E:/Donders/11/processed/OD"  # Replace with your desired output directory path.

# DataProcessor object.
processor = DataProcessor(data_dir, output_dir)

# Process all data in the specified directories using parallel processing.
with concurrent.futures.ProcessPoolExecutor() as executor:
    """
    This code concurrently processes data from multiple subfolders using a concurrent.futures.ProcessPoolExecutor. 
    
    It starts by creating the executor, then generates a list of subfolders in the specified data directory.
    
    In a parallel loop, it calls the process_subfolder_with_timing method of a DataProcessor object for each subfolder, 
    collecting the processed data dictionaries in the processed_data list. The result is a list containing the 
    dictionaries for each subfolder after parallel processing.
    """
    # Get a list of subfolders within the specified data directory.
    subfolders = [subfolder for subfolder in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, subfolder))]
    # Create a list to collect processed data dictionaries for each subfolder.
    processed_data = []  # Create a list to collect processed data
    # Iterate over each subfolder for parallel processing.
    for subfolder in subfolders:
        # Call the 'process_subfolder_with_timing' method on the DataProcessor object for each subfolder.
        result = processor.process_subfolder_with_timing(subfolder)  # Call the method and collect the dictionaries
        # Append the processed data dictionary to the list.
        processed_data.append(result)  # Append the dictionaries to the list

# Now, 'processed_data' contains the dictionaries for each subfolder.

Loading data from: E:/Donders/11/raw/OD\2018-10-31_10-25-02_Pre-sleep\HPC_100_CH32_0.continuous.mat
Loading data from: E:/Donders/11/raw/OD\2018-10-31_10-25-02_Pre-sleep\2018-10-31_10-25-02_presleep-states.mat
LFP data shape: (6751147, 1)
States data shape: (1, 2699)
Processing data...
(6751147,)
(2699,)
(0,)
Called get_cycles_data
Data processing completed.
0.6540532112121582
Loading data from: E:/Donders/11/raw/OD\2018-10-31_11-15-54_Post_Trial1\HPC_100_CH32_0.continuous.mat
Loading data from: E:/Donders/11/raw/OD\2018-10-31_11-15-54_Post_Trial1\2018-10-31_11-15-54_post_trial1-states.mat
LFP data shape: (6750976, 1)
States data shape: (1, 2699)
Processing data...
(6750976,)
(2699,)
(4, 2)
Called get_cycles_data
Data processing completed.
291.0214374065399
Loading data from: E:/Donders/11/raw/OD\2018-10-31_12-06-31_Post_Trial2\HPC_100_CH32_0.continuous.mat
Loading data from: E:/Donders/11/raw/OD\2018-10-31_12-06-31_Post_Trial2\2018-10-31_12-06-31_post_trial2-states.mat
LFP data shape:

Processing HDF file: post_trial1_2017-09-28_11-30-59_REM_dict.h5
Keys in HDF file:
post_trial1_2017-09-28_11-30-59
Processing HDF file: post_trial1_2017-11-13_10-54-15_REM_dict.h5
Keys in HDF file:
post_trial1_2017-11-13_10-54-15
Processing HDF file: post_trial2_2017-09-28_12-21-41_REM_dict.h5
Keys in HDF file:
post_trial2_2017-09-28_12-21-41
Processing HDF file: post_trial2_2017-11-13_11-45-00_REM_dict.h5
Keys in HDF file:
post_trial2_2017-11-13_11-45-00
Processing HDF file: post_trial3_2017-09-28_13-12-53_REM_dict.h5
Keys in HDF file:
post_trial3_2017-09-28_13-12-53
Processing HDF file: post_trial3_2017-11-13_12-35-45_REM_dict.h5
Keys in HDF file:
post_trial3_2017-11-13_12-35-45
Processing HDF file: post_trial4_2017-09-28_14-03-38_REM_dict.h5
Keys in HDF file:
post_trial4_2017-09-28_14-03-38
Processing HDF file: post_trial4_2017-11-13_13-26-21_REM_dict.h5
Keys in HDF file:
post_trial4_2017-11-13_13-26-21
Processing HDF file: post_trial5_2017-09-28_14-55-18_REM_dict.h5
Keys in HDF fil

[5412500. 5625000.]


print(rem_states.shape)

(2,)


<class 'numpy.ndarray'>
(2,)
5412500.0
<class 'int'>
Processing REM period 1, start: 5412500, end: 5625000
start: 5412500, end: 5625000
start: 5412500, end: 5625000


ValueError: operands could not be broadcast together with shapes (613,5) (2,) 