In [None]:
import os
import mne
import json
import h5py
import numpy as np
import scipy
from datetime import datetime
from scipy.signal import butter, filtfilt, lfilter
from scipy.signal import butter, lfilter, sosfilt, iirnotch
from sklearn.preprocessing import scale
from tqdm import tqdm
from scipy.signal import spectrogram, get_window

In [None]:
def dsToDt(ds):
    obj_datetime = datetime.strptime(ds, '%Y/%m/%d %I:%M:%S %p')
    return obj_datetime


def start_time_diff(ann, edf):
    '''
    * Input 
        ann_start_time: annotation file start time, 
                        string format
        edf_start_time: measurement date (edf file start time), 
                        datetime format
    * Output
        time_diff: seconds
    '''
    ann_start = dsToDt(ann)
    edf_start = edf.replace(tzinfo=None)
    
    # edf_start < ann_start : for all cases
    diff = ann_start - edf_start
    time_diff = diff.seconds
    
    return time_diff


def matching_edf_ann(signal, t, s_rate, time_diff, num_epoch):
    '''
    * Input
        signal: edf signal
        t: time
        s_rate: sampling rate
        time_diff: start time difference between annotation and edf files
        num_epoch: number of epochs in annotation
    * Output
        mat_signal: matched signal data
        mat_t: matched time
        num_epoch_clip: matched number of epochs
    '''
    # Calculate matching start time with time difference
    start_idx = int(time_diff * s_rate)

    signal_len = signal.shape[-1]
    # In case the signal length is shorter than annotated num_epoch
    # Get number of epoch by clipping overlapping epochs
    num_epoch_clip = min(num_epoch, (signal_len - start_idx) // (30 * s_rate)) 
    
    # Total time of clipped epochs
    num_idx = num_epoch_clip  * 30 * s_rate
    
    # Calculate matching end time
    end_idx = int(start_idx + num_idx)
        
    mat_signal = signal[:, start_idx:end_idx]
    mat_t = t[start_idx:end_idx]
    
    return mat_signal, mat_t, num_epoch_clip


def convert_Hz(input_data, from_hz=200, to_hz=100, epoch_size=30, show_plot=False):
    left_idx = np.floor(np.arange(to_hz*epoch_size)/to_hz * from_hz).astype(np.int64)
    right_idx = np.ceil(np.arange(to_hz*epoch_size)/to_hz * from_hz).astype(np.int64)
    
    left_wght = (right_idx/from_hz - np.arange(to_hz*epoch_size)/to_hz) * from_hz
    right_wght = (np.arange(to_hz*epoch_size)/to_hz - left_idx/from_hz) * from_hz
    
    # Adjustment where left_idx == right_idx, left_wght = right_wght = 0
    left_wght[(left_wght + right_wght) == 0] = 1
    
    converted_data = input_data[:,left_idx] * left_wght + input_data[:,right_idx] * right_wght 
    new_idx = left_idx * left_wght + right_idx * right_wght
    
    return converted_data



def butter_bandpass_filter_seoul(data, lowcut, highcut, fs, order=8):
    def butter_bandpass(lowcut, highcut, fs, order=8):
        low = lowcut
        high = highcut
        sos  = butter(order, [low, high], btype='band', fs=fs, output='sos', analog=False)
        return sos 
    
    sos  = butter_bandpass(lowcut, highcut, fs, order=order)
    
    y = sosfilt(sos, data)
    
    return y


def butter_bandpass_filter_hallym(signals, lowcut, highcut, fs, order=4, bandstop=None):
    if not bandstop == None:
        samp_freq = fs  # Sample frequency (Hz)
        notch_freq = 60.0  # Frequency to be removed from signal (Hz)
        quality_factor = 30.0  # Quality factor
        b_notch, a_notch = iirnotch(notch_freq, quality_factor, samp_freq)
        signals = filtfilt(b_notch, a_notch, signals)
    
    def butter_bandpass(lowcut, highcut, fs, order=4):
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        b,a = butter(N=order,Wn=[low,high],btype='bandpass', analog=False,output='ba')
        return b,a
    
    b,a = butter_bandpass(lowcut,highcut, fs, order=order)

    y = lfilter(b,a,signals)
    
    return y


def notch_filter(data, f0, Q, fs):
    b, a = iirnotch(f0, Q, fs)
    y = lfilter(b, a, data)
    return y

In [None]:
ann_path = '/tf/data/#_2020_Sleep_Quality/1_PSG_210426.json'
edf_path = '/tf/data/#_2020_Sleep_Quality/1_PSG/'

In [None]:
with open (ann_path, 'r') as f:
    ann = json.load(f)

In [None]:
patients = ann['Patient']

In [None]:
offset = 1

fs = 100
epoch_second = 30
win_size  = 2
overlap = 1
nfft = 2 ** np.ceil(np.log2(win_size * fs)).astype(int)

In [None]:
print(len(patients))

In [None]:
root_path = "/tf/data2/jmjeong/prep_sg_snuh_mat/"

name_idx = 1
for idx, p in enumerate(patients):
    try:
        patient_snum = p['Report']['Patient Serial Number']
        branch = '-'.join(patient_snum.split('-')[:-1])
        edf_filename = os.path.join(edf_path, branch, patient_snum, patient_snum +'_edf.edf')
        raw = mne.io.read_raw_edf(edf_filename)

        info = raw.info
        edf_start = info['meas_date']
        sampling_rate = info['sfreq']

        patient_event = p['Event']
        ann_start = p['Event'][0]['Start_Time']
        num_epoch = p['Num_of_Image(epoch)']

        timediff = start_time_diff(ann_start, edf_start)

        labels = -1 * np.ones((num_epoch,))

        first = True
        offset = 1

        for event in patient_event:
            start_epoch = event['Start_Epoch']
            end_epoch = event['End_Epoch']
            status = -1

            e = event['Event_Label']

            if e == 'Wake':
                status = 0
            elif e == 'N1':
                status = 1
            elif e == 'N2':
                status = 2
            elif e == 'N3':
                status = 3
            elif e == 'REM':
                status = 4
            else:
                continue
                
            if first:
                ann_start = event['Start_Time']
                offset = start_epoch
                first = False

            labels[start_epoch - offset:end_epoch - offset] = status # start_epoch starts from 1: 1,2,3,...

        if sum(labels) < 0:
            print("Error file!")
            continue
        
        found = False
        for kword in ['C4-A1','C4-M1']: # ['Right-A1', 'E2-M1']: ['Chin1', 'Chin', 'CHIN1', 'CHIN', 'Chib1', '1-2']::
            if kword in raw.ch_names:
                signal, t = raw.get_data(picks=kword, return_times=True)
                found= True
                break
        
        mat_signal, mat_t, num_epoch_clip = matching_edf_ann(signal, t, sampling_rate, timediff, num_epoch)    
        labels = labels[:int(num_epoch_clip)]
        
        # ??? -> 200
        if sampling_rate != 200:
            mat_signal = convert_Hz(mat_signal, from_hz=sampling_rate, to_hz=200, epoch_size=num_epoch_clip*30)
        
        # reshape
        target_idx = np.where(labels != -1)[0]
        mat_signal = np.reshape(mat_signal, [-1, epoch_second*200])   
        mat = mat_signal[target_idx]
        stages = labels[target_idx]
        N = len(stages)
        
        # prep: hallym
        filterd_mat = []
        for m in mat:
            low, high = 0.3, 35 
            filtered_epoch = butter_bandpass_filter_hallym(
                signals=m,
                lowcut=low,
                highcut=high,
                fs=200,
                order=4,
                bandstop=60
            )
            
            filterd_mat.append(filtered_epoch)
        mat = np.array(filterd_mat)

        # 200 -> 100
        mat = np.reshape(mat, [1, -1])
        if sampling_rate != 100:
            mat = convert_Hz(mat, from_hz=200, to_hz=100, epoch_size=num_epoch_clip*30)
        mat = np.reshape(mat, [-1, epoch_second*100])

        # one hot
        stages_one_hot = np.zeros((N, 5))
        for i in range(N):
            stages_one_hot[i, int(stages[i])] = 1

        nfft = 2 ** np.ceil(np.log2(win_size * 100)).astype(int)
        mat_eeg = np.zeros((N, 29, nfft//2+1))
            
        for k in range(N):
            f, t, Xk = spectrogram(mat[k, :], fs=fs, window=get_window('hamming', win_size*fs), noverlap=overlap*fs, nfft=nfft)
            Xk = 20 * np.log10(np.abs(Xk))
            mat_eeg[k, :, :] = Xk.T

        N, t, f = mat_eeg.shape
        mat_eeg_ = mat_eeg.reshape(N, t*f)

        inf_ind = np.isinf(np.sum(mat_eeg_, axis=1))
        count = np.sum(inf_ind)

        if count > 0:
            print(f"{idx}: {count} inf epochs removed")
            stages = stages[~inf_ind]
            stages_one_hot = stages_one_hot[~inf_ind, :]
            mat = mat[~inf_ind, :]
            mat_eeg = mat_eeg[~inf_ind, :, :]

        assert np.sum(np.isnan(mat_eeg)) == 0, 'NaN'
        assert np.sum(np.isinf(mat_eeg)) == 0, 'Inf'

        stages_one_hot = stages_one_hot.astype(np.float32)  # one-hot encoding
        stages = stages.reshape(-1, 1)
        stages = stages + 1
        stages = stages.astype(np.float32)
        X2 = mat_eeg.astype(np.float32)
        X1 = mat.astype(np.float32)

        scipy.io.savemat(root_path + 'n' + str(name_idx).zfill(4) + '_eeg.mat', {'X1': X1, 'X2': X2, 'label': stages, 'y': stages_one_hot}, do_compression=True)
        print(root_path + 'n' + str(name_idx).zfill(4) + '_eeg.mat')
        name_idx += 1
        
    except:
        pass