This notebook is for a pipeline of analyzing and visualizing a time-seires dataset. 

## Step 1: load package and preprocess functions

In [1]:
import pandas as pd
import numpy as np
import os
from wav2sleep.data.edf import load_edf_data
from wav2sleep.data.txt import parse_txt_annotations
from wav2sleep.data.utils import interpolate_index
from wav2sleep.data.xml import parse_xml_annotations
from wav2sleep.data.xml_parse_all import parse_all_annotations, annotate_waveform
from wav2sleep.data.rpoints import parse_process_rpoints_annotations
from wav2sleep.settings import *
from wav2sleep.config import *


In [2]:

import mne, pandas as pd, pathlib
### please modify path here
# annotation_path = '/scratch/besp/shared_data/shhs/polysomnography/annotations-events-nsrr/shhs1/shhs1-201935-nsrr.xml'
# edf_path = '/scratch/besp/shared_data/shhs/polysomnography/edfs/shhs1/shhs1-201935.edf'
# DATA_FOR_CHECK = 'shhs'
annotation_path = '/scratch/besp/shared_data/ccshs/polysomnography/annotations-events-nsrr/ccshs-trec-1800823-nsrr.xml'
edf_path = '/scratch/besp/shared_data/ccshs/polysomnography/edfs/ccshs-trec-1800823.edf'
# annotation_path = '/scratch/besp/shared_data/ccshs/polysomnography/annotations-events-nsrr/ccshs-trec-1800248-nsrr.xml'
# edf_path = '/scratch/besp/shared_data/ccshs/polysomnography/edfs/ccshs-trec-1800248.edf'
DATA_FOR_CHECK = 'ccshs'


###########################
edf = pathlib.Path(edf_path)
raw = mne.io.read_raw_edf(edf, preload=False, verbose="error")


hdr          = raw._raw_extras[0]              
rec_len_sec  = hdr['record_length']           
n_samps_list = hdr['n_samps']                  

rows = []
for idx, ch in enumerate(raw.info['chs']):
    sfreq = n_samps_list[idx] / rec_len_sec    
    rows.append(dict(channel   = ch['ch_name'],
                     sfreq_hz  = sfreq,
                     phys_unit = ch.get('unit', '—'),
                     lowpass   = ch.get('lowpass',  '—'),
                     highpass  = ch.get('highpass', '—')))

df = pd.DataFrame(rows)#.sort_values("sfreq_hz", ascending=False)
print(df)          

        channel          sfreq_hz  phys_unit lowpass highpass
0            C3    [128.0, 128.0]        107       —        —
1            C4    [128.0, 128.0]        107       —        —
2            A1    [128.0, 128.0]        107       —        —
3            A2    [128.0, 128.0]        107       —        —
4           LOC    [128.0, 128.0]        107       —        —
5           ROC    [128.0, 128.0]        107       —        —
6          ECG2    [256.0, 256.0]        107       —        —
7          ECG1    [256.0, 256.0]        107       —        —
8          EMG1    [256.0, 256.0]        107       —        —
9          EMG2    [256.0, 256.0]        107       —        —
10         EMG3    [256.0, 256.0]        107       —        —
11        L Leg      [64.0, 64.0]        107       —        —
12        R Leg      [64.0, 64.0]        107       —        —
13      AIRFLOW      [32.0, 32.0]        107       —        —
14  THOR EFFORT      [32.0, 32.0]        107       —        —
15  ABDO

In [3]:
'''
Preprocessing notes:
1. select a time window: here we use 10h 
2. select different frequency for different channel
3. resample using interpolate
4. channel-wise normalization
'''
def _mne_lowpass_series(s: pd.Series, fs,
                        cutoff=None) -> pd.Series:
    """
    Apply low-pass filter to a pd.Series using MNE.
    Keeps frequencies below the cutoff.
    
    Parameters:
    - s: input signal
    - fs: sampling rate
    - cutoff: cutoff frequency (Hz)
    """
    if cutoff is None:
        return s

    x = s.to_numpy(np.float64)[np.newaxis, :]  # shape (1, n)

    x_filt = mne.filter.filter_data(
        x, sfreq=fs,
        l_freq=None, h_freq=cutoff, 
        method='fir', phase='zero-double',
        n_jobs='cuda',
        verbose=False
    )[0]

    return pd.Series(x_filt, index=s.index, name=s.name)

