In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import psycopg2
import random
import math
import pandas as pd
from pathlib import Path

In [None]:
con = psycopg2.connect(dbname='mimic', user='sudarshan', host='/var/run/postgresql')

query = """
select hadm_id, subject_id, icustay_id, admission_age, wait_period, chartinterval, category, description, text, class_label from notes where class_label != -1 and length(text) between 100 and 8500
"""
df = pd.read_sql_query(query, con)
con.close()
df.head()

In [None]:
def fix_df(df):
    df.drop_duplicates(inplace=True)    
    df['note'] = df['category'].str.cat(df['description'], sep='\n')
    df['note'] = df['note'].str.cat(df['text'], sep='\n')
    df.drop(['category', 'description', 'text'], axis=1, inplace=True)
    cols = list(df.columns)
    cols[-1] = 'class_label'
    cols[-2] = 'note'
    df = df.reindex(cols, axis=1)
    return df

def set_splits(df, val_pct, test_pct=None):
    df['split'] = 'train'
    df_len = len(df)
    idxs = list(range(df_len))
    random.shuffle(idxs)

    val_idx = math.ceil(df_len * val_pct)
    val_idxs = idxs[:val_idx]
    df.loc[val_idxs, 'split'] = 'val'

    if test_pct:
        test_idx = val_idx + math.ceil(df_len * test_pct)
        test_idxs = idxs[val_idx:test_idx]
        df.loc[test_idxs, 'split'] = 'test'

    return df

In [None]:
df = fix_df(df)

In [None]:
classes = [None] * df['class_label'].nunique()
for idx in range(len(classes)):
    classes[idx] = df[df['class_label'] == idx].copy()

In [None]:
for idx in range(len(classes)):
    classes[idx] = set_splits(classes[idx], 0.1, 0.1)

In [None]:
df = pd.concat(classes, axis=0)

In [None]:
print(len(df[(df['class_label'] == 0) & (df['split'] == 'train')]), len(df[(df['class_label'] == 0) & (df['split'] == 'train')])/len(classes[0]))
print(len(df[(df['class_label'] == 1) & (df['split'] == 'train')]), len(df[(df['class_label'] == 1) & (df['split'] == 'train')])/len(classes[1]))

In [None]:
path = Path('./data')
df.to_csv(path/'data.csv', index=False)

In [None]:
df = pd.read_csv(path/'data.csv')
df.head()