In [None]:
import pyspark
import time
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from random import sample
import random

spark = SparkSession.builder.getOrCreate() 

In [None]:
#Hardcode schema to speed up data read
#Code Reference:https://github.com/MIT-LCP/mimic-code/blob/master/buildmimic/aws-athena/mimictoparquet_glue_job.py
schema_icustays = StructType([
    StructField("row_id", IntegerType()),
    StructField("subject_id", IntegerType()),
    StructField("hadm_id", IntegerType()),
    StructField("icustay_id", IntegerType()),
    StructField("dbsource", StringType()),
    StructField("first_careunit", StringType()),
    StructField("last_careunit", StringType()),
    StructField("first_wardid", ShortType()),
    StructField("last_wardid", ShortType()),
    StructField("intime", TimestampType()),
    StructField("outtime", TimestampType()),
    StructField("los", DoubleType())
])

schema_patients = StructType([
    StructField("row_id", IntegerType()),
    StructField("subject_id", IntegerType()),
    StructField("gender", StringType()),
    StructField("dob", TimestampType()),
    StructField("dod", TimestampType()),
    StructField("dod_hosp", TimestampType()),
    StructField("dod_ssn", TimestampType()),
    StructField("expire_flag", IntegerType())
])

schema_services = StructType([
    StructField("row_id", IntegerType()),
    StructField("subject_id", IntegerType()),
    StructField("hadm_id", IntegerType()),
    StructField("transfertime", TimestampType()),
    StructField("prev_service", StringType()),
    StructField("curr_service", StringType())
])

schema_chartevents = StructType([
    StructField("row_id", IntegerType()),
    StructField("subject_id", IntegerType()),
    StructField("hadm_id", IntegerType()),
    StructField("icustay_id", IntegerType()),
    StructField("itemid", IntegerType()),
    StructField("charttime", TimestampType()),
    StructField("storetime", TimestampType()),
    StructField("cgid", IntegerType()),
    StructField("value", StringType()),
    StructField("valuenum", DoubleType()),
    StructField("valueuom", StringType()),
    StructField("warning", IntegerType()),
    StructField("error", IntegerType()),
    StructField("resultstatus", StringType()),
    StructField("stopped", StringType())
])


schema_ditems = StructType([
    StructField("row_id", IntegerType()),
    StructField("itemid", IntegerType()),
    StructField("label", StringType()),
    StructField("abbreviation", StringType()),
    StructField("dbsource", StringType()),
    StructField("linksto", StringType()),
    StructField("category", StringType()),
    StructField("unitname", StringType()),
    StructField("param_type", StringType()),
    StructField("conceptid", IntegerType())
])

schema_admissions = StructType([
    StructField("row_id", IntegerType()),
    StructField("subject_id", IntegerType()),
    StructField("hadm_id", IntegerType()),
    StructField("admittime", TimestampType()),
    StructField("dischtime", TimestampType()),
    StructField("deathtime", TimestampType()),
    StructField("admission_type", StringType()),
    StructField("admission_location", StringType()),
    StructField("discharge_location", StringType()),
    StructField("insurance", StringType()),
    StructField("language", StringType()),
    StructField("religion", StringType()),
    StructField("marital_status", StringType()),
    StructField("ethnicity", StringType()),
    StructField("edregtime", TimestampType()),
    StructField("edouttime", TimestampType()),
    StructField("diagnosis", StringType()),
    StructField("hospital_expire_flag", ShortType()),
    StructField("has_chartevents_data", ShortType())
])


schema_fio2 = StructType([
    StructField("icustay_id", IntegerType()),
    StructField("charttime", TimestampType()),
    StructField("fio2", DoubleType())
    
])

schema_gcs = StructType([
    StructField("icustay_id", IntegerType()),
    StructField("charttime", TimestampType()),
    StructField("gcs", DoubleType()),
    StructField("gcsmotor", DoubleType()),
    StructField("gcsverbal", DoubleType()),
    StructField("gcseyes", DoubleType()),
    StructField("endotrachflag", IntegerType())    
])

schema_sofa = StructType([
    StructField("icustay_id", IntegerType()),
    StructField("hr", IntegerType()),
    StructField("starttime", TimestampType()),
    StructField("endtime", TimestampType()),
    StructField("sofa_24hours", IntegerType())
])