def process_edf(edf: pd.DataFrame):
    """Process dataframe of EDF data."""
    signals = []

    def _process_edf_column(col, target_index, preprocessed_fs):
        """Process signal column of EDF"""
        if col in edf:
            
            raw = edf[col].dropna()
            
            # print(len(raw.loc[0:1]))
            raw_fs = len(raw.loc[0:1]) - 1
            
            if raw_fs > preprocessed_fs: 
                raw_hp = _mne_lowpass_series(raw, raw_fs, cutoff = preprocessed_fs/2)
            else:
                raw_hp = raw
            
            resampled = interpolate_index(raw_hp, target_index,
                              method="linear", squeeze=False)
            # normalized_wav = (resampled_wav - resampled_wav.mean()) / resampled_wav.std()
            print("col:", col, "length:", resampled.shape)
            signals.append(resampled)
            return 0
        else:
            return 1

    _process_edf_column(ECG, ECG_SIGNAL_INDEX, FREQ_ECG)
    _process_edf_column(HR, HR_SIGNAL_INDEX, FREQ_HR)

    _process_edf_column(SPO2, SPO2_SIGNAL_INDEX, FREQ_SPO2)
    _process_edf_column(OX, OX_SIGNAL_INDEX, FREQ_OX)
    _process_edf_column(ABD, ABD_SIGNAL_INDEX, FREQ_ABD)
    _process_edf_column(THX, THX_SIGNAL_INDEX, FREQ_THX)
    _process_edf_column(AF, AF_SIGNAL_INDEX, FREQ_AF)
    _process_edf_column(NP, NP_SIGNAL_INDEX, FREQ_NP)
    _process_edf_column(SN, SN_SIGNAL_INDEX, FREQ_SN)
    
    _process_edf_column(EMG_LLeg, EMG_LLeg_SIGNAL_INDEX, FREQ_EMG_LLeg)
    _process_edf_column(EMG_RLeg, EMG_RLeg_SIGNAL_INDEX, FREQ_EMG_RLeg)
    _process_edf_column(EMG_LChin, EMG_LChin_SIGNAL_INDEX, FREQ_EMG_LChin)
    _process_edf_column(EMG_RChin, EMG_RChin_SIGNAL_INDEX, FREQ_EMG_RChin)
    _process_edf_column(EMG_CChin, EMG_CChin_SIGNAL_INDEX, FREQ_EMG_CChin)
    _process_edf_column(EOG_L, EOG_L_SIGNAL_INDEX, FREQ_EOG_L)
    _process_edf_column(EOG_R, EOG_R_SIGNAL_INDEX, FREQ_EOG_R)
    
    is_na_C3 = _process_edf_column(EEG_C3, EEG_C3_SIGNAL_INDEX, FREQ_EEG_C3)
    is_na_C4 = _process_edf_column(EEG_C4, EEG_C4_SIGNAL_INDEX, FREQ_EEG_C4)
    is_na_A1 = _process_edf_column(EEG_A1, EEG_A1_SIGNAL_INDEX, FREQ_EEG_A1)
    is_na_A2 = _process_edf_column(EEG_A2, EEG_A2_SIGNAL_INDEX, FREQ_EEG_A2)
    is_na_O1 = _process_edf_column(EEG_O1, EEG_O1_SIGNAL_INDEX, FREQ_EEG_O1)
    is_na_O2 = _process_edf_column(EEG_O2, EEG_O2_SIGNAL_INDEX, FREQ_EEG_O2)
    is_na_F3 = _process_edf_column(EEG_F3, EEG_F3_SIGNAL_INDEX, FREQ_EEG_F3)
    is_na_F4 = _process_edf_column(EEG_F4, EEG_F4_SIGNAL_INDEX, FREQ_EEG_F4)
    
    # add a logic to check
    
    is_na_C3_A2 = _process_edf_column(EEG_C3_A2, EEG_C3_A2_SIGNAL_INDEX, FREQ_EEG_C3_A2)
    is_na_C4_A1 = _process_edf_column(EEG_C4_A1, EEG_C4_A1_SIGNAL_INDEX, FREQ_EEG_C4_A1)
    is_na_F3_A2 = _process_edf_column(EEG_F3_A2, EEG_F3_A2_SIGNAL_INDEX, FREQ_EEG_F3_A2)
    is_na_F4_A1 = _process_edf_column(EEG_F4_A1, EEG_F4_A1_SIGNAL_INDEX, FREQ_EEG_F4_A1)
    is_na_O1_A2 = _process_edf_column(EEG_O1_A2, EEG_O1_A2_SIGNAL_INDEX, FREQ_EEG_O1_A2)
    is_na_O2_A1 = _process_edf_column(EEG_O2_A1, EEG_O2_A1_SIGNAL_INDEX, FREQ_EEG_O2_A1)
    
    
    
    merged_df = pd.concat(signals, axis=1).astype(np.float32)
    
    if (EEG_C3_A2 not in merged_df.columns.to_list()) and (is_na_C3 == 0) and (is_na_A2 == 0):
        merged_df[EEG_C3_A2] = merged_df[EEG_C3] - merged_df[EEG_A2]
    if (EEG_C4_A1 not in merged_df.columns.to_list()) and (is_na_C4 == 0) and (is_na_A1 == 0):
        merged_df[EEG_C4_A1] = merged_df[EEG_C4] - merged_df[EEG_A1]
    if (EEG_F3_A2 not in merged_df.columns.to_list()) and (is_na_F3 == 0) and (is_na_A2 == 0):
        merged_df[EEG_F3_A2] = merged_df[EEG_F3] - merged_df[EEG_A2]
    if (EEG_F4_A1 not in merged_df.columns.to_list()) and (is_na_F4 == 0) and (is_na_A1 == 0):
        merged_df[EEG_F4_A1] = merged_df[EEG_F4] - merged_df[EEG_A1]
    if (EEG_O1_A2 not in merged_df.columns.to_list()) and (is_na_O1 == 0) and (is_na_A2 == 0):
        merged_df[EEG_O1_A2] = merged_df[EEG_O1] - merged_df[EEG_A2]
    if (EEG_O2_A1 not in merged_df.columns.to_list()) and (is_na_O2 == 0) and (is_na_A1 == 0):
        merged_df[EEG_O2_A1] = merged_df[EEG_O2] - merged_df[EEG_A1]    
    
    merged_df = (merged_df - merged_df.mean()) / merged_df.std()
    return merged_df



