# Dataset Preparation for First ICU Prediction

## Imports & Inits

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

import math
import numpy as np
import pandas as pd
import spacy

import seaborn as sns
sns.set(style = 'darkgrid')

import matplotlib.pyplot as plt
%matplotlib inline

from pathlib import Path

from utils.splits import set_two_splits
from args import args
vars(args)

{'path': PosixPath('../data'),
 'workdir': PosixPath('../data/workdir'),
 'figdir': PosixPath('../data/workdir/figdir'),
 'raw_csv': PosixPath('../data/mimic_icu_pred_raw_dataset.csv'),
 'proc_csv': PosixPath('../data/mimic_icu_pred_proc_dataset.csv'),
 'cols': ['imminent_label',
  'discharge_label',
  'scispacy_note',
  'charttime',
  'intime'],
 'dates': ['charttime', 'intime'],
 'imminent_threshold': {'lr': 0.47, 'rf': 0.32, 'gbm': 0.3},
 'discharge_threshold': {'lr': 0.48, 'rf': 0.37, 'gbm': 0.39}}

## Data Processing

1. Load in the data
2. Drop duplicates
3. Merge `category`, `description`, and `text` into a new column called `note`
4. Tokenize text using `scispacy` and create new column called `scispacy_note` to save tokenized text
5. Save a csv file onto disk

In [None]:
def tokenize_text(text):
  tokens = [token.text for token in nlp(text)]
  return ' '.join(tokens)

def group_eth(eth):
  eth = eth.lower()
  if 'white' in eth:
    return 'white'
  elif 'black' in eth:
    return 'black'
  elif 'hispanic' in eth:
    return 'hispanic'
  elif 'asian' in eth:
    return 'asian'
  else:
    return 'unknown'

In [None]:
df = pd.read_csv(args.raw_csv)
df.shape

In [None]:
nlp = spacy.load('en_core_sci_md', disable=['parser', 'ner', 'tagger'])

df.drop_duplicates(inplace=True)
df['note'] = df['category'].str.cat(df['description'], sep='\n')
df['note'] = df['note'].str.cat(df['text'], sep='\n')
df['ethnicity'] = df['ethnicity'].apply(group_eth)
df['scispacy_note'] = df['note'].apply(tokenize_text)
df.drop(['text', 'description'], axis=1, inplace=True)
df.shape

In [None]:
df.to_csv(args.proc_csv, index=False)

## Splitting

In [3]:
df = pd.read_csv(args.proc_csv)
print(df.columns)
print(df.shape)

Index(['subject_id', 'hadm_id', 'icustay_id', 'admission_type', 'admittime',
       'dischtime', 'intime', 'outtime', 'charttime', 'icu_los', 'deathtime',
       'adm_icu_period', 'chart_icu_period', 'chartinterval', 'ethnicity',
       'dob', 'gender', 'admission_age', 'category', 'imminent_label',
       'discharge_label', 'note', 'scispacy_note'],
      dtype='object')
(75271, 23)


In [4]:
print(f"Number of admissions: {df['hadm_id'].nunique()}")
print(f"Number of Notes: {len(df)}")
print(f"Average number of notes per admission: {df.groupby('hadm_id')['scispacy_note'].size().mean():0.2f}")
assert(round(len(df)/df['hadm_id'].nunique(), 2) == round(df.groupby('hadm_id')['scispacy_note'].size().mean(), 2))

Number of admissions: 12877
Number of Notes: 75271
Average number of notes per admission: 5.85


In [5]:
from sklearn.model_selection import GroupShuffleSplit

In [6]:
split_df = df[['hadm_id', 'scispacy_note', 'imminent_label', 'discharge_label']]

In [9]:
train_idxs, test_idxs = next(GroupShuffleSplit(test_size=.15, n_splits=2, random_state=42).split(split_df, groups=df['hadm_id']))

train_df = df.iloc[train_idxs]
test_df = df.iloc[test_idxs]

In [25]:
assert(round(train_df['hadm_id'].nunique()/df['hadm_id'].nunique(), 2) == 0.85)
assert(round(test_df['hadm_id'].nunique()/df['hadm_id'].nunique(), 2) == 0.15)
assert(list(set(train_df['hadm_id'].unique().tolist()) & set(test_df['hadm_id'].unique().tolist())) == [])

In [30]:
print(f"Number of admissions: {train_df['hadm_id'].nunique()}")
print(f"Number of Notes: {len(train_df)}")
print(f"Average number of notes per admission: {train_df.groupby('hadm_id')['scispacy_note'].size().mean():0.2f}")
assert(round(len(train_df)/train_df['hadm_id'].nunique(), 2) == round(train_df.groupby('hadm_id')['scispacy_note'].size().mean(), 2))

