In [1]:
import pandas as pd
import numpy as np
import sqlalchemy
from sqlqueries_unc_preeclampsia import *

### 1. load and massage the dataset

In [2]:
connection_string = os.getenv(CONN_STRING)
engine = sqlalchemy.create_engine(connection_string)

In [3]:
dat_bp = pd.read_sql(blood_pressure_sql_string, con=engine)
dat_vital = pd.read_sql(vital_sql_string, con=engine)
dat_obs = pd.read_sql(obs_sql_string, con=engine)
dat_preeclampsia = pd.read_sql(preeclampsia_sql_string, con=engine)
dat_med = pd.read_sql(get_rx_sql_string, con=engine)
dat_age = pd.read_sql(age_sql_string, con=engine)
dat_race = pd.read_sql(race_sql_string, con=engine)

In [None]:
# pivot the bp table before doing the join
dat_bp = dat_bp.pivot(columns='time_period', index='BIRTHID', values='bp_cat')
dat_bp.columns = ['before_20wk','20wk_delivery','during_delivery','90d_after_delivery']
dat_bp = dat_bp.reset_index()
dat_bp.drop(columns=['BIRTHID'])  # hide the birthid

In [None]:
dat = pd.merge(left=dat_preeclampsia, right=dat_bp, left_on='BIRTHID', right_on='BIRTHID', how='left')
dat.drop(columns=['BIRTHID'])

In [None]:
# the preeclampsia lable is created either
# 1. diagonsis code of O14 or O15
# 2. high blood pressure in perieds 1-2-3, i.e. 20 weeks - 90 days after delivery AND some lab measurements
dat['diagnosis_1'] = np.where(dat['earliest_diagnosis_date'].isna()==False, 1, 0)
dat['diagnosis_2'] = np.where( (dat['earliest_lab_confirm_date'].isna()==False) & \
                               (dat['20wk_delivery'] + dat['during_delivery'] + dat['90d_after_delivery'] >= 1), 1, 0)
dat['diagnosis'] = np.where(dat['diagnosis_1'] + dat['diagnosis_2'] > 0, 1, 0)
dat = dat.drop(columns=['earliest_diagnosis_date','latest_diagnosis_date','earliest_lab_confirm_date','latest_lab_confirm_date', 'diagnosis_1', 'diagnosis_2'])
dat.drop(columns=['BIRTHID'])

In [None]:
# the preeclampsia ratio seems somewhat high??
dat.diagnosis.value_counts()

In [8]:
# pivot the obs table
dat_obs = dat_obs.pivot(columns=['RAW_OBSCLIN_NAME'], index='BIRTHID', values=['min_VALUE','max_VALUE','mean_VALUE','median_VALUE'])
dat_obs.columns = [x.replace('VALUE', y) for x, y in dat_obs.columns.to_flat_index()]

In [None]:
dat_vital.isna().mean()

In [None]:
dat_vital.drop(columns=['BIRTHID'])

In [11]:
# fill a few percentage of BMI that's nan with availabel weight data
for measure in ['max','min','mean','median']:
    tmp = dat_vital[f'{measure}_WEIGHT'] / dat_vital['mean_HEIGHT'] ** 2 * 705
    dat_vital[f'{measure}_BMI'] = np.where(dat_vital[f'{measure}_BMI'].isna(), tmp, dat_vital[f'{measure}_BMI'])

In [None]:
dat_vital.drop(columns=['BIRTHID'])

In [None]:
# now merge the obs, vital, age and race table
dat = pd.merge(left=dat, right=dat_obs, left_on='BIRTHID', right_on='BIRTHID', how='left')
dat = pd.merge(left=dat, right=dat_vital, left_on='BIRTHID', right_on='BIRTHID', how='left')
dat = pd.merge(left=dat, right=dat_age, left_on='BIRTHID', right_on='BIRTHID', how='left')
dat = pd.merge(left=dat, right=dat_race, left_on='BIRTHID', right_on='BIRTHID', how='left')
dat.shape

In [None]:
dat.drop(columns=['BIRTHID']).head()

In [15]:
# lastly the medication table
med_col = list(dat_med.columns)
med_col.remove('BIRTHID')

In [None]:
dat = pd.merge(left=dat, right=dat_med, left_on='BIRTHID', right_on='BIRTHID', how='left')
dat.drop(columns=['BIRTHID'])

In [17]:
for col in med_col:
    dat[col] = dat[col].fillna(0)

In [None]:
# the black and non-black rate is not extremely different!
dat[dat.is_black==1].diagnosis.value_counts(normalize=True), dat[dat.is_black==0].diagnosis.value_counts(normalize=True)

### 2. take a look at the Nans

In [None]:
dat.isna().mean().to_list()

In [None]:
dat.corrwith(dat['diagnosis'], method='spearman').sort_values().tail(50)

### 3. build the model

In [21]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

In [22]:
pipeline = Pipeline([
    ('scale', StandardScaler()),
    ('imputer',KNNImputer()),
    ('clf',RandomForestClassifier(class_weight='balanced'))
])

In [23]:
# define the specificity
from sklearn.metrics import recall_score, make_scorer
specificity = make_scorer(recall_score, pos_label=0)

In [24]:
clf = GridSearchCV(pipeline,
                   param_grid={'clf__min_samples_split': [5,10,20,50,100], 'imputer__n_neighbors':[5,10,30,50]},
                   verbose=3,
                   scoring={'f1':'f1',
                            'roc_auc':'roc_auc',
                            'sensitivity':'recall',
                            'precision':'precision',
                            'specificity': specificity},
                   refit='roc_auc')

In [None]:
# X = dat.drop(columns=['BIRTHID','diagnosis','20wk_delivery','during_delivery','90d_after_delivery'])
X = dat.drop(columns=['BIRTHID','diagnosis'])
cols_todrop = []
for col in X.columns:
    if X[col].isna().mean() > 0.2:  cols_todrop.append(col)
X = X.drop(columns=cols_todrop)
y = dat['diagnosis']
clf.fit(X, y)

In [None]:
clf.best_params_, clf.best_score_

### 4. initial look at feature importance

In [None]:
import shap

In [None]:
clf.best_estimator_

In [29]:
X_processed = clf.best_estimator_[:2].transform(X)

In [30]:
explainer = shap.TreeExplainer(clf.best_estimator_[2])
shap_values = explainer.shap_values(X_processed, check_additivity=False)

In [None]:
shap.summary_plot(shap_values[1], X_processed, max_display=30, feature_names=X.columns)

### 5 remove the bpcat column

In [None]:
X = dat.drop(columns=['BIRTHID','diagnosis','20wk_delivery','during_delivery','90d_after_delivery'])
# X = dat.drop(columns=['BIRTHID','diagnosis'])
cols_todrop = []
for col in X.columns:
    if X[col].isna().mean() > 0.2:  cols_todrop.append(col)
X = X.drop(columns=cols_todrop)
y = dat['diagnosis']
clf.fit(X, y)

In [None]:
clf.best_params_, clf.best_score_

In [35]:
X_processed = clf.best_estimator_[:2].transform(X)
explainer = shap.TreeExplainer(clf.best_estimator_[2])
shap_values = explainer.shap_values(X_processed, check_additivity=False)

In [None]:
shap.summary_plot(shap_values[1], X_processed, max_display=30, feature_names=X.columns)