def process(edf_fp: str, label_fp: str, output_fp: str, overwrite: bool = False) -> bool:
    """Process night of data."""
    if os.path.exists(output_fp) and not overwrite:
        logger.debug(f'Skipping {edf_fp=}, {output_fp=}, already exists')
        return False
    else:
        os.makedirs(os.path.dirname(output_fp), exist_ok=True)
        
    # Process labels
    if label_fp.endswith('.xml'):
        try:
            labels = parse_xml_annotations(label_fp) # parse sleep stages
            all_df = parse_all_annotations(label_fp) # parse all other events (arousals, respiratory)
        except Exception as e:
            logger.error(f'Failed to parse: {label_fp}.')
            logger.error(e)
            return False
    else:
        labels = parse_txt_annotations(fp=label_fp)
        # NOTE: If we end up using a dataset with txt annotations, we will want to write another function to parse all other events
        if labels is None:
            logger.error(f'Failed to parse: {label_fp}.')
            return False
    labels = labels.reindex(TARGET_LABEL_INDEX).fillna(-1) # these are sleep stage labels
    # Check for N1, N3 or REM presence. (Recordings with just sleep-wake typically use N2 as sole sleep class)
    stage_counts = labels.value_counts()
    if stage_counts.get(1.0) is None and stage_counts.get(3.0) is None and stage_counts.get(4.0) is None: 
        logger.error(f'No N1, N3 or REM in {label_fp}.')
        output_fp = output_fp.replace('.parquet', 'sleepstage.issues.parquet') # note these are still useful since sleep stages are not strictly necessary for captions
    
    edf = load_edf_data(edf_fp, columns=EDF_COLS, raise_on_missing=False)
    waveform_df = process_edf(edf)
    output_df = pd.concat([waveform_df, labels], axis=1)
    
    
    # output_df = annotate_waveform(output_df, [all_df])
    
    # output_df.to_parquet(output_fp)
    return output_df, all_df

