In [None]:
import duckdb
import polars as pl
import polars.selectors as cs

# Database Setup

Follow instructions [here](https://github.com/MIT-LCP/mimic-code/tree/main/mimic-iv/buildmimic/duckdb) to create a DuckDB database file.

In [None]:
con = duckdb.connect("mimic4.db")

# Query MIMIC-IV dataset
def q(query: str, con: duckdb.duckdb.DuckDBPyConnection = con):
    return con.sql(query)

# View column information for a table
def table_info(table: str, con: duckdb.duckdb.DuckDBPyConnection = con):
    return con.sql(f"PRAGMA table_info('mimiciv_hosp.{table}')")

# Count number of lab events for itemid(s)
def lab_count(itemid, con: duckdb.duckdb.DuckDBPyConnection = con):
    if isinstance(itemid, int):
        return con.sql(f"SELECT COUNT(*) FROM mimiciv_hosp.labevents WHERE itemid = {itemid}")
    else:
        return con.sql(f"SELECT itemid, COUNT(*) FROM mimiciv_hosp.labevents WHERE itemid IN {tuple(itemid)} GROUP BY itemid")

# View units of lab events for itemid(s)
def lab_unit(itemid, con: duckdb.duckdb.DuckDBPyConnection = con):
    if isinstance(itemid, int):
        return(con.sql(f"SELECT DISTINCT valueuom FROM mimiciv_hosp.labevents WHERE itemid = {itemid}"))
    else:
        return(con.sql(f"SELECT DISTINCT valueuom FROM mimiciv_hosp.labevents WHERE itemid IN {tuple(itemid)}"))

# Get the last lab measurement for each admission for itemid(s)
def lab_table_last(itemid, colname: str):
    if isinstance(itemid, int):
        return q(f"""
            WITH measure AS (
                SELECT subject_id, charttime, valuenum,
                    ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY charttime) AS measure_num
                FROM mimiciv_hosp.labevents 
                WHERE itemid = {itemid} AND valuenum NOT NULL
            ),
            max_measure AS (
                SELECT *, MAX(measure_num) OVER (PARTITION BY subject_id) AS max_num
                FROM measure
            )
            SELECT subject_id, valuenum AS {colname}
            FROM max_measure
            WHERE measure_num = max_num
            """)
    else:
        return q(f"""
            WITH measure AS (
                SELECT subject_id, charttime, valuenum,
                    ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY charttime) AS measure_num
                FROM mimiciv_hosp.labevents 
                WHERE itemid IN {tuple(itemid)} AND valuenum NOT NULL
            ),
            max_measure AS (
                SELECT *, MAX(measure_num) OVER (PARTITION BY subject_id) AS max_num
                FROM measure
            )
            SELECT subject_id, valuenum AS {colname}
            FROM max_measure
            WHERE measure_num = max_num
            """)
    
# Get the last lab measurement for each admission for itemid(s), and convert WBC counts from # cells/uL to 10^3 cells/uL
def lab_table_last_wbc(itemid, colname: str):
    if isinstance(itemid, int):
        return q(f"""
            WITH measure AS (
                SELECT subject_id, charttime, 
                    IF(valueuom = '#/uL', valuenum/1000, valuenum) AS valuenum,
                    ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY charttime) AS measure_num
                FROM mimiciv_hosp.labevents 
                WHERE itemid = {itemid} AND valuenum NOT NULL
            ),
            max_measure AS (
                SELECT *, MAX(measure_num) OVER (PARTITION BY subject_id) AS max_num
                FROM measure
            )
            SELECT subject_id, valuenum AS {colname}
            FROM max_measure
            WHERE measure_num = max_num
            """)
    else:
        return q(f"""
            WITH measure AS (
                SELECT subject_id, charttime, 
                    IF(valueuom = '#/uL', valuenum/1000, valuenum) AS valuenum,
                    ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY charttime) AS measure_num
                FROM mimiciv_hosp.labevents 
                WHERE itemid IN {tuple(itemid)} AND valuenum NOT NULL
            ),
            max_measure AS (
                SELECT *, MAX(measure_num) OVER (PARTITION BY subject_id) AS max_num
                FROM measure
            )
            SELECT subject_id, valuenum AS {colname}
            FROM max_measure
            WHERE measure_num = max_num
            """)


# Create Dataset

## Demographics

In [None]:
# age = admission time - anchor_year + anchor_age
# Unknown race: 'PATIENT DECLINED TO ANSWER', 'UNABLE TO OBTAIN', 'UNKNOWN', null
demographics = q(
"""
WITH demo AS (
    SELECT subject_id, gender, hadm_id, admittime, insurance, language, 
        IF(marital_status IS NULL, 'UNKNOWN', marital_status) AS marital_status,
        IF(race IN ('PATIENT DECLINED TO ANSWER', 'UNABLE TO OBTAIN', 'UNKNOWN', NULL), 'UNKNOWN', race) AS race,
        anchor_age + DATE_PART('year', AGE(admittime, MAKE_TIMESTAMP(anchor_year, 1, 1, 0, 0, 0))) AS age,
    FROM mimiciv_hosp.admissions
    LEFT JOIN mimiciv_hosp.patients
    USING (subject_id)
),
demo_max_time AS (
    SELECT *, MAX(admittime) OVER (PARTITION BY subject_id) AS max_time
    FROM demo
)
SELECT subject_id, gender, insurance, language, marital_status, race, age
FROM demo_max_time
WHERE admittime = max_time;
""")
demographics

## Diagnoses

In [None]:
# Create indicator variables for whether a patient had a diagnosis for a disease
# 0 means no diagnosis, 1 means there was a diagnosis.
diagnoses = q("""
SELECT subject_id,
MAX(CASE 
        WHEN icd_version = 9 AND icd_code LIKE '401%' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'I10%' THEN 1
        ELSE 0
    END
) AS hypertension,

MAX(CASE 
        WHEN icd_version = 9 AND icd_code = '42731' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'I48%' THEN 1
        ELSE 0
    END 
) AS atrial_fibrillation,

MAX(CASE
        WHEN icd_version = 9 AND icd_code LIKE '250_1' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'E10%' THEN 1
        ELSE 0
    END
) AS diabetes_type_1,

MAX(CASE
        WHEN icd_version = 9 AND icd_code LIKE '250_0' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'E11%' THEN 1
        ELSE 0
    END
) AS diabetes_type_2,

MAX(CASE
        WHEN icd_version = 9 AND icd_code LIKE '496%' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'J44%' THEN 1
        ELSE 0
    END
) AS copd,

MAX(CASE
        WHEN icd_version = 9 AND icd_code LIKE '493%' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'J45%' THEN 1
        ELSE 0
    END
) AS asthma,

MAX(CASE
        WHEN icd_version = 9 AND icd_code LIKE '571%' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'K7%' THEN 1
        ELSE 0
    END
) AS liver_disease,

MAX(CASE
        WHEN icd_version = 9 AND (
                icd_code LIKE '403%' OR 
                icd_code LIKE '404%' OR 
                icd_code LIKE '585%') THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'N18%' THEN 1
        ELSE 0
    END
) AS ckd,

MAX(CASE 
        WHEN icd_version = 9 AND (
                icd_code LIKE '14%' OR 
                icd_code LIKE '15%' OR 
                icd_code LIKE '16%' OR 
                icd_code LIKE '17%' OR 
                icd_code LIKE '18%' OR 
                icd_code LIKE '19%' OR 
                icd_code LIKE '20%' OR 
                icd_code LIKE '21%' OR 
                icd_code LIKE '22%' OR 
                icd_code LIKE '23%') THEN 1
        WHEN icd_version = 10 AND (
                icd_code LIKE 'C%' OR 
                icd_code LIKE 'D0%' OR 
                icd_code LIKE 'D1%' OR
                icd_code LIKE 'D2%' OR 
                icd_code LIKE 'D3%' OR
                icd_code LIKE 'D4%') THEN 1
        ELSE 0
    END
) AS cancer,

MAX(CASE 
        WHEN icd_version = 9 AND (
                icd_code LIKE '2962%' OR 
                icd_code LIKE '2963%' OR 
                icd_code LIKE '311%') THEN 1
        WHEN icd_version = 10 AND (
                icd_code LIKE 'F32%' OR
                icd_code LIKE 'F33%') THEN 1
        ELSE 0
    END
) AS depression,

MAX(CASE
        WHEN icd_version = 9 AND icd_code LIKE '715%' THEN 1
        WHEN icd_version = 10 AND (
                icd_code LIKE 'M15%' OR
                icd_code LIKE 'M16%' OR
                icd_code LIKE 'M17%' OR
                icd_code LIKE 'M18%' OR
                icd_code LIKE 'M19%') THEN 1
        ELSE 0
    END
) AS osteoarthritis,

MAX(CASE
        WHEN icd_version = 9 AND (
                icd_code LIKE '280%' OR
                icd_code LIKE '281%' OR
                icd_code LIKE '282%' OR
                icd_code LIKE '283%' OR
                icd_code LIKE '284%' OR
                icd_code LIKE '285%') THEN 1
        WHEN icd_version = 10 AND (
                icd_code LIKE 'D5%' OR
                icd_code LIKE 'D60%' OR
                icd_code LIKE 'D61%' OR
                icd_code LIKE 'D62%' OR
                icd_code LIKE 'D63%' OR
                icd_code LIKE 'D64%') THEN 1
        ELSE 0
    END
) AS anemia

FROM mimiciv_hosp.diagnoses_icd
GROUP BY subject_id;
""")
diagnoses

In [None]:
# Indicator variable for whether heart failure was diagnosed at the most recent visit
heart_failure = q("""
WITH visits AS (
  SELECT subject_id, hadm_id, icd_code, icd_version, admittime
  FROM mimiciv_hosp.diagnoses_icd
  LEFT JOIN mimiciv_hosp.admissions
  USING (subject_id, hadm_id)
),
last_visit_time AS (
  SELECT *,
    MAX(admittime) OVER (PARTITION BY subject_id) AS last_visit_time
  FROM visits
),
last_visit AS (
  SELECT subject_id, icd_code, icd_version
  FROM last_visit_time
  WHERE admittime = last_visit_time
)
SELECT subject_id, 
  MAX(CASE 
        WHEN icd_version = 9 AND icd_code LIKE '428%' THEN 1
        WHEN icd_version = 10 AND icd_code LIKE 'I50%' THEN 1
        ELSE 0
    END
  ) AS heart_failure
FROM last_visit
GROUP BY subject_id;
""")
heart_failure

## Medications

In [None]:
# Create indicator variables for whether a patient was ever prescribed a medication. Does not consider time of prescription.
medications = q("""
WITH visits AS (
  SELECT subject_id, admittime
  FROM mimiciv_hosp.diagnoses_icd
  LEFT JOIN mimiciv_hosp.admissions
  USING (subject_id, hadm_id)
),
last_visit_time AS (
  SELECT *,
    MAX(admittime) OVER (PARTITION BY subject_id) AS last_visit_time
  FROM visits
),
last_visit AS (
  SELECT subject_id, admittime
  FROM last_visit_time
  WHERE admittime = last_visit_time
)
SELECT ph.subject_id,
    MAX(IF(LOWER(medication) LIKE '%enalapril%', 1, 0))             AS enalapril,
    MAX(IF(LOWER(medication) LIKE '%lisinopril%', 1, 0))            AS lisinopril,
    MAX(IF(LOWER(medication) LIKE '%ramipril%', 1, 0))              AS ramipril,
    
    MAX(IF(LOWER(medication) LIKE '%carvedilol%', 1, 0))            AS carvedilol,
    MAX(IF(LOWER(medication) LIKE '%metoprolol succinate%', 1, 0))  AS metoprolol_succinate,
    MAX(IF(LOWER(medication) LIKE '%bisoprolol%', 1, 0))            AS bisoprolol,
    
    MAX(IF(LOWER(medication) LIKE '%furosemide%', 1, 0))            AS furosemide,
    MAX(IF(LOWER(medication) LIKE '%bumetanide%', 1, 0))            AS bumetanide,
    MAX(IF(LOWER(medication) LIKE '%spironolactone%', 1, 0))        AS spironolactone,
    
    MAX(IF(LOWER(medication) LIKE '%warfarin%', 1, 0))              AS warfarin,
    MAX(IF(LOWER(medication) LIKE '%apixaban%', 1, 0))              AS apixaban,
    MAX(IF(LOWER(medication) LIKE '%rivaroxaban%', 1, 0))           AS rivaroxaban

FROM last_visit AS la
LEFT JOIN mimiciv_hosp.pharmacy AS ph
    ON ph.subject_id = la.subject_id
    AND ph.starttime < la.admittime
GROUP BY ph.subject_id;
""")

In [None]:
# Create indicator variables for whether a patient was ever prescribed a medication. Does not consider time of prescription.
medications = q("""
SELECT subject_id,
    MAX(IF(LOWER(medication) LIKE '%enalapril%', 1, 0))             AS enalapril,
    MAX(IF(LOWER(medication) LIKE '%lisinopril%', 1, 0))            AS lisinopril,
    MAX(IF(LOWER(medication) LIKE '%ramipril%', 1, 0))              AS ramipril,
    
    MAX(IF(LOWER(medication) LIKE '%carvedilol%', 1, 0))            AS carvedilol,
    MAX(IF(LOWER(medication) LIKE '%metoprolol succinate%', 1, 0))  AS metoprolol_succinate,
    MAX(IF(LOWER(medication) LIKE '%bisoprolol%', 1, 0))            AS bisoprolol,
    
    MAX(IF(LOWER(medication) LIKE '%furosemide%', 1, 0))            AS furosemide,
    MAX(IF(LOWER(medication) LIKE '%bumetanide%', 1, 0))            AS bumetanide,
    MAX(IF(LOWER(medication) LIKE '%spironolactone%', 1, 0))        AS spironolactone,
    
    MAX(IF(LOWER(medication) LIKE '%warfarin%', 1, 0))              AS warfarin,
    MAX(IF(LOWER(medication) LIKE '%apixaban%', 1, 0))              AS apixaban,
    MAX(IF(LOWER(medication) LIKE '%rivaroxaban%', 1, 0))           AS rivaroxaban

FROM mimiciv_hosp.pharmacy
GROUP BY subject_id;
""")
medications

## Vital Signs / Lab Results

### Lab Result Keys

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%oxygen%'
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%bnp%' AND fluid = 'Blood'
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%creatinine%' AND (fluid = 'Blood' OR LOWER(label) LIKE '%serum%');
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%urea%' AND fluid = 'Blood';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%sodium%' AND fluid = 'Blood';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%potassium%' AND fluid = 'Blood';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%asparate%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%alanine%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%tropo%'
ORDER BY label;
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE 'hemoglobin' AND fluid = 'Blood';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE (LOWER(label) LIKE '%hematocrit%' OR LOWER(label) LIKE '%hct%') AND fluid = 'Blood';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%mcv%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE 'mch';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE 'mchc';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE 'rdw';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE LOWER(label) LIKE '%platelet count%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE '%white%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE 'red blood cells';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE '%neutrophil%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE '%lympho%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE '%monocyte%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE '%eosinophil%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE '%basophil%';
""")

In [None]:
q("""
SELECT * 
FROM mimiciv_hosp.d_labitems 
WHERE fluid = 'Blood' AND LOWER(label) LIKE '%immature%';
""")

### Blood Pressure

In [None]:
blood_pressure = q("""
WITH subject_measure AS (
    SELECT subject_id, chartdate, seq_num, result_value, 
        ROW_NUMBER() OVER (PARTITION BY subject_id ORDER BY chartdate, seq_num) AS measure_num
    FROM mimiciv_hosp.omr
    WHERE LOWER(result_name) LIKE 'blood pressure%' AND result_value NOT NULL
),
max_measure AS (
    SELECT *, MAX(measure_num) OVER (PARTITION BY subject_id) AS max_num
    FROM subject_measure
),
last_measure AS (
    SELECT *, subject_id, result_value
    FROM max_measure
    WHERE measure_num = max_num
)
SELECT subject_id,
    CAST(regexp_extract(result_value, '(\\d+)/\\d+', 1) AS SMALLINT) AS bp_systolic,
    CAST(regexp_extract(result_value, '\\d+/(\\d+)', 1) AS SMALLINT) AS bp_diastolic
FROM last_measure
""")
blood_pressure

### Oxygen

In [None]:
oxygen = lab_table_last((50816, 50817), "oxygen")
oxygen

### NT-proBNP

In [None]:
nt_probnp = lab_table_last(50963, "nt_probnp")
nt_probnp

### Creatinine

In [None]:
creatinine = lab_table_last((50912, 52024, 52546, 51081), "creatinine")
creatinine

### Blood Urea Nitrogen

In [None]:
bun = lab_table_last((51006, 52647), "bun")
bun

### Sodium

In [None]:
sodium = lab_table_last((50824, 50983, 52455, 52623), "sodium")
sodium

### Potassium

In [None]:
potassium = lab_table_last((50822, 50971, 52452, 52610), "potassium")
potassium

### AST

In [None]:
ast = lab_table_last((50878, 53088), "ast")
ast

### ALT

In [None]:
alt = lab_table_last((50861, 53084), "alt")
alt

### Troponin

In [None]:
troponin = lab_table_last(51003, "troponin")
troponin

### Complete Blood Count

#### Hemoglobin

In [None]:
hgb = lab_table_last((50811, 51222, 51640), "hgb")
hgb

#### Hematocrit

In [None]:
hct = lab_table_last((50810, 51221, 51638, 51639, 52028), "hct")
hct

#### MCV

In [None]:
mcv = lab_table_last((51250, 51691), "mcv")
mcv

#### MCH

In [None]:
mch = lab_table_last(51248, "mch")
mch

#### MCHC

In [None]:
mchc = lab_table_last(51249, "mchc")
mchc

#### RDW

In [None]:
rdw = lab_table_last(51277, "rdw")
rdw

#### Platelet Count

In [None]:
plt = lab_table_last((51265, 51704), "plt")
plt

#### White Blood Cell Count

In [None]:
wbc = lab_table_last((51301, 51755, 51756), "wbc")
wbc

#### Red Blood Cell Count

In [None]:
rbc = lab_table_last(51279, "rbc")
rbc

#### Neutrophil

In [None]:
neutrophil_p = lab_table_last(51256, "neutrophil_p")
neutrophil_c = lab_table_last_wbc((52075, 53133), "neutrophil_c")

#### Lymphocyte

In [None]:
lymphocyte_p = lab_table_last((51244, 51245, 51690), "lymphocyte_p")
lymphocyte_c = lab_table_last_wbc((51133, 52769, 53132), "lymphocyte_c")

#### Monocyte

In [None]:
monocyte_p = lab_table_last(51254, "monocyte_p")
monocyte_c = lab_table_last_wbc((51253, 52074), "monocyte_c")

#### Eosinophil

In [None]:
eosinophil_p = lab_table_last(51200, "eosinophil_p")
eosinophil_c = lab_table_last_wbc((51199, 52073), "eosinophil_c")

#### Basophil

In [None]:
basophil_p = lab_table_last(51146, "basophil_p")
basophil_c = lab_table_last_wbc(52069, "basophil_c")

#### Immature Granulocytes

In [None]:
immature_gran_p = lab_table_last(52135, "immature_gran_p")

# Join Tables

In [None]:
final = q("""
SELECT *
FROM heart_failure
    LEFT JOIN diagnoses
    USING (subject_id)

    LEFT JOIN demographics
    USING (subject_id)
    
    LEFT JOIN medications
    USING (subject_id)
    
    LEFT JOIN blood_pressure
    USING (subject_id)
    
    LEFT JOIN oxygen
    USING (subject_id)
          
    LEFT JOIN nt_probnp
    USING (subject_id)
    
    LEFT JOIN creatinine
    USING (subject_id)
          
    LEFT JOIN bun
    USING (subject_id)
    
    LEFT JOIN sodium
    USING (subject_id)
    
    LEFT JOIN potassium
    USING (subject_id)
    
    LEFT JOIN ast
    USING (subject_id)
    
    LEFT JOIN alt
    USING (subject_id)
    
    LEFT JOIN troponin
    USING (subject_id)
    
    LEFT JOIN hgb
    USING (subject_id)
    
    LEFT JOIN hct
    USING (subject_id)
    
    LEFT JOIN mcv
    USING (subject_id)
    
    LEFT JOIN mch
    USING (subject_id)
    
    LEFT JOIN mchc
    USING (subject_id)
    
    LEFT JOIN rdw
    USING (subject_id)
    
    LEFT JOIN plt
    USING (subject_id)
    
    LEFT JOIN wbc
    USING (subject_id)
    
    LEFT JOIN rbc
    USING (subject_id)
    
    LEFT JOIN neutrophil_p
    USING (subject_id)
    
    LEFT JOIN neutrophil_c
    USING (subject_id)
    
    LEFT JOIN lymphocyte_p
    USING (subject_id)
    
    LEFT JOIN lymphocyte_c
    USING (subject_id)
    
    LEFT JOIN monocyte_p
    USING (subject_id)
    
    LEFT JOIN monocyte_c
    USING (subject_id)
    
    LEFT JOIN eosinophil_p
    USING (subject_id)
    
    LEFT JOIN eosinophil_c
    USING (subject_id)
    
    LEFT JOIN basophil_p
    USING (subject_id)
    
    LEFT JOIN basophil_c
    USING (subject_id)
    
    LEFT JOIN immature_gran_p
    USING (subject_id)
;""")

In [None]:
q("""
COPY final
    TO 'mimic4_hosp_dataset_patients1.parquet'
    (FORMAT 'parquet', CODEC 'zstd');
""")

# Clean Data

In [None]:
import polars as pl
import polars.selectors as cs
import pandas as pd

In [None]:
data = pl.read_parquet("mimic4_hosp_dataset_patients1.parquet")
# mimic4_hosp_dataset_patients.parquet: converted dtypes

In [None]:
medications = pl.read_parquet("medications.parquet")

In [None]:
data.select(medications.columns)

In [None]:
medications

In [None]:
data.describe()

In [None]:
data2 = data.cast(
    {cs.integer() & ~cs.contains("_id") & ~cs.contains("bp_"): pl.Int8}
).drop(
    # Columns with excessive missing values
    cs.contains("_c") | pl.col("immature_gran_p") | pl.col("nt_probnp") | pl.col("oxygen") | pl.col("troponin")
).with_columns(
    # Remove implausible values
    #pl.col("oxygen").map_elements(lambda x: pl.Null if x > 100 else x, return_dtype=pl.Float64),
    pl.col("bp_systolic").map_elements(lambda x: pl.Null if x < 60 else x, return_dtype=pl.Int16),
    pl.col("bp_diastolic").map_elements(lambda x: pl.Null if x < 40 else x, return_dtype=pl.Int16),
    
    # Collapse race categories using US Census categories
    pl.col("race"
        ).str.replace(r'WHITE.*|PORTUGUESE', 'WHITE/EUROPEAN'
        ).str.replace(r'ASIAN.*', 'ASIAN'
        ).str.replace(r'BLACK.*', 'BLACK'
        ).str.replace(r'HISPANIC.*|SOUTH AMERICAN.*', 'HISPANIC/LATINO/SOUTH AMERICAN'
        ).str.replace(r'OTHER.*|MULTIPLE.*', 'OTHER'
        ).str.replace(r'AMERICAN INDIAN.*|NATIVE.*', 'NATIVE AMERICAN/PACIFIC ISLANDER'
    )
).select(
    cs.string(),
    cs.integer(),
    cs.float()
)

In [None]:
data2.describe()

In [None]:
data2.write_parquet("mimic4-hosp_hf-cleaned.parquet")

# Multiple Imputation

In [None]:
import pandas as pd
import polars as pl
import polars.selectors as cs
import miceforest as mf

In [None]:
data2 = pl.read_parquet("mimic4-hosp_hf-cleaned.parquet")

data3 = data2.with_columns(
    cs.string().cast(pl.Categorical),
    (cs.integer() & ~cs.starts_with("bp_") & ~cs.starts_with("age") & ~cs.starts_with("subject_id")).cast(pl.String).cast(pl.Categorical),
).to_pandas()

data3.to_parquet("mimic4-hosp_hf-cleaned_categorical.parquet")

In [None]:
data3 = pd.read_parquet("mimic4-hosp_hf-cleaned_categorical.parquet")

In [None]:
data3.head()

In [None]:
data_amp = mf.ampute_data(data3.drop(columns=['subject_id']), random_state=100)
kernel = mf.ImputationKernel(
  data_amp,
  datasets=5,
  random_state=100
)

In [None]:
kernel.mice(5, min_data_in_leaf=100)
kernel.save_kernel("miceforest_kernel")

In [None]:
kernel = mf.load_kernel("miceforest_kernel")

In [None]:
kernel

In [None]:
completed_dataset = kernel.complete_data()

In [None]:
for n in range(5): 
    kernel.complete_data(n).to_parquet(f"mimic4-hosp_hf-multiple_imputed_{n}.parquet")