*Created On: Feb 11, 2026*
## Description
To prototype a `EHRFoundationalModelMIMIC4(BaseTask)` class, focusing first on clinical discharge notes + radiology notes pre-processor

- [PyHealth Task: Length of Stay Stagenet](https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/tasks/length_of_stay_stagenet_mimic4.py)
- [PyHealth Task: Mortality Prediction](https://github.com/sunlabuiuc/PyHealth/blob/master/pyhealth/tasks/mortality_prediction.py)
- [`hamid`](https://archive.physionet.org/mimic2/UserGuide/node12.html): Hospital Admission ID

In [None]:
# Change directory to package root
import os
PROJECT_ROOT = '/Users/wpang/Desktop/PyHealth'
os.chdir(PROJECT_ROOT)

# Other General Packages
from datetime import datetime
from typing import Any, Dict, List, Optional

In [None]:
# PyHealth Packages
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import MultimodalMortalityPredictionMIMIC4
from pyhealth.tasks.base_task import BaseTask

# Will's Contribution Utilities
from will_contribution.utils import delete_cache

In [None]:
import pandas as pd
import polars as pl

pl.Config.set_tbl_rows(1000)
pl.Config.set_tbl_cols(100)
pl.Config.set_fmt_str_lengths(1000)

In [None]:
# Paths
EHR_ROOT = os.path.join(PROJECT_ROOT, "srv/local/data/physionet.org/files/mimiciv/2.2")
NOTE_ROOT = os.path.join(PROJECT_ROOT, "srv/local/data/physionet.org/files/mimic-iv-note/2.2")
CXR_ROOT = os.path.join(PROJECT_ROOT,"srv/local/data/physionet.org/files/mimic-cxr-jpg/2.0.0")
CACHE_DIR = os.path.join(PROJECT_ROOT,"srv/local/data/wp/pyhealth_cache")

In [None]:
class EHRFoundationalModelMIMIC4(BaseTask):
    
    task_name: str = "EHRFoundationalModelMIMIC4"
    TOKEN_REPRESENTING_MISSING_TEXT = "<missing>"
    TOKEN_REPRESENTING_MISSING_FLOAT = float("nan")
    
    def __init__(self):
        """Initialize the EHR Foundational Model task."""
        self.input_schema: Dict[str, Union[str, Tuple[str, Dict]]] = {
            "discharge_note_times": (
                "tuple_time_text",
                {
                    "tokenizer_name": "bert-base-uncased",
                    "type_tag": "note",
                },
            ),
            "radiology_note_times": (
                "tuple_time_text",
                {
                    "tokenizer_name": "bert-base-uncased",
                    "type_tag": "note",
                },
            ),
        }
        self.output_schema: Dict[str, str] = {"mortality": "binary"}

    def _clean_text(self, text: Optional[str]) -> Optional[str]:
        """Return text if non-empty, otherwise None."""
        return text if text else None

    def __call__(self, patient: Any) -> List[Dict[str, Any]]:
        # Get demographic info to filter by age
        demographics = patient.get_events(event_type="patients")
        if not demographics:
            return []

        demographics = demographics[0]

        # Get visits
        admissions = patient.get_events(event_type="admissions")
        if len(admissions) == 0:
            return []

        # Determine which admissions to process iteratively
        # Check each admission's NEXT admission for mortality flag
        admissions_to_process = []
        mortality_label = 0

        for i, admission in enumerate(admissions):
            # Check if THIS admission has the death flag
            if admission.hospital_expire_flag in [1, "1"]:
                # Patient died in this admission - set mortality label
                # but don't include this admission's data
                mortality_label = 1
                break

            # Check if there's a next admission with death flag
            if i + 1 < len(admissions):
                next_admission = admissions[i + 1]
                if next_admission.hospital_expire_flag in [1, "1"]:
                    # Next admission has death - include current, set mortality
                    admissions_to_process.append(admission)
                    mortality_label = 1
                    break

            # No death in current or next - include this admission
            admissions_to_process.append(admission)

        if len(admissions_to_process) == 0:
            return []

        # Aggregated notes and time offsets across all admissions (per hadm_id)
        all_discharge_texts: List[str] = []
        all_discharge_times_from_admission: List[float] = []
        all_radiology_texts: List[str] = []
        all_radiology_times_from_admission: List[float] = []

        # Process each admission independently (per hadm_id)
        for admission in admissions_to_process:
            admission_time = admission.timestamp

            # Get notes for this hadm_id only
            discharge_notes = patient.get_events(
                event_type="discharge", filters=[("hadm_id", "==", admission.hadm_id)]
            )
            radiology_notes = patient.get_events(
                event_type="radiology", filters=[("hadm_id", "==", admission.hadm_id)]
            )

            for note in discharge_notes: #TODO: Maybe make this into a helper function?
                try:
                    note_text = self._clean_text(note.text)
                    if note_text:
                        time_from_admission = (
                            note.timestamp - admission_time
                        ).total_seconds() / 3600.0
                        all_discharge_texts.append(note_text)
                        all_discharge_times_from_admission.append(time_from_admission)
                except AttributeError: # note object is missing .text or .timestamp attribute (e.g. malformed note)
                    all_discharge_texts.append(TOKEN_REPRESENTING_MISSING_TEXT) # Token representing missing text
                    all_discharge_times_from_admission.append(float("nan")) # Token representing missing time(?)
            if not discharge_notes: # If we get an empty list
                all_discharge_texts.append(TOKEN_REPRESENTING_MISSING_TEXT) # Token representing missing text
                all_discharge_times_from_admission.append(TOKEN_REPRESENTING_MISSING_FLOAT) # Token representing missing time(?)

            for note in radiology_notes: #TODO: Maybe make this into a helper function?
                try:
                    note_text = self._clean_text(note.text)
                    if note_text:
                        time_from_admission = (
                            note.timestamp - admission_time
                        ).total_seconds() / 3600.0
                        all_radiology_texts.append(note_text)
                        all_radiology_times_from_admission.append(time_from_admission)
                except AttributeError: # note object is missing .text or .timestamp attribute (e.g. malformed note)
                    all_radiology_texts.append(TOKEN_REPRESENTING_MISSING_TEXT) # Token representing missing text
                    all_radiology_times_from_admission.append(TOKEN_REPRESENTING_MISSING_FLOAT) # Token representing missing time(?)
            if not radiology_notes: # If we receive empty list
                all_radiology_texts.append(TOKEN_REPRESENTING_MISSING_TEXT) # Token representing missing text
                all_radiology_times_from_admission.append(TOKEN_REPRESENTING_MISSING_FLOAT) # Token representing missing time(?)

        discharge_note_times_from_admission = (all_discharge_texts, all_discharge_times_from_admission)
        radiology_note_times_from_admission = (all_radiology_texts, all_radiology_times_from_admission)

        return [
            {
                "patient_id": patient.patient_id,
                "discharge_note_times": discharge_note_times_from_admission,
                "radiology_note_times": radiology_note_times_from_admission,
                "mortality": mortality_label,
            }
        ]

### Main Code

In [None]:
dataset = MIMIC4Dataset(
        ehr_root=EHR_ROOT,
        note_root=NOTE_ROOT,
        ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
        note_tables=["discharge", "radiology"],
        cache_dir=CACHE_DIR,
        num_workers=8,
        dev=True
    )

# To Display:
# dataset.global_event_df.head(10).collect().to_pandas()

### Quick Notes

Patient `10000032`:
- Num Hospital Admissions: 4 (`len(samples[0]['admissions_to_process'])`)
  - `'hadm_id': '22595853'`: `timestamp=datetime.datetime(2180, 5, 6, 22, 23)`
    - Discharge Note ID: ['10000032-DS-21']
    - Radiology Note ID: ["10000032-RR-14", "10000032-RR-15", "10000032-RR-16"]
  - `'hadm_id': '22841357'`: `timestamp=datetime.datetime(2180, 6, 26, 18, 27)`
    - Discharge Note ID: ['10000032-DS-22']
    - Radiology Note ID: ['10000032-RR-22', '10000032-RR-23']
  - `'hadm_id': '29079034'`: `timestamp=datetime.datetime(2180, 7, 23, 12, 35)`
    - Discharge Note ID: ['10000032-DS-23']
    - Radiology Note ID: ["10000032-RR-44", "10000032-RR-45"]
  - `'hadm_id': '25742920'`: `timestamp=datetime.datetime(2180, 8, 5, 23, 44)`
    - Discharge Note ID: ['10000032-DS-24']
    - Radiology Note ID: ["10000032-RR-46", "10000032-RR-47", "10000032-RR-33"]
- Discharge Notes Count: 4


In [None]:
ID = "10000032"
TYPE = "radiology"
#TYPE = "discharge"
NOTE = "radiology_notes"
# NOTE = "discharge_notes"
HADM_ID = '22841357'

In [None]:
task = EHRFoundationalModelMIMIC4()    

# Single patient
patient = dataset.get_patient(ID)                                                                            
print("----")
print("Admission IDs (hadm_id)")
for index, content in enumerate(patient.get_events(event_type="admissions")):
    print(f"{content.attr_dict['hadm_id']} -> Admission Time: {content.timestamp}")

print("----")
print(f"Count of {TYPE} notes for hadm_id: {HADM_ID}")
print(len(patient.get_events(
                event_type=TYPE, filters=[("hadm_id", "==", HADM_ID)])))
print("----")
print(f"Note ID for {TYPE} notes for hadm_id: {HADM_ID}")
for index, content in enumerate(patient.get_events(event_type=TYPE, filters=[("hadm_id", "==", HADM_ID)])):
    print(f"{content.attr_dict['note_id']} -> Note Timestamp: {content.timestamp} -> First 100 Characters: {content.text[:100]}")
print("----")

In [None]:
samples = task(patient)

texts, times = samples[0]['radiology_note_times']

for i, (text, time) in enumerate(zip(texts, times)):
    print("Time Diff (Hrs):", time)
    print(text)
    print("-----")