In [4]:

output_path = f'./test/test_{DATA_FOR_CHECK}/test.parquet'
output_df, all_df = process(edf_path, annotation_path, output_path)

<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.Float64Index'>
col: ECG length: (3686400, 1)
<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.Float64Index'>
col: HR length: (28800, 1)
<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.Float64Index'>
col: SPO2 length: (28800, 1)
<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.Float64Index'>
col: OX length: (28800, 1)
<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.Float64Index'>
col: ABD length: (230400, 1)
<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.Float64Index'>
col: THX length: (230400, 1)
<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.Float64Index'>
col: AF length: (230400, 1)
<class 'pandas.core.indexes.numeric.Float64Index'> <class 'pandas.core.indexes.numeric.

In [5]:
all_df

Unnamed: 0,Types,Concepts,Starts,Ends,Signal
0,Limb Movement,Limb movement,42.5,47.0,R Leg
1,Limb Movement,Limb movement,61.4,62.2,R Leg
2,Limb Movement,Limb movement,94.2,96.0,R Leg
3,Limb Movement,Limb movement,110.8,114.7,L Leg
4,Limb Movement,Limb movement,116.7,118.9,R Leg
...,...,...,...,...,...
1352,Limb Movement,Limb movement,41180.8,41181.3,R Leg
1353,Limb Movement,Limb movement,41185.2,41186.1,R Leg
1354,Respiratory,SpO2 artifact,41192.8,41236.0,SpO2
1355,Limb Movement,Limb movement,41199.6,41200.5,L Leg


## Step 2: Sanity Check for the pre-processed data

In [None]:
df = pd.read_parquet(output_path)

In [None]:
df.head(128)

In [None]:
print(pd.isna(df['EMG_LLeg_events']))

In [None]:
print(len(df))

df_clean = df.dropna(how="all") 

print(len(df_clean))

In [None]:
df.describe()

In [None]:
print(df['ECG'])
print(df['ECG'].dropna())
print(df.columns)
print(df['SPO2_events'].unique())
print(df['Stage'].unique())

## Step 3: Check Each Channel in the Time Domain

In [None]:
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
def format_overlapping_events(events_df, event_cols):
    # Create a new column for overlapping events
    events_df['Multiple Events'] = None

    # For each row, check for overlapping events
    for idx in events_df.index:
        row_events = []
        for col in event_cols:
            if pd.notna(events_df.loc[idx, col]):
                row_events.append(events_df.loc[idx, col])
        
        # If multiple events, combine them and set original columns to NaN
        if len(row_events) > 1:
            events_df.loc[idx, 'Multiple Events'] = ' and '.join(row_events)
            for col in event_cols:
                events_df.loc[idx, col] = None

    # Add overlapping_events to event_cols
    event_cols.append('Multiple Events')

    return events_df, event_cols