Number of admissions: 10945
Number of Notes: 64241
Average number of notes per admission: 5.87


In [31]:
print(f"Number of admissions: {test_df['hadm_id'].nunique()}")
print(f"Number of Notes: {len(test_df)}")
print(f"Average number of notes per admission: {test_df.groupby('hadm_id')['scispacy_note'].size().mean():0.2f}")
assert(round(len(test_df)/test_df['hadm_id'].nunique(), 2) == round(test_df.groupby('hadm_id')['scispacy_note'].size().mean(), 2))

Number of admissions: 1932
Number of Notes: 11030
Average number of notes per admission: 5.71


## Plots

In [None]:
df = pd.read_csv(args.proc_csv)
df.columns

In [None]:
intervals = ['-1 ≤ t ≤ 0']
intervals += [f'-{i+1} ≤ t ≤ -{i}' for i in range(1, 15)]
intervals.append('t ≤ -15')

### ICU LOS Histogram

In [None]:
plot_df = df[['icu_los']]

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))

sns.distplot(plot_df, kde=False, ax=ax, bins=80)
ax.set_xlabel('ICU Length of Stay (days)')
ax.set_ylabel('# notes')
ax.set_xlim(0, 40)
ax.xaxis.labelpad = 20
ax.yaxis.labelpad = 20
# ax.text(ax.get_xlim()[1]*0.50, ax.get_ylim()[1]*0.80, f"Min: {df['icu_los'].min()}, Avg: {df['icu_los'].mean(): 0.2f}, Max: {df['icu_los'].max()}", fontsize=15, ha='center', va='bottom')

In [None]:
# fig.savefig(args.figdir/'icu_los_hist.tif', dpi=300)

### Bar Plot of Notes Over Days

#### All Notes

In [None]:
plot_df = pd.DataFrame(df.groupby(['chartinterval']).size(), columns=['n_notes'])
plot_df.reset_index(inplace=True, drop=True)
plot_df['days'] = intervals

In [None]:
fig, ax = plt.subplots(figsize=(15, 8))
sns.barplot(x='days', y='n_notes', data=plot_df, ax=ax)
ax.set_xticklabels(ax.get_xticklabels(),rotation=45, ha='right')
ax.set_xlabel('Time to ICU Admission (days)')
ax.set_ylabel('# notes')
for index, row in plot_df.iterrows():
    ax.text(index, row['n_notes'], str(row['n_notes']), color='black', ha='center', va='bottom')

In [None]:
# fig.savefig(args.figdir/'note_bp.tif', dpi=300)

#### By Category

In [None]:
def plot_intervals(ax, df, cat):
  sns.barplot(x='days', y='n_notes', data=df, ax=ax)
  ax.set_xticklabels(ax.get_xticklabels(),rotation=45, ha='right')
  ax.set_xlabel('')
  ax.set_ylabel('')
  ax.set_title(f"Note Category: {cat}\n# notes: {df['n_notes'].sum()}")   

  for index, (_, row) in enumerate(df.iterrows()):
      ax.text(index, row['n_notes'], str(row['n_notes']), color='black', ha='center', va='bottom')    

In [None]:
plot_df = pd.DataFrame(df.groupby(['category', 'chartinterval']).size(), columns=['n_notes'])
plot_df.reset_index(inplace=True)
plot_df['days'] = plot_df['chartinterval'].apply(lambda x: intervals[x])
plot_df.drop(['chartinterval'], inplace=True, axis=1)

In [None]:
fig, ax = plt.subplots(6, 2, figsize=(20, 50))
plot_intervals(ax[0][0], plot_df.loc[plot_df['category'] == 'Case Management ', ['n_notes', 'days']], 'Case Management')
plot_intervals(ax[0][1], plot_df.loc[plot_df['category'] == 'Consult', ['n_notes', 'days']], 'Consult')

plot_intervals(ax[1][0], plot_df.loc[plot_df['category'] == 'General', ['n_notes', 'days']], 'General')
plot_intervals(ax[1][1], plot_df.loc[plot_df['category'] == 'Nursing', ['n_notes', 'days']], 'Nursing')

plot_intervals(ax[2][0], plot_df.loc[plot_df['category'] == 'Nursing/other', ['n_notes', 'days']], 'Nursing/other')
plot_intervals(ax[2][1], plot_df.loc[plot_df['category'] == 'Nutrition', ['n_notes', 'days']], 'Nutrition')

