# Deep Learning on OMOP Data in `EHRData` with `PyPOTS`

This tutorial demonstrates how to quickly apply machine learning to OMOP data using [PyPOTS](https://github.com/WenjieDu/PyPOTS), a powerful toolkit for time series analysis.

**Prerequisites:** Complete the [OMOP Introduction tutorial](omop_intro.ipynb) first to understand how to load OMOP data into EHRData.

## Use Case: ICU Mortality Prediction

We'll predict in-hospital mortality for ICU patients using the [MIMIC-IV demo dataset in OMOP format](https://physionet.org/content/mimic-iv-demo-omop/0.9/).

```{note}
This is a demonstration example. Real clinical prediction requires more sophisticated preprocessing, validation, and careful consideration of clinical context.
```

## What is PyPOTS?

PyPOTS provides state-of-the-art neural network models for time series tasks:
- **Imputation** - Fill missing values in incomplete time series
- **Classification** - Predict outcomes from time series
- **Forecasting** - Predict future values
- **Clustering** - Group similar patients

PyPOTS works seamlessly with EHRData objects!




## Setup and Installation


In [None]:
%pip install pypots

In [54]:
# PyPOTS requires this for scipy compatibility
import os

os.environ["SCIPY_ARRAY_API"] = "1"

import ehrdata as ed
import duckdb
import numpy as np
import torch
from pypots.classification import BRITS

## Setup Database and Download Data


In [55]:
# Create database connection
con = duckdb.connect(":memory:")

# Download MIMIC-IV OMOP demo data
ed.dt.mimic_iv_omop(backend_handle=con)

## Define the Cohort

We'll focus on ICU patients by filtering `visit_occurrence` for ICU stays using OMOP concept IDs:
- **4305366**: Surgical ICU
- **40481392**: Medical ICU
- **32037**: Intensive Care
- **763903**: Trauma ICU
- **4149943**: Cardiac ICU

We apply two key filters:
1. **Duration**: Only ICU stays >24 hours (to ensure sufficient data for 24-hour analysis)
2. **First visit**: If a patient had multiple ICU stays, we select their **first ICU visit**


In [56]:
# Filter for first ICU visit per patient (>24 hours only)
con.execute("""
    WITH RankedVisits AS (
        SELECT
            v.*,
            vd.*,
            ROW_NUMBER() OVER (PARTITION BY v.person_id ORDER BY v.visit_start_date) AS rn
        FROM visit_occurrence v
        JOIN visit_detail vd USING (visit_occurrence_id)
        WHERE vd.visit_detail_concept_id IN (4305366, 40481392, 32037, 763903, 4149943)
            AND date_diff('hour', v.visit_start_date, v.visit_end_date) > 24
    ),
    first_icu_visit_occurrence_id AS (
        SELECT visit_occurrence_id
        FROM RankedVisits
        WHERE rn = 1
    )
    DELETE FROM visit_occurrence
    WHERE visit_occurrence_id NOT IN (SELECT visit_occurrence_id FROM first_icu_visit_occurrence_id)
""")

# Check how many ICU visits remain
n_visits = con.execute("SELECT COUNT(*) FROM visit_occurrence").fetchone()[0]
print(f"ICU cohort: {n_visits} patients (first ICU visit >24h only)")

ICU cohort: 99 patients (first ICU visit >24h only)


## Build EHRData from OMOP

Now we construct the EHRData object using **ICU visit start** as the time reference (t=0) for each patient:


In [57]:
# Step 1: Setup observations from person + visit_occurrence
edata = ed.io.omop.setup_obs(
    backend_handle=con,
    observation_table="person_visit_occurrence",  # Each row = one ICU visit
)

print(f"Created EHRData with {edata.n_obs} ICU visits")
edata.obs.head()

Created EHRData with 99 ICU visits


Unnamed: 0,person_id,gender_concept_id,year_of_birth,month_of_birth,day_of_birth,birth_datetime,race_concept_id,ethnicity_concept_id,location_id,provider_id,...,visit_type_concept_id,provider_id_1,care_site_id_1,visit_source_value,visit_source_concept_id,admitting_source_concept_id,admitting_source_value,discharge_to_concept_id,discharge_to_source_value,preceding_visit_occurrence_id
0,4239478333578644568,8507,2111,,,NaT,8527,0,,,...,32817,,,10022880|27708593,2000001808,38004207,PHYSICIAN REFERRAL,581476,HOME,
1,-8090189584974691216,8507,2118,,,NaT,8527,0,,,...,32817,,,10009049|22995465,2000001806,8870,EMERGENCY ROOM,581476,HOME,
2,2161418207209636934,8507,2060,,,NaT,2000001401,0,,,...,32817,,,10002495|24982426,2000001809,8717,TRANSFER FROM HOSPITAL,8863,SKILLED NURSING FACILITY,
3,1532249960797525190,8532,2106,,,NaT,2000001405,0,,,...,32817,,,10014078|25809882,2000001806,8870,EMERGENCY ROOM,581476,HOME HEALTH CARE,
4,2288881942133868955,8532,2102,,,NaT,8527,0,,,...,32817,,,10001217|24597018,2000001806,8870,EMERGENCY ROOM,581476,HOME HEALTH CARE,


In [58]:
# Step 2: Extract measurements from the first 24 hours
edata = ed.io.omop.setup_variables(
    edata=edata,
    backend_handle=con,
    layer="measurements",
    data_tables=["measurement"],
    data_field_to_keep={"measurement": "value_as_number"},
    interval_length_number=1,
    interval_length_unit="h",  # Hourly intervals
    num_intervals=24,  # First 24 hours
    aggregation_strategy="last",
    enrich_var_with_feature_info=True,
    instantiate_tensor=True,
)

edata

 [  8]
 [ 15]
 [ 22]
 [ 27]
 [ 32]
 [ 46]
 [ 53]
 [ 54]
 [ 56]
 [ 58]
 [ 62]
 [ 74]
 [ 75]
 [ 76]
 [ 78]
 [ 86]
 [ 99]
 [107]
 [109]
 [117]
 [132]
 [138]
 [149]
 [159]
 [217]]