def get_colormap(events_df, event_cols):
    custom_colors = [
        '#800000',  # Maroon (Dark Red)
        '#457B9D',  # Steel Blue
        '#556B2F',  # Dark Olive Green
        '#533100',  # Brown
        '#581845',  # Dark Purple
        '#2F4F4F',  # Dark Slate Gray'
        '#CD5C5C',  # Indian Red (a bit lighter, but distinct)
        '#DAA520',  # Goldenrod
        '#40E0D0',  # Turquoise
    ]
    all_events = []
    for col in event_cols:
        all_events.extend(events_df[col].dropna().unique())
    unique_events = np.unique(all_events)
    event_colors = dict(zip(unique_events, custom_colors[:len(unique_events)]))
    return event_colors
def plot_channels_annotations_sleep_stage(df_input, channel_cols, stage_col = 'Sleep Stage', stage_labels=None):
    """
    Plot multiple channels of sensor data with highlighted event intervals. NOTE: sleep stage should be bfilled to propagate 30-sec epoch stage to all timepoints in the interval

    Parameters:
    -----------
    df_input : pandas.DataFrame
        DataFrame containing the data with channel and channel_events annotation columns
    channel_cols : list
        List of channel column names to plot. Assumes annotation columns with _events exists for some channels.
    stage_col : str 
        Name of sleep stage column
    """

    n_channels = len(channel_cols)
    fig, axes = plt.subplots(n_channels, 1, figsize=(18, 5*n_channels), sharex=True)

    # First plot the event annotation boxes that will span all subplots
    event_cols = [col for col in df_input.columns if col.endswith('_events')]
    # Create a copy of df with just event columns
    events_df = df_input[event_cols].copy()
    # Create copy of df, might manipulate 
    df = df_input.copy()

    # Format overlapping events in a new column 'Multiple Events'
    events_df, event_cols = format_overlapping_events(events_df, event_cols)

    # Create a color map for all unique events across all columns
    event_colors = get_colormap(events_df, event_cols)

    # --- Plot annotation events
    for event_col in event_cols:
        # Get unique events in this window
        events = events_df[event_col].dropna().unique()
        for event in events:
            # Find start and end times for each event
            event_mask = events_df[event_col] == event
            event_starts = events_df.index[event_mask & ~event_mask.shift(1).fillna(False)]
            event_ends = events_df.index[event_mask & ~event_mask.shift(-1).fillna(False)]
            
            # Add transparent box for each event occurrence spanning all subplots
            for start_time, end_time in zip(event_starts, event_ends):
                # Add to first subplot to avoid duplicate legend entries+_
                span = axes[0].axvspan(start_time, end_time, color=event_colors[event], alpha=0.3, label=f'{event_col}: {event}')
                # Add same span to all other subplots without labels
                for ax in axes[1:]:
                    ax.axvspan(start_time, end_time, color=event_colors[event], alpha=0.3)

    # --- Plot each channel in its own subplot
    for idx, (channel, ax) in enumerate(zip(channel_cols, axes)):
        if channel in df.columns:
            # Plot non-NaN values only
            valid_data = df[channel].dropna()
            if not valid_data.empty:
                ax.plot(valid_data.index, valid_data.values, linewidth=0.5)
                ax.set_ylabel(channel)
                # Add grid
                #ax.grid(True, alpha=0.3)
                
                ax.set_xlabel('Time (seconds)')
                ax.tick_params(axis='x', labelbottom=True)

    # --- get handles and labels ready to plot legend of annotation events
    handles, labels = axes[0].get_legend_handles_labels()
    if handles:
        # Get unique handles and labels while preserving order
        unique_labels = []
        unique_handles = []
        for h, l in zip(handles, labels):
            if l not in unique_labels:
                unique_labels.append(l)
                unique_handles.append(h)

    # --- Sleep Stages
    # Optional: define readable labels
    if stage_labels is None:
        # Get unique sleep stages
        unique_stages = df[stage_col].dropna().unique()
        stage_labels = {int(k): 'Stage' + str(int(k)) for k in unique_stages}
    else:
        # map sleep stage numbers to labels 
        df[stage_col] = df[stage_col].map(stage_labels)
        # Get unique sleep stages
        unique_stages = df[stage_col].dropna().unique()
    # Overlay sleep stage bar at bottom of every axis (use height just below lowest frequency)
    stage_band_y_axes = {}
    band_height_axes = {}
    for i, ax in enumerate(axes):
        ymin, ymax = ax.get_ylim()
        stage_band_y_axes[i] = ymin - ((ymax-ymin) * 0.05)  # slight offset below
        band_height_axes[i] = ((ymax-ymin) * 0.05)
    # Custom color map
    color_map = ListedColormap(plt.cm.tab10.colors[:len(unique_stages)])
    stage_to_idx = {stage: idx for idx, stage in enumerate(unique_stages)} # map stage number to index number to use to access color_map
    # Construct colored rectangles (one per time bin)
    for j, (channel, ax) in enumerate(zip(channel_cols, axes)):
        df_subset = df[[channel, stage_col]].dropna()
        stage_values = df_subset[stage_col].values
        time_values = df_subset.index.values
        # Detect stage changes
        changes = np.where(stage_values[:-1] != stage_values[1:])[0] + 1 # indices where stage changes
        segments = np.split(np.arange(len(stage_values)), changes) # split into segments of continuous sleep stages
        # Draw one rectangle per segment
        for seg in segments:
            stage = stage_values[seg[0]]
            ax.add_patch(plt.Rectangle(
                (time_values[seg[0]], stage_band_y_axes[j]), # anchor
                time_values[seg[-1]] - time_values[seg[0]], # width
                band_height_axes[j], # height
                color=color_map(stage_to_idx[stage]),
                linewidth=0
            ))

    # --- Legend for sleep stages
    legend_patches = [plt.Line2D([0], [0], color=color_map(stage_to_idx[stage]), lw=6)
                      for stage in unique_stages] # dummy lines for legend
    #legend_labels = [stage_labels[int(stage)] for stage in unique_stages]
    legend_labels = unique_stages
    for j, ax in enumerate(axes):
        ymin, ymax = ax.get_ylim()
        legend1 = ax.legend(legend_patches, legend_labels, loc='center left', bbox_to_anchor=(1.0, 0.5),
                title="Sleep Stage", fontsize='small')
        ax.add_artist(legend1) # add first legend as artist
        ax.set_ylim(stage_band_y_axes[j], ymax)
    
    # --- Legend for annotation events
    if handles:
        for ax in axes:
            ax.legend(unique_handles, unique_labels, loc='upper left', bbox_to_anchor=(1, 1))

    plt.tight_layout()
    return fig