schema_vital = StructType([
    StructField("icustay_id", IntegerType()),
    StructField("charttime", TimestampType()),
    StructField("heartrate", DoubleType()),
    StructField("sysbp", DoubleType()),
    StructField("diasbp", DoubleType()),
    StructField("meanbp", DoubleType()),
    StructField("resprate", DoubleType()),
    StructField("tempc", DoubleType()),
    StructField("spo2", DoubleType()),
    StructField("glucose", DoubleType())
])


In [None]:
#Source Tables
#Postgres Tables
df_icustays = spark.read.csv('gs://peaceful-bruin-307600/db/ICUSTAYS.csv', sep = ',', schema = schema_icustays, header = True)
df_patients = spark.read.csv('gs://peaceful-bruin-307600/db/PATIENTS.csv', sep = ',', schema = schema_patients, header = True)
df_services = spark.read.csv('gs://peaceful-bruin-307600/db/SERVICES.csv', sep = ',', schema = schema_services, header = True)
df_chartevents = spark.read.csv('gs://peaceful-bruin-307600/db/CHARTEVENTS.csv', sep = ',', schema = schema_chartevents, header = True)
df_admissions = spark.read.csv('gs://peaceful-bruin-307600/db/ADMISSIONS.csv', sep = ',', schema = schema_admissions, header = True)
df_ditems = spark.read.csv('gs://peaceful-bruin-307600/db/D_ITEMS.csv', sep = ',', schema = schema_ditems, header = True)
df_sepsis_no_exclusion = spark.read.csv('gs://peaceful-bruin-307600/sepsis3-df-no-exclusions.csv', sep = ',', inferSchema= True, header = True)

#Derived Physionet Tables
df_fio2 = spark.read.csv('gs://peaceful-bruin-307600/derived/fio2.csv', sep = ',', schema = schema_fio2, header = True)
df_gcs = spark.read.csv('gs://peaceful-bruin-307600/derived/gcs.csv', sep = ',', schema = schema_gcs, header = True)
df_sofa = spark.read.csv('gs://peaceful-bruin-307600/derived/sofa_direct', sep = ',', schema = schema_sofa, header = True)
df_vital = spark.read.csv('gs://peaceful-bruin-307600/derived/vital.csv', sep = ',', schema = schema_vital, header = True)

In [None]:
#Create Temporary Tables for query
df_icustays.registerTempTable('icustays')
df_patients.registerTempTable('patients')
df_services.registerTempTable('services')
df_chartevents.registerTempTable('chartevents')
df_admissions.registerTempTable('admissions')
df_ditems.registerTempTable('ditems')
df_fio2.registerTempTable('fio2')
df_gcs.registerTempTable('gcs')
df_sofa.registerTempTable('sofa')
df_vital.registerTempTable('vital')
df_sepsis_no_exclusion.registerTempTable('sepsis3')

In [None]:
#Get Suspected infection time for each icustay_id
query = \
"""
select
    Distinct
    icustay_id
    ,suspected_infection_time_poe
    ,suspected_infection_time_poe_days

from sepsis3
    Where suspected_infection_time_poe is NOT NULL
"""

df_susp_inf = spark.sql(query)
df_susp_inf.registerTempTable('susp_inf')

