### This is a copu of ml-data-preprocessing, written to help with scaffolding for the query multiplexer.
### 02-ml-data-preprocessing is still the main preprocessing file

### ML Data pre-processing
This notebook is for loading and cleaning the data that will be used to train the ML on.
Things like patient heart rate and blood pressure readings that occurred around the time of the administration of the second dose 

It should persist the data into the "out" directory to be consumed by the ml training notebook

In [1]:
import root_config as rc
from detectdd import config
import pandas as pd

rc.configure()

from detectdd.auth_bigquery import BigQueryClient
from detectdd.serializer import Serializer

print("Loading cohort")

try:
    serializer = Serializer()
    cohort_with_icd = serializer.read_cohort()  # need to run 01-cohort.ipynb to produce the cohort
    print(cohort_with_icd.describe())
    print(len(cohort_with_icd))
    cohort_without_icd = serializer.read_cohort_with_no_icd()
    print(len(cohort_without_icd))
    cohort = pd.concat([cohort_with_icd.head(10000), cohort_without_icd.head(10000)])
except FileNotFoundError:
    raise Exception("Need to run [01-cohort.ipynb] at least once to create the cohort file in the /out directory")

big_query = BigQueryClient.auth()


from detectdd.query_multiplexer import WhereClauseGenerator
from detectdd.query_multiplexer import QueryMultiplexer
import pandas as pd
from detectdd.auth_bigquery import BigQueryClient

cohort_with_no_ddi = pd.read_csv(config.out_dir / 'non-drug-interactions.csv')

cohort_with_no_ddi["dose_b_time"] = cohort_with_no_ddi["dose_b_time"].astype("datetime64[s]")

cohort_with_no_ddi.nunique()

# fetch this data set
data_cohort=cohort_with_icd
cohort_filename = "vitals_data_before_and_after.csv"

#data_cohort=cohort_with_no_ddi
#cohort_filename = "vitals_data_before_and_after_no_drug_interaction.csv"

Loading cohort
Loaded cohort from ..\out\cohort-full.out
            subject_id          hadm_id          stay_id  \
count           7356.0           7356.0           7356.0   
mean   14943046.671017  25021995.427814  34998651.263187   
min         10004733.0       20038242.0       30004144.0   
25%         12474247.0       22532253.0       32482959.0   
50%         14886080.0       25085291.0       34996638.0   
75%         17444849.0       27536715.0       37456269.0   
max         19983257.0       29996046.0       39999230.0   
std     2871783.101868   2893328.688213   2905726.423783   

                      dose_b_time  event_count  admin_count  num_icd_codes  
count                        7356       7356.0  7356.000000         7356.0  
mean   2153-12-28 12:40:12.960848     7.245922     4.754486        1.40348  
min           2110-02-10 22:06:00          0.0     1.000000            1.0  
25%           2133-04-08 13:30:00          2.0     1.000000            1.0  
50%           215

In [2]:

# Assuming you have a Serializer class that handles reading your saved cohort data
serializer = Serializer()

# Extract unique subject_ids from the cohort data
subject_ids = data_cohort['subject_id'].unique()

# Convert the list of subject_ids to a format suitable for SQL query
subject_id_str = ', '.join([str(id) for id in subject_ids])
# print(subject_id_str)
# Now, let's proceed to fetch the vital signs for these subject_ids from MIMIC

query_multiplexer = QueryMultiplexer(big_query)

# Write a SQL query to fetch the required vitals where the subject_ids are in your cohort
query = """
SELECT stay_id, subject_id, charttime, heart_rate, sbp, dbp, mbp
FROM `physionet-data.mimiciv_derived.vitalsign`
WHERE ($where) 
    AND (heart_rate IS NOT NULL OR sbp IS NOT NULL OR dbp IS NOT NULL OR mbp IS NOT NULL)
"""

# query = f"""
# SELECT subject_id, heart_rate, sbp, dbp, mbp
# FROM `physionet-data.mimiciv_derived.vitalsign`
# WHERE subject_id IN ({subject_id_str}) limit 100"""

where_fragment = "(stay_id= $stay_id AND charttime > DATETIME_ADD('$dose_b_time', INTERVAL -720 MINUTE) AND charttime < DATETIME_ADD('$dose_b_time', INTERVAL 720 MINUTE))"

multimap_data = {k: v.tolist() for k, v in data_cohort.groupby('stay_id')['dose_b_time']}
results = query_multiplexer.multiplex_query(query, multi_map_data=multimap_data,
                                            where_clause=WhereClauseGenerator(where_fragment, "stay_id", "dose_b_time"))

Executing query 1, with 1697 pairs at 2023-11-01 17:40:35.444264
Partitioning key value pairs 1697
Number of partitions 6 with partition_size 282.8333333333333
Got result with 9084 values
Got result with 9634 values
Got result with 9594 values
Got result with 9648 values
Got result with 9445 values
Got result with 9372 values
Executing query 2, with 1194 pairs at 2023-11-01 17:42:29.327415
Single partition
Got result with 39873 values
Executing query 3, with 896 pairs at 2023-11-01 17:43:59.865507
Single partition
Got result with 29777 values
Executing query 4, with 690 pairs at 2023-11-01 17:45:01.314372
Single partition
Got result with 23420 values
Executing query 5, with 546 pairs at 2023-11-01 17:45:45.149220
Single partition
Got result with 17791 values
Executing query 6, with 460 pairs at 2023-11-01 17:46:20.825052
Single partition
Got result with 15077 values
Executing query 7, with 384 pairs at 2023-11-01 17:46:49.242237
Single partition
Got result with 12497 values
Executing q

In [3]:
# Run the query
vitals_data = results
vitals_data.describe()

Unnamed: 0,dose_b_time,subject_id,charttime,heart_rate,sbp,dbp,mbp
count,243186,243186.0,243186,187107.0,190323.0,190279.0,190385.0
mean,2154-03-31 06:27:29.784363008,14989478.751289,2154-03-31 06:28:17.823888,90.283604,116.679622,61.421599,77.438982
min,2110-02-10 22:06:00,10004733.0,2110-02-10 10:15:00,5.0,8.0,1.0,1.0
25%,2133-10-26 10:05:00,12492737.0,2133-10-26 10:33:45,77.0,101.0,52.0,67.0
50%,2154-01-27 08:30:00,14998916.0,2154-01-27 00:57:30,89.0,114.0,60.0,75.0
75%,2176-02-10 05:12:00,17529736.0,2176-02-09 23:00:00,102.0,130.0,69.0,86.0
max,2209-05-30 02:04:00,19983257.0,2209-05-30 14:00:00,217.0,329.0,290.0,299.0
std,,2881415.701602,,19.654878,21.948004,14.326254,15.926375


In [4]:
vitals_data.to_csv(config.out_dir / cohort_filename)