In [1]:
from snorkel.labeling.apply.dask import PandasParallelLFApplier
from snorkel.labeling.model.label_model import LabelModel
from snorkel.labeling import labeling_function, LFAnalysis
from sklearn.metrics import confusion_matrix
from matrixprofile import matrixProfile
from datetime import datetime, timedelta
import pandas as pd
import numpy as np

In [2]:
# Label mappings
ABSTAIN = -1
NOT_SUPPRESSIBLE = 0
SUPPRESSIBLE = 1

In [3]:
# Read in data
alarms_df = pd.read_pickle('chop_spo2alarms_df.pkl')
vitals_df = pd.read_pickle('chop_vitals_df.pkl')

In [4]:
# Set age factors for labeling functions
age_factors_df = pd.DataFrame({'pt_age_group':[1,2,3,4], 
                                'pt_age_group_L':['< 1 month','1-< 2 month','2-< 6 month','6 months and older'], 
                                'hr_age_factor':[3.833, 3.766, 3.733, 3.533], 
                                'rr_age_factor':[0.933, 0.9, 0.866, 0.8]}, 
                                index=[1,2,3,4])

In [5]:
# Matrix profile related functions

np.seterr(divide='ignore', invalid='ignore')    # setting for mp

def compute_std(ts, window):
    """
    Compute moving std
    
    Args:
        ts -
        window -
    
    Return:
        float - moving std
    """
    ts_std = []
    for i in range(len(ts) - window):
        ts_std.append(np.std(ts[i: i+window-1]))
    ts_std = np.array(ts_std)
    ts_std_head = np.zeros(window // 2 - 1)
    ts_std_tail = np.zeros(len(ts) - len(ts_std) - window // 2 + 1)
    ts_std = np.concatenate([ts_std_head, ts_std, ts_std_tail])
    return ts_std


def compute_mp(ts, window, threshold=None):
    """
    Compute matrix profile at given window
    
    Args:
        ts -
        window -
        threshold -
    
    Return:
        numpy array - matrix profile
        float - percentage of outliers (if threshold given)
    """
    # sfp - commenting this out as it is unneccesary
    # remove trailing nans of ts
    #i = len(ts) - 1
    #while np.isnan(ts[i]) and i >= 0:
    #    i -= 1
    #ts = ts[0:i+1]
    
    # compute mp by stamp
    mp = np.array(matrixProfile.stomp(ts, m=window))[0]
    
    # calibrate ts and mp, so mp value is assigned to the middle of that window
    mp_head = np.zeros(window//2 - 1)
    mp_tail = np.zeros(len(ts) - len(mp) - window//2 + 1)
    mp = np.concatenate([mp_head, mp, mp_tail])
    
    # remove error results due to zero std (make them 0 so they don't contribute to outliers)
    ts_std = compute_std(ts, window=window)
    count_zero_std = 0
    for i in range(len(ts_std)):
        if ts_std[i] == 0:
            mp[i] = 0
            count_zero_std += 1

    # compute percentage of outliers, where head, tail and zero std points do not participate
    if not threshold is None:
        outlier = mp[np.where(mp > threshold)]
        outlier_percentage = len(outlier) / (len(mp) - len(mp_head) - len(mp_tail) - count_zero_std)
        return mp, outlier_percentage
    
    return mp


def apply_compute_mp():
    """
    Compute a matrix profile per patient with window sizes 20 to 120 seconds
    """
    for v in ['SPO2-%','HR','RESP']:
        for w in range(20,130,10):
            vitals_df[v + ' MP' + str(w)] = vitals_df.groupby('pt_id').apply(lambda x: compute_mp(x[v].to_numpy(), w))

In [None]:
apply_compute_mp()

In [None]:
# Labeling functions for Snorkel

def get_vitals(pt_id, v_sign):
    """
    Get timeseries of a specific vital sign for a given patient
    
    Args:
        pt_id - integer id of patient
        v_sign - string vital sign name
        
    Return:
        timeseries array of vital sign for patient
    """
    return vitals_df.loc[(pt_id, ), v_sign]


@labeling_function()
def lf_long_alarm_60s(x):
    """
    If the alarm duration is at least 60 seconds then the alarm is not suppressible, 
    otherwise abstain 
    """
    return NOT_SUPPRESSIBLE if x.duration >= 60 else ABSTAIN


@labeling_function()
def lf_long_alarm_65s(x):
    """
    If the alarm duration is at least 65 seconds then the alarm is not suppressible, 
    otherwise abstain 
    """
    return NOT_SUPPRESSIBLE if x.duration >= 65 else ABSTAIN


@labeling_function()
def lf_long_alarm_70s(x):
    """
    If the alarm duration is at least 70 seconds then the alarm is not suppressible, 
    otherwise abstain 
    """
    return NOT_SUPPRESSIBLE if x.duration >= 70 else ABSTAIN


@labeling_function()
def lf_spo2_below85_over120s(x):
    """
    If SpO2 level stays within range (80,85] for longer than 120 seconds since alarm start 
    then the alarm is not suppressible, otherwise abstain
    """
    spo2 = get_vitals(x.pt_id, 'SPO2-%')
    spo2 = spo2[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=121))]
    return NOT_SUPPRESSIBLE if np.all( (80 < spo2) & (spo2 <= 85) ) else ABSTAIN


@labeling_function()
def lf_spo2_below80_over100s(x):
    """
    If SpO2 level stays within range (70,80] for longer than 100 seconds since alarm start 
    then the alarm is not suppressible, otherwise abstain
    """
    spo2 = get_vitals(x.pt_id, 'SPO2-%')
    spo2 = spo2[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=101))]
    return NOT_SUPPRESSIBLE if np.all( (70 < spo2) & (spo2 <= 80) ) else ABSTAIN