plot_intervals(ax[3][0], plot_df.loc[plot_df['category'] == 'Pharmacy', ['n_notes', 'days']], 'Pharmacy')
plot_intervals(ax[3][1], plot_df.loc[plot_df['category'] == 'Physician ', ['n_notes', 'days',]], 'Physician')

plot_intervals(ax[4][0], plot_df.loc[plot_df['category'] == 'Radiology', ['n_notes', 'days']], 'Radiology')
plot_intervals(ax[4][1], plot_df.loc[plot_df['category'] == 'Rehab Services', ['n_notes', 'days']], 'Rehab Services')

plot_intervals(ax[5][0], plot_df.loc[plot_df['category'] == 'Respiratory ', ['n_notes', 'days']], 'Respiratory')
plot_intervals(ax[5][1], plot_df.loc[plot_df['category'] == 'Social Work', ['n_notes', 'days']], 'Social Work')

fig.text(0.5, 0.1, 'Time to ICU Admission (days)', ha='center')
fig.text(0.08, 0.5, '# notes', va='center', rotation='vertical')

plt.subplots_adjust(hspace = 0.3)

In [None]:
# cats = sorted(list(df['category'].unique()))

# n = 0
# fig, ax = plt.subplots(1, 1, figsize=(10, 8))
# plot_intervals(ax, plot_df.loc[plot_df['category'] == cats[n], ['n_notes', 'days']], cats[n])
# ax.set_xlabel('Time to ICU Admission (days)')
# ax.set_ylabel('# notes')

In [None]:
# fig.savefig(args.figdir/'note_cats_bp.tif', dpi=300)

### Note Chart Time to ICU Admission Period Histogram

#### All Notes

In [None]:
plot_df = df[['category', 'chart_icu_period']]

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))

sns.distplot(plot_df['chart_icu_period'], kde=False, ax=ax, bins=80)
ax.set_xlabel('Period between Note Chart Time and ICU Admission Time (days)')
ax.set_ylabel('# notes')
ax.set_xlim(0, 60)

# ax.text(ax.get_xlim()[1]*0.50, ax.get_ylim()[1]*0.80, f"Min: {mdf['chart_icu_period'].min()}, Avg: {mdf['chart_icu_period'].mean(): 0.2f}, Max: {mdf['chart_icu_period'].max()}", fontweight='bold', fontsize=15, ha='center', va='bottom')

In [None]:
# fig.savefig(args.figdir/'note_icu_period_hist.tif', dpi=300)

#### By Category

In [None]:
def plot_period(ax, df, cat):
  sns.distplot(df, kde=False, ax=ax, bins=10)
  ax.set_xlabel('')
  ax.set_ylabel('')
  ax.set_title(f"Note Category: {cat}")   

In [None]:
fig, ax = plt.subplots(6, 2, figsize=(20, 50))
plot_period(ax[0][0], plot_df.loc[plot_df['category'] == 'Case Management ', ['chart_icu_period']], 'Case Management')
plot_period(ax[0][1], plot_df.loc[plot_df['category'] == 'Consult', ['chart_icu_period']], 'Consult')

plot_period(ax[1][0], plot_df.loc[plot_df['category'] == 'General', ['chart_icu_period']], 'General')
plot_period(ax[1][1], plot_df.loc[plot_df['category'] == 'Nursing', ['chart_icu_period']], 'Nursing')

plot_period(ax[2][0], plot_df.loc[plot_df['category'] == 'Nursing/other', ['chart_icu_period']], 'Nursing/other')
plot_period(ax[2][1], plot_df.loc[plot_df['category'] == 'Nutrition', ['chart_icu_period']], 'Nutrition')

plot_period(ax[3][0], plot_df.loc[plot_df['category'] == 'Pharmacy', ['chart_icu_period']], 'Pharmacy')
plot_period(ax[3][1], plot_df.loc[plot_df['category'] == 'Physician ', ['chart_icu_period',]], 'Physician')

plot_period(ax[4][0], plot_df.loc[plot_df['category'] == 'Radiology', ['chart_icu_period']], 'Radiology')
plot_period(ax[4][1], plot_df.loc[plot_df['category'] == 'Rehab Services', ['chart_icu_period']], 'Rehab Services')

plot_period(ax[5][0], plot_df.loc[plot_df['category'] == 'Respiratory ', ['chart_icu_period']], 'Respiratory')
plot_period(ax[5][1], plot_df.loc[plot_df['category'] == 'Social Work', ['chart_icu_period']], 'Social Work')