In [None]:
#Create Cohort
#Code Reference: https://github.com/MIT-LCP/mimic-code/blob/7ff270c7079a42621f6e011de6ce4ddc0f7fd45c/tutorials/cohort-selection.ipynb
#Code Reference: https://github.com/alistairewj/sepsis3-mimic/blob/master/query/tbls/cohort.sql
query = \
"""
WITH co AS
(
SELECT 
icu.subject_id
,icu.hadm_id
,icu.icustay_id
,icu.dbsource
,first_careunit
,los as icu_length_of_stay
,icu.intime
,icu.outtime
,DATEDIFF (icu.intime , pat.dob )/365 as age
,pat.gender
,adm.ethnicity
,adm.HAS_CHARTEVENTS_DATA
,RANK() OVER (PARTITION BY icu.subject_id ORDER BY icu.intime) AS icustay_id_order
FROM icustays icu
INNER JOIN patients pat ON icu.subject_id = pat.subject_id
INNER JOIN admissions adm ON icu.hadm_id = adm.hadm_id
--LIMIT 10
)
,serv AS
(
SELECT 
icu.*
,se.curr_service
,CASE
--WHEN curr_service like '%SURG' then 1
--WHEN curr_service = 'ORTHO' then 1
WHEN curr_service in ('CSURG','VSURG','TSURG') then 1
ELSE 0 END
as surgical
,RANK() OVER (PARTITION BY icu.hadm_id ORDER BY se.transfertime DESC) as rank
FROM icustays icu
LEFT JOIN services se ON icu.hadm_id = se.hadm_id
--AND se.transfertime < icu.intime + interval '12' hour
)

SELECT
co.*
,CASE
WHEN co.icu_length_of_stay < .5 then 1
ELSE 0 END
AS exclusion_los
,CASE
WHEN co.age <= 16 then 1
ELSE 0 END
AS exclusion_age
,CASE 
WHEN co.icustay_id_order != 1 THEN 1
ELSE 0 END 
AS exclusion_first_stay
,CASE
WHEN serv.surgical == 1 THEN 1
ELSE 0 END
as exclusion_surgical
,CASE
when co.dbsource != 'metavision' THEN 1
ELSE 0 END 
as exclusion_icu_db
,Case 
when co.HAS_CHARTEVENTS_DATA == 0 then 1
when co.intime is null then 1
when co.outtime is null then 1
else 0 end 
as exclusion_bad_data

,inf.suspected_infection_time_poe
,inf.suspected_infection_time_poe - INTERVAL 48 HOURS as inf_window_start
,inf.suspected_infection_time_poe + INTERVAL 24 HOURS as inf_window_end

--Exclude cases where there is no overlap with suspected infection window
,Case 
--The line below limits suspected infection that occur outside the end of the ICU Stay to ensure at least 12 hours of data are avaliable for SOFA score window
when (inf.suspected_infection_time_poe - INTERVAL 48 HOURS) > (co.outtime - INTERVAL 12 HOURS) Then 1
--The line below limits suspected infection time to within 12 hours prior to ICU stay to ensure enough data points for sofa window 
--Cuts out Noise for specificity as suspected infection time is often round to day and can lead to misleading no sepsis diagnosis ground truth
when (inf.suspected_infection_time_poe + INTERVAL 24 HOURS ) < (co.intime + INTERVAL 12 HOURS) Then 1
else 0 end 
as exclusion_sus_inf_window

FROM co
LEFT JOIN serv ON co.icustay_id = serv.icustay_id AND serv.rank = 1
LEFT JOIN susp_inf inf ON (co.icustay_id = inf.icustay_id)

"""    
df_cohort_no_exclusion = spark.sql(query)

In [None]:
#Filter Cohort for Exclusions
df_cohort_exclusion = df_cohort_no_exclusion.filter((df_cohort_no_exclusion.exclusion_age == 0)
                                                    & (df_cohort_no_exclusion.exclusion_first_stay == 0)
                                                    & (df_cohort_no_exclusion.exclusion_surgical == 0) 
                                                    & (df_cohort_no_exclusion.exclusion_icu_db == 0)
                                                    & (df_cohort_no_exclusion.exclusion_bad_data == 0)
                                                    & (df_cohort_no_exclusion.exclusion_sus_inf_window == 0)
                                                    & (df_cohort_no_exclusion.exclusion_los == 0)) 

df_cohort_exclusion.registerTempTable('cohort_exclusion')
df_cohort_exclusion.count()

8569

In [None]:
#Create TS data with an additional field of array time sequences by hour between the intime and outtime for a subjet_id, icustay_id
#Code Reference: https://stackoverflow.com/questions/43141671/sparksql-on-pyspark-how-to-generate-time-series
query = \
"""
SELECT
ce.*
,DATE_TRUNC('hour', ce.intime) as intime_round
,DATE_TRUNC('hour', ce.outtime)+ INTERVAL 1 HOURS as outtime_round
,sequence(to_timestamp(DATE_TRUNC('hour', ce.intime)), to_timestamp(DATE_TRUNC('hour', ce.outtime)), interval 1 hour) as time

FROM cohort_exclusion ce
"""
df_cohort_exclusion_ts = sqlContext.sql(query)