@labeling_function()
def lf_spo2_below70_over90s(x):
    """
    If SpO2 level stays within range (60,70] for longer than 90 seconds since alarm start 
    then the alarm is not suppressible, otherwise abstain
    """
    spo2 = get_vitals(x.pt_id, 'SPO2-%')
    spo2 = spo2[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=91))]
    return NOT_SUPPRESSIBLE if np.all( (60 < spo2) & (spo2 <= 70) ) else ABSTAIN


@labeling_function()
def lf_spo2_below60_over60s(x):
    """
    If SpO2 level stays within range (50,60] for longer than 60 seconds since alarm start 
    then the alarm is not suppressible, otherwise abstain
    """
    spo2 = get_vitals(x.pt_id, 'SPO2-%')
    spo2 = spo2[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=61))]
    return NOT_SUPPRESSIBLE if np.all( (50 < spo2) & (spo2 <= 60) ) else ABSTAIN


@labeling_function()
def lf_spo2_below50_over30s(x):
    """
    If SpO2 level stays within range (0,50] for longer than 30 seconds since alarm start 
    then the alarm is not suppressible, otherwise abstain
    """
    spo2 = get_vitals(x.pt_id, 'SPO2-%')
    spo2 = spo2[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=31))]
    return NOT_SUPPRESSIBLE if np.all( spo2 <= 50 ) else ABSTAIN
    

@labeling_function()
def lf_hr_below50_over120s(x):
    """
    If HR below 50 * age factor for over 120 seconds since alarm start then the alarm is not suppressible, 
    otherwise abstain
    """
    hr = get_vitals(x.pt_id, 'HR')
    hr = hr[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=121))]
    age_factor = age_factors_df.loc[x.pt_age_group, 'hr_age_factor']
    return NOT_SUPPRESSIBLE if np.all( (40*age_factor < hr) & (hr <= 50*age_factor) ) else ABSTAIN
    
    
@labeling_function()
def lf_hr_below40_over60s(x):
    """
    If HR below 40 * age factor for over 60 seconds since alarm start then the alarm is not suppressible, 
    otherwise abstain
    """
    hr = get_vitals(x.pt_id, 'HR')
    hr = hr[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=61))]
    age_factor = age_factors_df.loc[x.pt_age_group, 'hr_age_factor']
    return NOT_SUPPRESSIBLE if np.all( (30*age_factor < hr) & (hr <= 40*age_factor) ) else ABSTAIN


@labeling_function()
def lf_hr_below30(x):
    """
    If HR below 30 * age factor for any duration since alarm start then the alarm is not suppressible, 
    otherwise abstain
    """
    hr = get_vitals(x.pt_id, 'HR')
    hr = hr[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=int(x.duration)))]
    age_factor = age_factors_df.loc[x.pt_age_group, 'hr_age_factor']
    return NOT_SUPPRESSIBLE if np.all( hr <= 30*age_factor ) else ABSTAIN


@labeling_function()
def lf_rr_below50_over120s(x):
    """
    If RR below 50 * age factor for over 120 seconds since alarm start then the alarm is not suppressible, 
    otherwise abstain
    """
    rr = get_vitals(x.pt_id, 'RESP')
    rr = rr[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=121))]
    age_factor = age_factors_df.loc[x.pt_age_group, 'rr_age_factor']
    return NOT_SUPPRESSIBLE if np.all( (40*age_factor < rr) & (rr <= 50*age_factor) ) else ABSTAIN