fig.text(0.5, 0.11, 'Period between Note Chart Time and ICU Admission Time (days)', ha='center')
fig.text(0.08, 0.5, '# notes', va='center', rotation='vertical')

plt.subplots_adjust(hspace = 0.1)

In [None]:
# cats = sorted(list(df['category'].unique()))

# n = 0
# fig, ax = plt.subplots(1, 1, figsize=(10, 8))
# plot_period(ax, plot_df.loc[plot_df['category'] == cats[n], ['chart_icu_period']], cats[n])
# ax.set_xlabel('Time to ICU Admission (days)')
# ax.set_ylabel('# notes')

In [None]:
# fig.savefig(args.figdir/'note_cat_icu_period_hist.tif', dpi=300)

### Hospital Admission to ICU Admission Period Histogram

In [None]:
plot_df = df[['adm_icu_period']]

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))

sns.distplot(plot_df, kde=False, ax=ax, bins=80)
ax.set_xlabel('Time between hospital admission and ICU admission (days)')
ax.set_ylabel('# notes')
ax.set_xlim(0, 70)
# ax.text(ax.get_xlim()[1]*0.50, ax.get_ylim()[1]*0.80, f"Min: {mdf['adm_icu_period'].min()}, Avg: {mdf['adm_icu_period'].mean(): 0.2f}, Max: {mdf['adm_icu_period'].max()}", fontweight='bold', fontsize=15, ha='center', va='bottom')    

In [None]:
# fig.savefig(args.figdir/'adm_icu_period_hist.tif', dpi=300)

### Note Length Histogram

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))
sns.distplot(df['note'].apply(len), kde=False, ax=ax, bins=100)
ax.set_xlabel('Length of Note (characters)')
ax.set_ylabel('# notes')

In [None]:
# fig.savefig(args.figdir/'note_len_hist.tif', dpi=300)

### Imminent ICU Prediction Class Distribution

In [None]:
desc = ['not used', 'negative class', 'positive class']

#### Without Admissions

In [None]:
plot_df = pd.DataFrame(df.groupby(['imminent_label']).size(), columns=['n_notes']).reset_index()
plot_df['imminent_label'] = desc
plot_df = plot_df.reindex([2, 1, 0])
plot_df.reset_index(inplace=True, drop=True)

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))
sns.barplot(x='imminent_label', y='n_notes', data=plot_df, ax=ax)
ax.set_xlabel('Imminent Class Label')
ax.set_ylabel('# notes')
for index, row in plot_df.iterrows():
  ax.text(index+0.05, row['n_notes']+50, str(row['n_notes']), color='black', ha='right', va='bottom')

In [None]:
# fig.savefig(args.figdir/'imminent_label_bp.tif', dpi=300)

#### With Admissions

In [None]:
p1 = pd.DataFrame(df.groupby(['imminent_label']).size(), columns=['n_notes']).reset_index()
p2 = df.groupby(['imminent_label'])['hadm_id'].nunique().reset_index()
p = p1.merge(p2, on=['imminent_label'])

In [None]:
p['imminent_label'] = desc

In [None]:
p = p.reindex([2,1,0])
p.reset_index(inplace=True, drop=True)
p

In [None]:
plot_df = p.copy()
plot_df.rename(columns={'hadm_id':'# Admissions', 'n_notes':'# Notes'}, inplace=True)
plot_df = pd.melt(plot_df, id_vars='imminent_label', var_name='Legend', value_name='counts')

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))

sns.barplot(x='imminent_label', y='counts', hue='Legend', data=plot_df, ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), ha='right')
ax.set_xlabel('Imminent Class Label')
ax.set_ylabel('# notes')

for index, row in plot_df.iterrows():
    if index < len(plot_df)//2:
        ax.text(index-0.13, row['counts']+50, str(row['counts']), color='black', ha='right', va='bottom')
    else:
        ax.text(index % (len(plot_df)//2)+0.25, row['counts']+50, str(row['counts']), color='black', ha='right', va='bottom')

In [None]:
# fig.savefig(args.figdir/'imminent_label_adms_bp.tif', dpi=300)

### Discharge ICU Prediction Class Distribution

In [None]:
desc = ['negative class', 'positive class']

#### Without Admissions

In [None]:
plot_df = pd.DataFrame(df.groupby(['discharge_label']).size(), columns=['n_notes']).reset_index()
plot_df['discharge_label'] = desc
plot_df = plot_df.reindex([1, 0])
plot_df.reset_index(inplace=True, drop=True)
plot_df

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))
sns.barplot(x='discharge_label', y='n_notes', data=plot_df, ax=ax)
ax.set_xlabel('5 Day Discharge Class Label')
ax.set_ylabel('# notes')
for index, row in plot_df.iterrows():
  ax.text(index+0.05, row['n_notes']+50, str(row['n_notes']), color='black', ha='right', va='bottom')