In [None]:
print(df.columns)
# df['Sleep Stage'] = 1 # just make a column

# real_COLS = df.columns.to_list()
# real_COLS.remove('Stage')

real_COLS = ['ECG', 'HR', 'SPO2', 'OX', 'ABD', 'THX', 'AF', 'NP', 'SN', 'EMG_LLeg',
       'EMG_RLeg', 'EMG_LChin', 'EMG_RChin', 'EMG_CChin', 'EOG_L', 'EOG_R']
print(real_COLS)
#####################################

stage_labels={
    0: 'Awake', 
    1: 'Light Sleep', 
    2: 'Light Sleep', 
    3: 'Deep Sleep', 
    4: 'REM'
}
SPO2_labels={
    'SpO2 artifact': 'SpO2 artifact', 
    'SpO2 desaturation': 'SpO2 desaturation'
}

df['Stage'] = df['Stage'].bfill()
df['SPO2_events'] = df['SPO2_events'].bfill()

df_show = df[100:140]
# might want to subset df to a smaller window here
fig = plot_channels_annotations_sleep_stage(df_show, real_COLS, 'Stage', stage_labels)
# fig = plot_channels_annotations_sleep_stage(df_show, real_COLS, 'SPO2_events', SPO2_labels)

plt.show()
# EDF_COLS are the channel column names you want to plot 
# 'Sleep Stage' is the name of the sleep stage column
# stage_labels is the dict I sent you above, grouping stages as awake, light sleep, deep sleep, or REM