## SpO2 Features Extraction
This code calculates all necessary metrics for SpO2 and respiration rate features from edf files and XML annotation files for participants. Outputs data to a csv. Note that the function requires edf files to be stored in a folder titled "clean-edf" and annotations to be stored in "edf-annotations". 

In [1]:
# Load all necessary functions and packages 

from collections import defaultdict
import datetime
import os
import shutil
import zipfile
import pickle
import numpy as np
import pandas as pd
import scipy.io as sio
from scipy.signal import resample
from tqdm import tqdm
import neurokit2 as nk
import mne
import sys
from itertools import groupby
from scipy.signal import resample
import xmltodict
sys.path.insert(0, '..')
sys.path.insert(0, 'hypoxic_burden')
from hb_functions import detect_oxygen_desaturation, calc_hypoxic_burden, get_spo2_no_desat, filter_spo2

In [2]:
# Compute sleep metrics based on sleep stages. Takes an array of stages and returns a dictionary, r, with sleep stages 
# tib = time in bed 
# tst = total sleep time 
# se = sleep efficiency 
# sl = sleep latency 

# Sleep macrostructure extraction from sleep stage array
def get_macrostructures(sleep_stages, epoch_time=30):
    # Assumes sleep_stages is between lights off to lights on only
    # Assumes epoch level

    # First part pulls out tib, sl, and waso 
    r = {}
    r['tib'] = len(sleep_stages)*epoch_time/3600
    sleep_ids = np.where(np.in1d(sleep_stages, [1,2,3,4]))[0] 
    r['tst'] = len(sleep_ids)*epoch_time/3600
    r['se']  = r['tst']/r['tib']*100
    if len(sleep_ids)>0:
        r['waso'] = np.sum(sleep_stages[sleep_ids[0]:sleep_ids[-1]+1]==5)*epoch_time/60
        r['sl'] = sleep_ids[0]*epoch_time/60
    else:
        r['waso'] = np.nan
        r['sl'] = np.nan
    rem_ids = np.where(sleep_stages==4)[0]
    if len(rem_ids)>0:
        r['rl'] = rem_ids[0]*epoch_time/60
    else:
        r['rl'] = np.nan

    # Calculates time in each stage  
    r['w_time'] = (sleep_stages==5).sum()*epoch_time/60
    r['r_time'] = (sleep_stages==4).sum()*epoch_time/60
    r['n1_time'] = (sleep_stages==3).sum()*epoch_time/60
    r['n2_time'] = (sleep_stages==2).sum()*epoch_time/60
    r['n3_time'] = (sleep_stages==1).sum()*epoch_time/60

    # Percent time in each stage 
    r['r_perc'] = (sleep_stages==4).mean()*100
    r['n1_perc'] = (sleep_stages==3).mean()*100
    r['n2_perc'] = (sleep_stages==2).mean()*100
    r['n3_perc'] = (sleep_stages==1).mean()*100

    # Sleep stages separated by first and second half of the night 
    Lhalf = len(sleep_stages)//2
    r['r_perc_half1'] = (sleep_stages[:Lhalf]==4).mean()*100
    r['n1_perc_half1'] = (sleep_stages[:Lhalf]==3).mean()*100
    r['n2_perc_half1'] = (sleep_stages[:Lhalf]==2).mean()*100
    r['n3_perc_half1'] = (sleep_stages[:Lhalf]==1).mean()*100
    r['r_perc_half2'] = (sleep_stages[Lhalf:]==4).mean()*100
    r['n1_perc_half2'] = (sleep_stages[Lhalf:]==3).mean()*100
    r['n2_perc_half2'] = (sleep_stages[Lhalf:]==2).mean()*100
    r['n3_perc_half2'] = (sleep_stages[Lhalf:]==1).mean()*100
    
    # Transition probability matrix 
    # Calculates transitions between stages, converts counts into probabilities to account for chance of stage transition happening in a certain stage
    transmat = np.zeros((5,5))
    for i in range(len(sleep_stages)-1):
        s1 = sleep_stages[i]
        s2 = sleep_stages[i+1]
        if s1 in [1,2,3,4,5] and s2 in [1,2,3,4,5]:
            transmat[int(s1)-1,int(s2)-1] += 1
    transmat = transmat/transmat.sum(axis=1, keepdims=True)
    r['n3_continue_prob'] = transmat[0,0]
    r['n2_continue_prob'] = transmat[1,1]
    r['n1_continue_prob'] = transmat[2,2]
    r['r_continue_prob'] = transmat[3,3]
    r['w_continue_prob'] = transmat[4,4]
    
    # Bout duration: longest continuous period spent in each sleep stage 
    ss = np.array(sleep_stages)
    ss[np.isnan(ss)|np.isinf(ss)] = -1
    r['n1_bout_dur'] = [0]
    r['n2_bout_dur'] = [0]
    r['n3_bout_dur'] = [0]
    r['r_bout_dur'] = [0]
    for k, l in groupby(ss):
        if k==1:
            r['n3_bout_dur'].append(len(list(l))*epoch_time/60)
        elif k==2:
            r['n2_bout_dur'].append(len(list(l))*epoch_time/60)
        elif k==3:
            r['n1_bout_dur'].append(len(list(l))*epoch_time/60)
        elif k==4:
            r['r_bout_dur'].append(len(list(l))*epoch_time/60)
    r['n1_bout_dur'] = max(r['n1_bout_dur'])
    r['n2_bout_dur'] = max(r['n2_bout_dur'])
    r['n3_bout_dur'] = max(r['n3_bout_dur'])
    r['r_bout_dur'] = max(r['r_bout_dur'])

    return r