EHRData object with n_obs × n_vars × n_t = 99 × 450 × 24
    obs: 'person_id', 'gender_concept_id', 'year_of_birth', 'month_of_birth', 'day_of_birth', 'birth_datetime', 'race_concept_id', 'ethnicity_concept_id', 'location_id', 'provider_id', 'care_site_id', 'person_source_value', 'gender_source_value', 'gender_source_concept_id', 'race_source_value', 'race_source_concept_id', 'ethnicity_source_value', 'ethnicity_source_concept_id', 'visit_occurrence_id', 'person_id_1', 'visit_concept_id', 'visit_start_date', 'visit_start_datetime', 'visit_end_date', 'visit_end_datetime', 'visit_type_concept_id', 'provider_id_1', 'care_site_id_1', 'visit_source_value', 'visit_source_concept_id', 'admitting_source_concept_id', 'admitting_source_value', 'discharge_to_concept_id', 'discharge_to_source_value', 'preceding_visit_occurrence_id'
    var: 'data_table_concept_id', 'data_table_concept_id_mapped', 'concept_id', 'concept_name', 'domain_id', 'vocabulary_id', 'concept_class_id', 'standard_concept', 'c

## Prepare Data for PyPOTS

Extract the time series tensor - PyPOTS works directly with EHRData's `.layers` format!


In [59]:
# Extract time series data: (n_patients, n_variables, n_timepoints)
X = edata.layers["measurements"]

print(f"Time series shape: {X.shape}")
print(f"Missing values: {np.isnan(X).sum()} / {X.size} ({np.isnan(X).mean() * 100:.1f}%)")

Time series shape: (99, 450, 24)
Missing values: 1063929 / 1069200 (99.5%)


## Task 2: Mortality Prediction with BRITS

Now let's predict in-hospital mortality using BRITS, which handles missing values during classification.

First, prepare the labels from OMOP's `death` table:


In [60]:
# Get death information from the database
death_df = con.execute("SELECT person_id, death_datetime FROM death").df()

# Merge with our cohort to get mortality labels
obs_with_death = edata.obs.merge(death_df, on="person_id", how="left", suffixes=("", "_death"))

# Create binary labels (1 = died, 0 = survived)
y = (~obs_with_death["death_datetime"].isna()).astype(int).values

print(f"Mortality rate: {y.mean() * 100:.1f}%")
print(f"Deaths: {y.sum()} / {len(y)} patients")

Mortality rate: 14.1%
Deaths: 14 / 99 patients


In [61]:
# Split into train/test (simple split for demonstration)
n_train = int(0.7 * len(X))

X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]

print(f"Training set: {len(X_train)} patients ({y_train.mean() * 100:.1f}% mortality)")
print(f"Test set: {len(X_test)} patients ({y_test.mean() * 100:.1f}% mortality)")

Training set: 69 patients (17.4% mortality)
Test set: 30 patients (6.7% mortality)


In [62]:
X_train.shape

(69, 450, 24)

In [63]:
y_train

array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 0, 0])

In [64]:
# Initialize BRITS classifier
torch.manual_seed(42)
brits = BRITS(n_steps=X.shape[2], n_features=X.shape[1], rnn_hidden_size=64, n_classes=2, epochs=20, batch_size=16)