In [None]:
#Explode the array field for each row into multiple rows to create a time series template
#df_cohort_exclusion_ts = df_cohort_exclusion_ts.withColumn("timestamp", explode(col("time"))).drop(col("time"))
df_cohort_exclusion_ts = df_cohort_exclusion_ts.select("*", posexplode(col("time"))).drop(col("time"))
df_cohort_exclusion_ts = df_cohort_exclusion_ts.withColumnRenamed('col', 'timestamp')
df_cohort_exclusion_ts = df_cohort_exclusion_ts.withColumnRenamed('pos', 'hour')

#Register Table for querying
df_cohort_exclusion_ts.registerTempTable('cohort_exclusion_ts')

In [None]:
#Cleanse sofa prior to joining
query = \
"""
select 
    s.icustay_id
    ,s.hr
    ,s.starttime
    ,s.endtime
    ,s.sofa_24hours
    
from sofa s 
"""    

df_sofa_cleansed = sqlContext.sql(query)

In [None]:
#Impute Last and First value for missing sofa 24 hour scores

#Fill in Missing Values with Last Value if avaliable followed by latest value for missing preonset data
#Code Reference: Paul Lee's Lab Notebook, https://stackoverflow.com/questions/38131982/forward-fill-missing-values-in-spark-python
window = Window.partitionBy('icustay_id')\
       .orderBy('hr')\
       .rowsBetween(-1000000, 0)

#colsfill = ['v_heartrate', 'v_sysbp', 'v_diasbp', 'v_meanbp', 'v_resprate', 'v_tempc', 'v_spo2', 'v_glucose']
colsfill = ['sofa_24hours']
            
for col in colsfill:
    df_sofa_cleansed = df_sofa_cleansed.withColumn(col, last(col,ignorenulls = True).over(window))   

window = Window.partitionBy('icustay_id')\
       .orderBy('hr')\
       .rowsBetween(0, Window.unboundedFollowing)

for col in colsfill:
    df_sofa_cleansed =df_sofa_cleansed.withColumn(col, first(col,ignorenulls = True).over(window))   
    

df_sofa_cleansed.registerTempTable('sofa_cleansed')

In [None]:
#JOIN SOFA to cohort ts

query = \
"""
select 
    ts.*
    --,s.hr
    --,s.starttime
    --,s.endtime
    ,s.sofa_24hours
    
from cohort_exclusion_ts ts 
        LEFT JOIN sofa_cleansed s ON (ts.icustay_id = s.icustay_id) AND (ts.timestamp = s.starttime)
 
where 1=1
"""    

df_cohort_cleansed = sqlContext.sql(query)
df_cohort_cleansed.registerTempTable('cohort_cleansed')

In [None]:
#Determines Diagnosis Sepsis-3 based on SOFA >= 2 over 48 hours prior and 24 hours after suspected infection time limited by icu stay time window. 
#Seeks increase based on first sofa 24 hour value in time window.
query = \
"""
With Inf_W_CTE AS (
--Flags Window over Time Series for suspected infection: 48 hrs preceding and 24 post suspicion
Select 
    cc.*
    ,CASE WHEN timestamp BETWEEN cc.inf_window_start AND cc.inf_window_end - INTERVAL 1 HOURS Then 1
          ELSE 0 
          END AS sus_window_flg
    
from cohort_cleansed as cc
)

,DIAG_CTE as (
--Compares sofa_24hours score with min. score prior to current row and flags when a geq +2 change occurs 
Select
    W.icustay_id
    ,W.hour
    ,W.sus_window_flg
    ,sofa_24hours - (Min(sofa_24hours) OVER(PARTITION BY icustay_id ORDER BY hour ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING)) as sofa_24hours_delta
    ,CASE WHEN (sofa_24hours - (Min(sofa_24hours) OVER(PARTITION BY icustay_id ORDER BY hour ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING))) >= 2 Then 1
          ELSE 0
          END AS Sepsis3_start_flg
          
From INF_W_CTE as W
Where W.sus_window_flg = 1
)

Select 
    cc.*
    --,d.sus_window_flg
    --,d.sofa_24hours_delta
    ,d.Sepsis3_start_flg
    ,CASE WHEN (SUM(d.Sepsis3_start_flg) OVER (PARTITION BY cc.icustay_id) >= 1) THEN 1 ELSE 0 END as Sepsis3_diag_flg 
FROM cohort_cleansed as cc
LEFT JOIN DIAG_CTE as d ON (cc.icustay_id = d.icustay_id) AND (cc.hour = d.hour)
"""  