In [3]:
# Calculate respiration rate (RR) from chest and abdomen signals using neurokit2

def get_resp_rate(abd, chest, Fs):
    """
    function to calculate RR for each PSG
    abd = abdomen signal, array 
    chest = chest signal, array 
    Fs = sampling frequency in Hz of the two signals 

    Uses the neurokit library to make these calculations
    """

    # First initiatlize empty lists 
    resp_rates = []
    missing_rates = []

    # Loop through the signals
    # Calculate instantaneous respiration rate based on time between peaks
    for resp in [abd, chest]:
        try:
            resp = nk.rsp_clean(resp, sampling_rate=Fs)
            _, resp_peaks_dict = nk.rsp_peaks(resp, sampling_rate=Fs)
            resp_peaks_dict = nk.rsp_fixpeaks(resp_peaks_dict)
            resp_rate = nk.rsp_rate(resp, resp_peaks_dict, sampling_rate=Fs)
            resp_rate[(resp_rate>30)|(resp_rate<5)] = np.nan  # remove extreme values
        except Exception as ee:
            resp_rate = np.zeros_like(resp)+np.nan
        missing_rates.append(np.isnan(resp_rate).mean())
        resp_rates.append(resp_rate)
    resp_rate = resp_rates[np.argmin(missing_rates)]
    return resp_rate


In [4]:
# Main function that applies all pre-defined functions to each edf file and extracts SpO2 features for each participant from edf files 
# Stores results in CSV 

