In [2]:
from pathlib import Path
from rich.console import Console
import polars as pl
import json
import csv

cons = Console()

## Load the MDACE dataset

In [3]:
data_path = Path("../data/MDACE")
assert data_path.exists(), "The specified path does not exist."

In [4]:
df_inpatient = pl.read_parquet(data_path / "parquet/Inpatient-ICD-10.parquet")
cons.print(df_inpatient.schema)
cons.print(df_inpatient.shape)
cons.print(df_inpatient.head())

### Validate the dataset (from Table 3 in the MDACE paper)

In [5]:
assert len(df_inpatient.select(pl.col('hadm_id')).unique()) == 302
assert len(df_inpatient.select(pl.col('note_id')).unique()) == 604
assert len(df_inpatient.select(pl.col('code')).unique()) == 1024
assert len(df_inpatient) == 3936

### Validate the dataset (from Table 4 in the MDACE paper)

In [6]:
df_inpatient.group_by('category'
).agg(
    pl.len().alias('count')
).sort('count', descending=True)

category,count
str,u32
"""Discharge summary""",3436
"""Physician""",364
"""Radiology""",60
"""General""",28
"""Nutrition""",19
…,…
"""Rehab Services""",8
"""Consult""",4
"""ECG""",2
"""Case Management""",2


In [7]:
assert len(df_inpatient.filter(pl.col('category') == "Discharge summary")) == 3436 # 2 more than in the paper
assert len(df_inpatient.filter(pl.col('category') == "Physician")) == 364
assert len(df_inpatient.filter(pl.col('category') == "Radiology")) == 60
assert len(df_inpatient.filter(pl.col('category') == "General")) == 28
assert len(df_inpatient.filter(pl.col('category') == "Nutrition")) == 19

## Load the Splits

In [8]:
# Read the list of ids from a file, one per line
with open(data_path / "splits/Inpatient/MDace-code-ev-train.csv", 'r') as f:
    train_hadm_ids = [int(line.strip()) for line in f]

with open(data_path / "splits/Inpatient/MDace-code-ev-val.csv", 'r') as f:
    dev_hadm_ids = [int(line.strip()) for line in f]

with open(data_path / "splits/Inpatient/MDace-code-ev-test.csv", 'r') as f:
    test_hadm_ids = [int(line.strip()) for line in f]

def get_split(hadm_id):
    if hadm_id in train_hadm_ids:
        return 'train'
    elif hadm_id in dev_hadm_ids:
        return 'dev'
    elif hadm_id in test_hadm_ids:
        return 'test'
    else:
        return 'unknown'

In [9]:
df_inpatient_splits = df_inpatient.with_columns(
    pl.when(pl.col('hadm_id').is_in(train_hadm_ids)).then(pl.lit('train'))
    .when(pl.col('hadm_id').is_in(dev_hadm_ids)).then(pl.lit('dev'))
    .when(pl.col('hadm_id').is_in(test_hadm_ids)).then(pl.lit('test'))
    .otherwise(pl.lit('unknown')).alias('split')
)
cons.print(df_inpatient_splits.head())

In [10]:
df_inpatient_splits.write_parquet(data_path / "parquet-splits/Inpatient-ICD-10-with-splits.parquet")
df_inpatient_splits.filter(
    pl.col('split') == 'train'
).write_parquet(data_path / "parquet-splits/Inpatient-ICD-10-train.parquet")
df_inpatient_splits.filter(
    pl.col('split') == 'dev'
).write_parquet(data_path / "parquet-splits/Inpatient-ICD-10-dev.parquet")
df_inpatient_splits.filter(
    pl.col('split') == 'test'
).write_parquet(data_path / "parquet-splits/Inpatient-ICD-10-test.parquet")

## Huggingface Dataset

In [11]:
from datasets import load_dataset

In [12]:
dataset = load_dataset("parquet", data_files={
    'train': str(data_path / "parquet-splits/Inpatient-ICD-10-train.parquet"), 
    'dev': str(data_path / "parquet-splits/Inpatient-ICD-10-dev.parquet"),
    'test': str(data_path / "parquet-splits/Inpatient-ICD-10-test.parquet")
})

Generating train split: 0 examples [00:00, ? examples/s]

Generating dev split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [17]:
dataset.save_to_disk(data_path / "mdace-inpatient-icd10")

Saving the dataset (0/1 shards):   0%|          | 0/2411 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/764 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/761 [00:00<?, ? examples/s]