In [None]:
# fig.savefig(args.figdir/'discharge_label_bp.tif', dpi=300)

#### With Admissions

In [None]:
p1 = pd.DataFrame(df.groupby(['discharge_label']).size(), columns=['n_notes']).reset_index()
p2 = df.groupby(['discharge_label'])['hadm_id'].nunique().reset_index()
p = p1.merge(p2, on=['discharge_label'])
p['discharge_label'] = desc
p = p.reindex([1,0])
p.reset_index(inplace=True, drop=True)
p

In [None]:
plot_df = p.copy()
plot_df.rename(columns={'hadm_id':'# Admissions', 'n_notes':'# Notes'}, inplace=True)
plot_df = pd.melt(plot_df, id_vars='discharge_label', var_name='Legend', value_name='counts')

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))

sns.barplot(x='discharge_label', y='counts', hue='Legend', data=plot_df, ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), ha='right')
ax.set_xlabel('5 Day Discharge Class Label')
ax.set_ylabel('# notes')

for index, row in plot_df.iterrows():
    if index < len(plot_df)//2:
        ax.text(index-0.13, row['counts']+50, str(row['counts']), color='black', ha='right', va='bottom')
    else:
        ax.text(index % (len(plot_df)//2)+0.25, row['counts']+50, str(row['counts']), color='black', ha='right', va='bottom')

In [None]:
# fig.savefig(args.figdir/'discharge_label_adms_bp.tif', dpi=300)

## Stats

In [None]:
df = pd.read_csv(args.proc_csv)

Make sure average prevalence of random test set is approximately same as real prevalence

In [None]:
avg_p = []
tdf = df.loc[(df['imminent_label'] != -1)][['imminent_label']].copy()
tdf.reset_index(inplace=True, drop=True)

for seed in range(127, 227):
  sdf = set_two_splits(tdf.copy(), name='test', seed=seed)
  test_size = len(sdf.loc[(sdf['split'] == 'test')])
  test_pos = len(sdf.loc[(sdf['split'] == 'test') & (sdf['imminent_label'] == 1)])
  avg_p.append(test_pos/test_size)  

avg_p = np.array(avg_p)
print(f"Real prevalence of Imminent class: {(len(df.loc[df['imminent_label'] == 1])/len(df.loc[(df['imminent_label'] != -1)])):0.3f}")
print(f"Average = {(avg_p.mean()):0.3f}, std = {(avg_p.std()):0.3f}")

avg_p = []

for seed in range(127, 227):
  sdf = set_two_splits(df.copy(), name='test', seed=seed)
  test_size = len(sdf.loc[(sdf['split'] == 'test')])
  test_pos = len(sdf.loc[(sdf['split'] == 'test') & (sdf['discharge_label'] == 1)])
  avg_p.append(test_pos/test_size)  

avg_p = np.array(avg_p)
print(f"Real prevalence of 5 day discharge class: {(len(df.loc[df['discharge_label'] == 1])/len(df)):0.3f}")
print(f"Average = {(avg_p.mean()):0.3f}, std = {(avg_p.std()):0.3f}")

In [None]:
df.groupby(df['deathtime'].apply(lambda x: True if pd.notnull(x) else False))['subject_id'].nunique().reset_index()

In [None]:
df.groupby(df['gender'])['subject_id'].nunique().reset_index()

In [None]:
df.groupby(df['ethnicity'])['subject_id'].nunique().reset_index()

In [None]:
df.groupby(df['admission_type'])['hadm_id'].nunique().reset_index()

In [None]:
df.groupby(df['category']).size().reset_index()

In [None]:
df.groupby(df.loc[(df['imminent_label'] == 1)]['category'])['subject_id'].nunique().reset_index()

In [None]:
df.groupby(df.loc[(df['imminent_label'] == 0)]['category'])['subject_id'].nunique().reset_index()

In [None]:
df.groupby(df.loc[(df['discharge_label'] == 1)]['category'])['subject_id'].nunique().reset_index()

In [None]:
df.groupby(df.loc[(df['discharge_label'] == 0)]['category'])['subject_id'].nunique().reset_index()

In [None]:
df.groupby('admission_age')['subject_id'].unique().reset_index()['admission_age'].describe()

In [None]:
df['icu_los'].describe()

In [None]:
df['adm_icu_period'].describe()

In [None]:
df['chart_icu_period'].describe()