df_cohort_diag = sqlContext.sql(query)
df_cohort_diag.registerTempTable('cohort_diag')

In [None]:
#Find Sepsis Onset Hour

query = \
"""
With first_cte as (

    Select
        cd.icustay_id
        ,cd.hour
        ,ROW_NUMBER() OVER(PARTITION BY cd.icustay_id order by cd.hour) as rn

    from cohort_diag as cd
    Where Sepsis3_start_flg = 1
)

select
cd.*
,fc.hour as sepsis_onset_hr
from cohort_diag cd left join first_cte fc ON (cd.icustay_id = fc.icustay_id) AND (cd.hour = fc.hour) AND fc.rn = 1 
"""  
df_cohort_diag_onset = sqlContext.sql(query)

In [None]:
#Impute Last and First value for sepsis onset hour

#Fill in Missing Values with Last Value if avaliable followed by latest value for missing preonset data
#Code Reference: Paul Lee's Lab Notebook, https://stackoverflow.com/questions/38131982/forward-fill-missing-values-in-spark-python
window = Window.partitionBy('icustay_id')\
       .orderBy('hour')\
       .rowsBetween(-1000000, 0)

colsfill = ['sepsis_onset_hr']
            
for col in colsfill:
    df_cohort_diag_onset = df_cohort_diag_onset.withColumn(col, last(col,ignorenulls = True).over(window))   

window = Window.partitionBy('icustay_id')\
       .orderBy('hour')\
       .rowsBetween(0, Window.unboundedFollowing)

for col in colsfill:
    df_cohort_diag_onset =df_cohort_diag_onset.withColumn(col, first(col,ignorenulls = True).over(window))   
    
df_cohort_diag_onset.registerTempTable('cohort_diag_onset')


In [None]:
#Determines obsevation and prediction time windows
query = \
"""
with window_cte as (
    select icustay_id
           --,round(((max(icu_length_of_stay)*24)/4),0) as control_index_hr
           --,(round(((max(icu_length_of_stay)*24)/4),0) -12) as control_start_hr
           ,max(hour) as control_index_hr
           ,(max(hour) - 12) as control_start_hr
           ,max(sepsis_onset_hr - 3) as case_index_hr
           ,(max(sepsis_onset_hr) -3 - 12) as case_start_hr
           
    from cohort_diag_onset 
    --where sepsis_onset_hr is NULL
    group by icustay_id
)

select cdo.* 
from cohort_diag_onset as cdo 
left join window_cte as cc on cdo.icustay_id = cc.icustay_id

where 1=1
      and (((hour <= case_index_hr) and (hour > case_start_hr)) 
             or ((sepsis_onset_hr is NULL) and (hour <= control_index_hr) and (hour > control_start_hr))) 
"""

df_cohort_diag_onset_final = sqlContext.sql(query)

In [None]:
#Balance Dataset for better model performance
pandasDF = df_cohort_diag_onset_final.toPandas()

icustay_id_sepsis3 = pandasDF[pandasDF['Sepsis3_diag_flg'] == 1]
print(icustay_id_sepsis3['icustay_id'].nunique())

icustay_id_nonsepsis3 = pandasDF[pandasDF['Sepsis3_diag_flg'] != 1]
print(icustay_id_nonsepsis3['icustay_id'].nunique())

l_icustay_id_nonsepsis3 = list(icustay_id_nonsepsis3['icustay_id'].unique())
print(len(l_icustay_id_nonsepsis3))

random.seed(10)
l_icustay_id_nonsepsis3_sample = sample(l_icustay_id_nonsepsis3,icustay_id_sepsis3['icustay_id'].nunique())
print(len(l_icustay_id_nonsepsis3_sample))

idx = icustay_id_nonsepsis3.icustay_id.isin(l_icustay_id_nonsepsis3_sample)
icustay_id_nonsepsis3_sample = icustay_id_nonsepsis3[idx]
print(icustay_id_nonsepsis3_sample['icustay_id'].nunique())

df_final = icustay_id_sepsis3.append(icustay_id_nonsepsis3_sample, ignore_index=True)
print((df_final['icustay_id'].nunique()))

df_final.to_csv('gs://peaceful-bruin-307600/cohort_v1.csv',index=False)