# Train the model
print("Training BRITS...")
brits.fit({"X": X_train.transpose(0, 2, 1), "y": y_train})

# Make predictions
predictions = brits.predict({"X": X_test.transpose(0, 2, 1)})
pred_labels = predictions["classification"]

# Calculate accuracy
accuracy = (pred_labels == y_test).mean()
print(f"\nTest Accuracy: {accuracy * 100:.1f}%")
print(f"Baseline (predict majority class): {max(y_test.mean(), 1 - y_test.mean()) * 100:.1f}%")

2026-01-23 13:44:24 [INFO]: No given device, using default device: cpu
2026-01-23 13:44:24 [INFO]: Using customized CrossEntropy as the training loss function.
2026-01-23 13:44:24 [INFO]: Using customized CrossEntropy as the validation metric function.


2026-01-23 13:44:24 [INFO]: BRITS initialized with the given hyperparameters, the number of trainable parameters: 2,233,780


Training BRITS...


2026-01-23 13:44:25 [INFO]: Epoch 001 - training loss (CrossEntropy): 8.8671
2026-01-23 13:44:25 [INFO]: Epoch 002 - training loss (CrossEntropy): 7.6410
2026-01-23 13:44:25 [INFO]: Epoch 003 - training loss (CrossEntropy): 7.1593
2026-01-23 13:44:26 [INFO]: Epoch 004 - training loss (CrossEntropy): 7.6085
2026-01-23 13:44:26 [INFO]: Epoch 005 - training loss (CrossEntropy): 6.9683
2026-01-23 13:44:27 [INFO]: Epoch 006 - training loss (CrossEntropy): 6.5387
2026-01-23 13:44:28 [INFO]: Epoch 007 - training loss (CrossEntropy): 7.3746
2026-01-23 13:44:28 [INFO]: Epoch 008 - training loss (CrossEntropy): 5.8364
2026-01-23 13:44:29 [INFO]: Epoch 009 - training loss (CrossEntropy): 5.8552
2026-01-23 13:44:30 [INFO]: Epoch 010 - training loss (CrossEntropy): 6.2894
2026-01-23 13:44:30 [INFO]: Epoch 011 - training loss (CrossEntropy): 5.8102
2026-01-23 13:44:31 [INFO]: Epoch 012 - training loss (CrossEntropy): 5.6322
2026-01-23 13:44:32 [INFO]: Epoch 013 - training loss (CrossEntropy): 6.1972


Test Accuracy: 93.3%
Baseline (predict majority class): 93.3%


### Understanding the Results

**What are we predicting?**
- **Target**: In-hospital mortality (death during the ICU stay)
- **Input**: 24 hours of vital signs and lab measurements from ICU admission
- **Model**: BRITS learns patterns in time-series data while handling missing values

**Important caveats for this demo:**

```{warning}
This demonstration uses only **100 ICU visits** from the MIMIC-IV demo dataset. Real clinical prediction models require:
- **Much larger datasets** (thousands of patients)
- **Careful feature engineering** and clinical domain knowledge
- **Proper validation** (cross-validation, external validation)
- **Clinical evaluation** and prospective testing

The model performance shown here is **not clinically meaningful** due to the small sample size and simplified preprocessing. This tutorial demonstrates the *technical workflow*, not a production-ready model.
```


## Key Advantages of This Approach

1. **OMOP Standardization** - Same code works across different hospitals
2. **Cohort Definition** - SQL queries filter for specific patient populations
3. **Time-Aware Analysis** - Visit start times define temporal reference (t=0)
4. **Missing Data Handling** - PyPOTS handles incomplete clinical data natively
5. **Rapid Prototyping** - From database to predictions in minimal code

## Summary

In this tutorial, we learned:

- ✅ How to apply PyPOTS models to OMOP data loaded via ehrdata
- ✅ **Direct integration**: `edata.layers` can be used directly in PyPOTS without manual extraction
- ✅ **Imputation**: Using SAITS to fill missing values in clinical time series
- ✅ **Classification**: Using BRITS to predict mortality while handling missing data
- ✅ The seamless workflow: **OMOP data** → **EHRData** → **PyPOTS models**

## Key Takeaways

**Why this workflow is powerful:**
1. **OMOP standardization** - Your code works across different hospitals' data
2. **EHRData structure** - Clean 3D tensor format (patients × variables × time)
3. **PyPOTS integration** - Use `edata.layers["measurements"]` directly as input
4. **Minimal code** - From database to predictions in ~30 lines

## Next Tutorial

Continue with **[PhysioNet 2012 Machine Learning](physionet2012_ml)** for another example of an ML prototyping workflow.

## Further Resources

- **[PyPOTS Documentation](https://docs.pypots.com/)** - Comprehensive documentation for PyPOTS models and utilities