@labeling_function()
def lf_rr_below40_over60s(x):
    """
    If RR below 40 * age factor for over 60 seconds since alarm start then the alarm is not suppressible, 
    otherwise abstain
    """
    rr = get_vitals(x.pt_id, 'RESP')
    rr = rr[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=61))]
    age_factor = age_factors_df.loc[x.pt_age_group, 'rr_age_factor']
    return NOT_SUPPRESSIBLE if np.all( (30*age_factor < rr) & (rr <= 40*age_factor) ) else ABSTAIN


@labeling_function()
def lf_rr_below30(x):
    """
    If RR below 30 * age factor for any duration since alarm start then the alarm is not suppressible, 
    otherwise abstain
    """
    rr = get_vitals(x.pt_id, 'RESP')
    rr = rr[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=int(x.duration)))]
    age_factor = age_factors_df.loc[x.pt_age_group, 'rr_age_factor']
    return NOT_SUPPRESSIBLE if np.all( rr <= 30*age_factor ) else ABSTAIN


def repeat_alarms(x, t):
    """
    If there exists other SpO2 alarms 't' minutes prior to the current alarm's start time and/or
    if there exists other SpO2 alarms 't' minutes after the current alarm's end time then
    the alarm is not suppressible, otherwise abstain

    Args:
        x - alarm instance
        t - timespan to consider (in seconds)
    """
    prior_alarms = alarms_df[ (alarms_df['pt_id'] == x.pt_id) & \
                                (x.alarm_datetime - timedelta(seconds=t) <= alarms_df['alarm_datetime']) & \
                                (alarms_df['alarm_datetime'] < x.alarm_datetime) ]

    subsq_alarms = alarms_df[ (alarms_df['pt_id'] == x.pt_id) & \
                                (x.alarm_datetime + timedelta(seconds=int(x.duration)) <= alarms_df['alarm_datetime']) & \
                                (alarms_df['alarm_datetime'] <= x.alarm_datetime + timedelta(seconds=int(x.duration + t))) ]
    
    count = prior_alarms.shape[0] + subsq_alarms.shape[0]
    
    return NOT_SUPPRESSIBLE if count > 0 else ABSTAIN


@labeling_function()
def lf_repeat_alarms_15s(x):
    """
    If there exists other alarms within 15 seconds of the current alarm then the alarm is
    not suppressible, otherwise abstain
    """
    return repeat_alarms(x, 15) 


@labeling_function()
def lf_repeat_alarms_30s(x):
    """
    If there exists other alarms within 30 seconds of the current alarm then the alarm is
    not suppressible, otherwise abstain
    """
    return repeat_alarms(x, 30) 


@labeling_function()
def lf_repeat_alarms_60s(x):
    """
    If there exists other alarms within 60 seconds of the current alarm then the alarm is
    not suppressible, otherwise abstain
    """
    return repeat_alarms(x, 60)


@labeling_function()
def lf_short_alarm_15s(x):
    """
    If the SpO2-Low alarm duration is at most 15 seconds then the alarm is suppressible, 
    otherwise abstain 
    """
    return SUPPRESSIBLE if x.duration <= 15 else ABSTAIN


@labeling_function()
def lf_short_alarm_10s(x):
    """
    If the SpO2-Low alarm duration is at most 10 seconds then the alarm is suppressible, 
    otherwise abstain 
    """
    return SUPPRESSIBLE if x.duration <= 10 else ABSTAIN


@labeling_function()
def lf_short_alarm_5s(x):
    """
    If the SpO2-Low alarm duration is at most 5 seconds then the alarm is suppressible, 
    otherwise abstain 
    """
    return SUPPRESSIBLE if x.duration <= 5 else ABSTAIN


def max_recovery(data):
    r = []
    
    for i in range(len(data)):
        for j in range(i+1, len(data)):
            r.append(data[j] - data[i])
            
    return max(r)


@labeling_function()
def lf_immediate_recovery_10s(x):
    """
    If SpO2 level increases/recovers by more than 20 percentage points within 
    10 seconds of alarm start then the alarm is suppressible, otherwise abstain
    """
    spo2 = get_vitals(x.pt_id, 'SPO2-%')
    return SUPPRESSIBLE if max_recovery(spo2[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=10))]) > 20 else ABSTAIN