def main():
    # Import mastersheet for demographics and PIDs
    df = pd.read_excel('mastersheet_temp.xlsx')

    # Create folders for storing all data generated 
    res_dir = 'intermediate_results'
    os.makedirs(res_dir, exist_ok=True)

    pickle_folder ="clean-pickle"
    os.makedirs(pickle_folder, exist_ok=True)

    # Extracting the PID, age, sex, etc 
    for i in tqdm(range(len(df))):
        try:
            age = df.Age.iloc[i]
            sex = df.Sex.iloc[i]
            pid = df.PID.iloc[i]
            edf_path = os.path.join('clean-edf', f'shhs1-{pid}.edf')
            print(pid)
            
            edf = mne.io.read_raw_edf(edf_path, verbose=False, preload=False)
            # Note: when loading an EDF using the mne package, all signals are resampled to the highest frequency
            channel_names = edf.info['ch_names']
            start_time = edf.info['meas_date'].replace(tzinfo=None)

            # Load only the SpO2 data 
            exclude_list = [x for x in channel_names if x!='SaO2']
            edf_spo2 = mne.io.read_raw_edf(edf_path, verbose=False, exclude=exclude_list)
            
            data = edf_spo2.get_data()
            ## NOTE: for 91 files, two additional channels were stored for SpO2. Airflow 0 and Airflow 1. 

            spo2 = data[0]  # shape = (T,)

            # Set sampling frequency and resample to 1Hz 
            Fs = edf_spo2.info['sfreq']
            if Fs!=1:
                spo2 = spo2[::int(Fs)]
                Fs = 1

            # Load respiratory effort belt
            resp_ch_names = ['THOR RES', 'ABDO RES']
            edf_resp = mne.io.read_raw_edf(edf_path, verbose=False, exclude=[x for x in channel_names if x not in resp_ch_names ])
            resp = edf_resp.get_data(picks=resp_ch_names)  # shape = (2,T)
            Fs_resp = edf_resp.info['sfreq']
            
            # Resmaple to 10Hz if not 10Hz
            if Fs_resp!=10:
                resp = resample(resp, int(round(resp/Fs_resp*10)), axis=-1)
                Fs_resp = 10
            thor_res = resp[0]
            abdo_res = resp[1]
            chest = thor_res
            abd = abdo_res
            
            # Read annotation files to extract scored sleep stage info  
            annot_path = f'edf-annotations/shhs1-{pid}-nsrr.xml'
            with open(annot_path, 'r') as f:
                annot = xmltodict.parse(f.read())
            annot = pd.DataFrame(annot['PSGAnnotation']['ScoredEvents']['ScoredEvent'])  
            # annot is a pandas dataframe with columns: EventType, EventConcept, Start, Duration...

            annot = annot[annot.EventType=='Stages|Stages'].reset_index(drop=True)
            annot = annot.rename(columns={'Start':'onset', 'Duration':'duration', 'EventConcept':'description'})

            # Creates a map to apply to annotations 
            stage_mapping = {'Wake|0':5, 'REM sleep|5':4,  'Stage 1 sleep|1':3, 'Stage 2 sleep|2':2, 'Stage 3 sleep|3':1} 
            annot['description'] = [stage_mapping.get(annot.description.iloc[x], np.nan) for x in range(len(annot.description))]
            assert Fs==1 
            assert Fs_resp==10
            
            # Essentially maps sleep stages from first function to match the same frequency of data points as SPO2 data is given 
            sleep_stages = np.zeros(len(spo2)) + np.nan
            
            # Create reverse mapping to safely check if description is a valid sleep stage number.
            # This avoids issues with NaN or unexpected values in 'description' that may not match stage_mapping.values() cleanly.
            reverse_mapping = {v: k for k, v in stage_mapping.items()}
            
            annot['onset'] = pd.to_numeric(annot['onset'], errors='coerce')
            annot['duration']=pd.to_numeric(annot['duration'], errors='coerce')                               
        
            for j in range(len(annot)):
                description = annot.loc[j, 'description']
                
                if description in reverse_mapping: 
                    start = int(round(annot.loc[j, 'onset']*Fs))
                    end = int(round((annot.loc[j,'onset']+annot.loc[j, 'duration'])*Fs))
                    start = max(0,start)
                    end = min(len(sleep_stages),end)
                    
                    if start<end:
                        sleep_stages[start:end] = description
    
    
            # Now must downsample to standardize to the 30-second epochs 
            sleep_stages2 = sleep_stages[::30]
    
            # Use the first function to calculate the sleep macrostructures using the mapped sleep stage-SPO2 dataset 
            feat1 = get_macrostructures(sleep_stages2)
            tst = feat1['tst']
            
            # Calculate hypoxic burden  
            # Filter out the signal to suppress detailed output (limiting any noise and artifacts with verbose) 
            # Outputs info about the start time, duration, and how low it drops 
            feat2 = {}
            spo2 = filter_spo2(spo2, Fs, verbose=False) #using func defined in py file, limiting extreme drops in spo2 signal 
            try:
                od = detect_oxygen_desaturation(spo2, is_plot=False, max_duration=90)
            except Exception as ee:
                print(pid, str(ee))
                od = []
    
            # Calculates hypoxic burden if desats present 
            # Calculates midpoints of desats using a function from hb_functions
            # Otherwise assigns NAN to prevent errors 
            if len(od)>0:
                event_times = od.Start.values+od.Duration.values/2
                feat2['hb_desat'], hb_response_desat = calc_hypoxic_burden(spo2, event_times, 1, tst)
                spo2_nodesat = get_spo2_no_desat(spo2, od)
            else:
                feat2['hb_desat'] = np.nan
                hb_response_desat = np.nan
                spo2_nodesat = spo2
    
            # Next, compute SPO2 values for each sleep stage 
            # ss is sleep stage/category, creating a boolean mask with ids to select correspnding data points 
            # sn is sleep number
            # Go through each sleep stage of interest//mask others using bool//look at the relevant SpO2 values in that stage 
            for sn, ss in zip([0,1,2,3,4,5,[1,2,3],6,[1,2,3,4],7],['ALL', 'N3', 'N2', 'N1', 'R', 'W','NREM','WBSO','SLEEP','AFTER_SO']):
                if type(sn)==list:
                    ids = np.in1d(sleep_stages, sn)
                elif sn==6:
                    sleep_indices = np.where(np.in1d(sleep_stages, [1,2,3,4]))[0]
                    if len(sleep_indices)>0:
                        ids=np.arange(sleep_indices[0])
                    else:
                        ids = []
                elif sn==7:
                    sleep_indices = np.where(np.in1d(sleep_stages, [1,2,3,4]))[0]
                    if len(sleep_indices)>0:
                        ids=np.arange(sleep_indices[0], len(sleep_stages))
                    else:
                        ids = []
                elif sn==0:
                    ids = np.ones(len(sleep_stages), dtype=bool)
                else:
                    ids = sleep_stages==sn
    
                feat2['%spo2<95%_'+ss] = np.nanmean(spo2[ids]<95)*100 # SPO2 levels below 95% 
                feat2['%spo2<90%_'+ss] = np.nanmean(spo2[ids]<90)*100 # SPO2 below 90%
                feat2['avg_spo2_'+ss] = np.nanmean(spo2[ids]) # AVG spo2 
                feat2['avg_spo2_no_desat_'+ss] = np.nanmean(spo2_nodesat[ids]) # AVG no desats
            
            # Calculate differences between SPO2 metrics for after sleep onset versus before sleep onset
            # Provides baseline and insight into SDB (sleep-disordered breathing)
            for fn in ['%spo2<95%', '%spo2<90%','avg_spo2','avg_spo2_no_desat']:
                feat2[fn+'_diff'] = feat2[fn+'_AFTER_SO']-feat2[fn+'_WBSO']
    
            # Respiration rate calculations 
            # Ensuring the same resolution (repeating the sleepstages data to match the respiration signal frequency)
            # Calculates using a hb_Functions function 
            sleep_stages2 = np.repeat(sleep_stages, 10)
            L = min(len(abd), len(sleep_stages2))
            abd = abd[:L]
            chest = chest[:L]
            sleep_stages2 = sleep_stages2[:L]
            resp_rate = get_resp_rate(abd, chest, Fs_resp)
            
            with open(os.path.join(res_dir, f'{pid}.pickle'), 'wb') as fff:
                pickle.dump({
                    'hb_response_desat':hb_response_desat,
                    'spo2':spo2, 'spo2_nodesat':spo2_nodesat,
                    'resp_rate':resp_rate,
                    'sleep_stages':sleep_stages,
                    }, fff)
    
            # Similar logic as above, calculating RR metrics per stage 
            feat3 = {}
            for sn, ss in zip([0,1,2,3,4,5,[1,2,3],6,[1,2,3,4],7],['ALL', 'N3', 'N2', 'N1', 'R', 'W','NREM','WBSO','SLEEP','AFTER_SO']):
                if type(sn)==list:
                    ids = np.in1d(sleep_stages2, sn)
                elif sn==6:
                    ids = np.arange(np.where(np.in1d(sleep_stages2, [1,2,3,4]))[0][0])
                elif sn==7:
                    ids = np.arange(np.where(np.in1d(sleep_stages2, [1,2,3,4]))[0][0], len(sleep_stages2))
                elif sn==0:
                    ids = np.ones(len(sleep_stages2), dtype=bool)
                else:
                    ids = sleep_stages2==sn
                feat3['avg_rr_'+ss] = np.nanmean(resp_rate[ids])
            feat3['avg_rr_diff'] = feat3['avg_rr_AFTER_SO']-feat3['avg_rr_WBSO']
    
            # Merge the three feature dictionaries into one to export 
            # Keys k become column names, i is the row for each feature per participant
            # "items" pulls out every key with every value 
            # Values become v for each feature 
            feats = feat1|feat2|feat3
            for k,v in feats.items():
                df.loc[i, k] = v
        
            # Starts with relevant participant metadata, then adds column names for all feature keys 
            # Reorders df to follow this order 
            cols = ['PID', 'Age', 'Sex', 'race']+list(feats.keys())
            df = df[cols]
        
            # Save data to csv 
            df.to_csv('all_spo2_features.csv', index=False)
            print(f"Completed: {pid}")
        except Exception as outer_e:
            print(f"Error in processing PID {pid}: {outer_e}")
            continue

In [None]:
# Run function 
# Note: must have edf files stored in a folder titled "clean-edf"
# Must have annotations in a file stored "edf-annotations"
if __name__=='__main__':
    main()