In [1]:
import os
import sys

pyhealth_path = os.path.dirname(os.getcwd())
if pyhealth_path not in sys.path:
    print(f"Adding PyHealth to sys.path: {pyhealth_path}")
    sys.path.insert(0, pyhealth_path)

Adding PyHealth to sys.path: /home/johnwu3/projects/PyHealth_Branch_Testing/PyHealth


In [2]:
pyhealth_path

'/home/johnwu3/projects/PyHealth_Branch_Testing/PyHealth'

In [3]:
from pyhealth.tasks import BaseTask
from typing import Any, Dict, List, Optional
from datetime import datetime

class MortalityPredictionMIMIC3Heterogeneous(BaseTask):
    """Task for predicting mortality using MIMIC-III dataset with text data.

    This task aims to predict whether the patient will decease in the next hospital
    visit based on clinical information from the current visit.
    """

    task_name: str = "MortalityPredictionMIMIC3"
    input_schema: Dict[str, str] = {
        "conditions": "sequence",
        "procedures": "sequence",
        "drugs": "sequence",
    }
    output_schema: Dict[str, str] = {"mortality": "binary"}

    def __call__(self, patient: Any) -> List[Dict[str, Any]]:
        """Processes a single patient for the mortality prediction task."""
        samples = []

        # We will drop the last visit
        visits = patient.get_events(event_type="admissions")

        if len(visits) <= 1:
            return []

        for i in range(len(visits) - 1):
            visit = visits[i]
            next_visit = visits[i + 1]

            # Check discharge status for mortality label - more robust handling
            if next_visit.hospital_expire_flag not in [0, 1, "0", "1"]:
                mortality_label = 0
            else:
                mortality_label = int(next_visit.hospital_expire_flag)

            # Convert string timestamps to datetime objects
            try:
                # Check the type and convert if necessary
                if isinstance(visit.dischtime, str):
                    discharge_time = datetime.strptime(
                        visit.dischtime, "%Y-%m-%d %H:%M:%S"
                    )
                else:
                    discharge_time = visit.dischtime
            except (ValueError, AttributeError):
                # If conversion fails, skip this visit
                print("Error parsing discharge time:", visit.dischtime)
                continue

            # Get clinical codes
            diagnoses = patient.get_events(
                event_type="diagnoses_icd",
                start=visit.timestamp,
                end=discharge_time,  # Now using a datetime object
            )
            procedures = patient.get_events(
                event_type="procedures_icd",
                start=visit.timestamp,
                end=discharge_time,  # Now using a datetime object
            )
            prescriptions = patient.get_events(
                event_type="prescriptions",
                start=visit.timestamp,
                end=discharge_time,  # Now using a datetime object
            )

            conditions = [event.icd9_code for event in diagnoses]
            procedures_list = [event.icd9_code for event in procedures]
            drugs = [event.drug for event in prescriptions]

            # Exclude visits without condition, procedure, or drug code
            samples.append(
                {
                    "hadm_id": visit.hadm_id,
                    "patient_id": patient.patient_id,
                    "conditions": conditions,
                    "procedures": procedures_list,
                    "drugs": drugs,
                    "mortality": mortality_label,
                }
            )

        return samples



In [4]:
from pyhealth.datasets import MIMIC3Dataset
dataset = MIMIC3Dataset(
    root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III",
    tables=["diagnoses_icd", "procedures_icd", "prescriptions", "noteevents"]
)

No config path provided, using default config
Initializing mimic3 dataset from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III (dev mode: False)
Scanning table: patients from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/PATIENTS.csv




Some column names were converted to lowercase
Scanning table: admissions from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv
Some column names were converted to lowercase
Scanning table: icustays from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ICUSTAYS.csv
Some column names were converted to lowercase
Scanning table: diagnoses_icd from https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/DIAGNOSES_ICD.csv.gz
Original path does not exist. Using alternative: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/DIAGNOSES_ICD.csv
Some column names were converted to lowercase
Joining with table: https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/ADMISSIONS.csv.gz
Original path

In [8]:
from pyhealth.tasks.mortality_prediction import MortalityPredictionMIMIC3
from pyhealth.datasets import split_by_patient, get_dataloader
mimic3_mortality_prediction = MortalityPredictionMIMIC3Heterogeneous()
samples = dataset.set_task(mimic3_mortality_prediction, num_workers=1, cache_dir="cache/") # use default task

Setting task MortalityPredictionMIMIC3 for mimic3 base dataset...
Loading cached samples from cache/MortalityPredictionMIMIC3.parquet
Loaded 2776 cached samples
Label mortality vocab: {0: 0, 1: 1}


Processing samples: 100%|██████████| 2776/2776 [00:00<00:00, 63155.03it/s]

Generated 2776 samples for task MortalityPredictionMIMIC3





In [9]:
from pyhealth.datasets import split_by_sample


train_dataset, val_dataset, test_dataset = split_by_sample(
    dataset=samples,
    ratios=[0.7, 0.1, 0.2]
)