@labeling_function()
def lf_immediate_recovery_15s(x):
    """
    If SpO2 level increases/recovers by more than 30 percentage points within 
    15 seconds of alarm start then the alarm is suppressible, otherwise abstain
    """
    spo2 = get_vitals(x.pt_id, 'SPO2-%')
    return SUPPRESSIBLE if max_recovery(spo2[x.alarm_datetime:(x.alarm_datetime + timedelta(seconds=15))]) > 30 else ABSTAIN


@labeling_function()
def lf_hr_tech_err_20(x):
    """
    If the difference between the SpO2 HR and ECG HR is larger than 20 percentage points 
    at time of alarm then suppressible, otherwise abstain
    """
    spo2_hr = get_vitals(x.pt_id, 'SPO2-R')
    ecg_hr = get_vitals(x.pt_id, 'HR')
    return SUPPRESSIBLE if abs(spo2_hr[x.alarm_datetime] - ecg_hr[x.alarm_datetime]) > 20 else ABSTAIN


@labeling_function()
def lf_hr_tech_err_30(x):
    """
    If the difference between the SpO2 HR and ECG HR is larger than 30 percentage points 
    at time of alarm then suppressible, otherwise abstain
    """
    spo2_hr = get_vitals(x.pt_id, 'SPO2-R')
    ecg_hr = get_vitals(x.pt_id, 'HR')
    return SUPPRESSIBLE if abs(spo2_hr[x.alarm_datetime] - ecg_hr[x.alarm_datetime]) > 30 else ABSTAIN


@labeling_function()
def lf_outlier_spo2_120(x):
    spo2 = get_vitals(x.pt_id, 'SPO2-% MP120')
    spo2 = spo2[(x.alarm_datetime - timedelta(seconds=60)):(x.alarm_datetime + timedelta(seconds=60))]
    return SUPPRESSIBLE if np.any(spo2 > 8.4) else ABSTAIN




In [None]:
x = alarms_df.iloc[194]
print(x)

lf_outlier_spo2_120(x)

In [None]:
def pred_threshold(p, thres):
    """
    Predict alarm is SUPPRESSIBLE iff probability of suppressible is larger 
    than given threshold and is larger than probability of not-suppressible

    Args:
        p - array [P(NOT_SUPPRESSIBLE), P(SUPPRESSIBLE)]
        thres - threshold for SUPPRESSIBLE prediction

    Return:
        int - 1 if suppressible, 0 if not-suppressible, -1 if abstain
    """
    if p[SUPPRESSIBLE] > p[NOT_SUPPRESSIBLE] and p[SUPPRESSIBLE] >= thres:
        return SUPPRESSIBLE
    elif p[NOT_SUPPRESSIBLE] > p[SUPPRESSIBLE]:
        return NOT_SUPPRESSIBLE
    else:
        return ABSTAIN

In [None]:
lfs = [
        lf_long_alarm_60s, lf_long_alarm_65s, lf_long_alarm_70s,
        lf_spo2_below85_over120s, lf_spo2_below80_over100s, lf_spo2_below70_over90s, lf_spo2_below60_over60s, lf_spo2_below50_over30s,
        lf_hr_below50_over120s, lf_hr_below40_over60s, lf_hr_below30,
        lf_rr_below50_over120s, lf_rr_below40_over60s, lf_rr_below30,
        lf_repeat_alarms_15s, lf_repeat_alarms_30s, lf_repeat_alarms_60s,
        lf_short_alarm_15s, lf_short_alarm_10s, lf_short_alarm_5s,
        lf_immediate_recovery_10s, lf_immediate_recovery_15s,
        lf_hr_tech_err_20, lf_hr_tech_err_30,
        lf_outlier_spo2_120, lf_outlier_spo2_110, lf_outlier_spo2_100, lf_outlier_spo2_90, lf_outlier_spo2_80, lf_outlier_spo2_70, lf_outlier_spo2_60, lf_outlier_spo2_50, lf_outlier_spo2_40, lf_outlier_spo2_30, lf_outlier_spo2_20,
        lf_outlier_hr_120, lf_outlier_hr_110, lf_outlier_hr_100, lf_outlier_hr_90, lf_outlier_hr_80, lf_outlier_hr_70, lf_outlier_hr_60, lf_outlier_hr_50, lf_outlier_hr_40, lf_outlier_hr_30, lf_outlier_hr_20,
        lf_outlier_rr_120, lf_outlier_rr_110, lf_outlier_rr_100, lf_outlier_rr_90, lf_outlier_rr_80, lf_outlier_rr_70, lf_outlier_rr_60, lf_outlier_rr_50, lf_outlier_rr_40, lf_outlier_rr_30, lf_outlier_rr_20
    ]