# Machine Learning on OMOP Data with PyPOTS

This tutorial demonstrates how to quickly apply machine learning to OMOP data using [PyPOTS](https://github.com/WenjieDu/PyPOTS), a powerful toolkit for time series analysis.

**Prerequisites:** Complete the [OMOP Introduction tutorial](omop_tables_tutorial.ipynb) first to understand how to load OMOP data into EHRData.

## Use Case: ICU Mortality Prediction

We'll predict in-hospital mortality for ICU patients using the [MIMIC-IV demo dataset in OMOP format](https://physionet.org/content/mimic-iv-demo-omop/0.9/).

```{note}
This is a demonstration example. Real clinical prediction requires more sophisticated preprocessing, validation, and careful consideration of clinical context.
```


## Setup and Installation


In [None]:
%pip install pypots ehrapy

In [None]:
# PyPOTS requires this for scipy compatibility
import os

os.environ["SCIPY_ARRAY_API"] = "1"

import ehrdata as ed
import duckdb
import numpy as np
import torch
from pypots.classification import BRITS
from pypots.imputation import SAITS

## Setup Database and Download Data


In [None]:
# Create database connection
con = duckdb.connect(":memory:")

# Download MIMIC-IV OMOP demo data
ed.dt.mimic_iv_omop(backend_handle=con)

## Define the Cohort

We'll focus on ICU patients by filtering `visit_occurrence` for ICU stays using OMOP concept IDs:
- **4305366**: Surgical ICU
- **40481392**: Medical ICU
- **32037**: Intensive Care
- **763903**: Trauma ICU
- **4149943**: Cardiac ICU

If a patient had multiple ICU stays, we select their **first ICU visit**:


In [None]:
# Filter for first ICU visit per patient
con.execute("""
    WITH RankedVisits AS (
        SELECT
            v.*,
            vd.*,
            ROW_NUMBER() OVER (PARTITION BY v.person_id ORDER BY v.visit_start_date) AS rn
        FROM visit_occurrence v
        JOIN visit_detail vd USING (visit_occurrence_id)
        WHERE vd.visit_detail_concept_id IN (4305366, 40481392, 32037, 763903, 4149943)
    ),
    first_icu_visit_occurrence_id AS (
        SELECT visit_occurrence_id
        FROM RankedVisits
        WHERE rn = 1
    )
    DELETE FROM visit_occurrence
    WHERE visit_occurrence_id NOT IN (SELECT visit_occurrence_id FROM first_icu_visit_occurrence_id)
""")

# Check how many ICU visits remain
n_visits = con.execute("SELECT COUNT(*) FROM visit_occurrence").fetchone()[0]
print(f"ICU cohort: {n_visits} patients (first ICU visit only)")

## Build EHRData from OMOP

Now we construct the EHRData object using **ICU visit start** as the time reference (t=0) for each patient:


In [None]:
# Step 1: Setup observations from person + visit_occurrence
edata = ed.io.omop.setup_obs(
    backend_handle=con,
    observation_table="person_visit_occurrence",  # Each row = one ICU visit
)

print(f"Created EHRData with {edata.n_obs} ICU visits")
edata.obs.head()

In [None]:
# Step 2: Extract measurements from the first 24 hours
edata = ed.io.omop.setup_variables(
    edata=edata,
    backend_handle=con,
    layer="measurements",
    data_tables=["measurement"],
    data_field_to_keep={"measurement": "value_as_number"},
    interval_length_number=1,
    interval_length_unit="h",  # Hourly intervals
    num_intervals=24,  # First 24 hours
    aggregation_strategy="last",
    enrich_var_with_feature_info=True,
    instantiate_tensor=True,
)

print(f"\nFinal shape: {edata.n_obs} ICU visits × {edata.n_vars} variables × {edata.n_tem} hours")
edata

## Prepare Data for PyPOTS

Extract the time series tensor - PyPOTS works directly with EHRData's `.layers` format!


In [None]:
# Extract time series data: (n_patients, n_variables, n_timepoints)
X = edata.layers["measurements"]

print(f"Time series shape: {X.shape}")
print(f"Missing values: {np.isnan(X).sum()} / {X.size} ({np.isnan(X).mean() * 100:.1f}%)")

## Task 1: Imputation with SAITS

Clinical ICU data has many missing values. Let's use SAITS (Self-Attention-based Imputation) to intelligently fill them:


In [None]:
# Initialize SAITS imputer
torch.manual_seed(42)
saits = SAITS(
    n_steps=X.shape[2],
    n_features=X.shape[1],
    n_layers=2,
    d_model=128,
    n_heads=4,
    d_k=32,
    d_v=32,
    d_ffn=64,
    dropout=0.1,
    epochs=10,
    batch_size=16,
)

# Fit and impute
saits.fit({"X": X})
imputed_result = saits.impute({"X": X})

print(f"\nImputed data shape: {imputed_result['imputation'].shape}")
print(f"Remaining missing: {np.isnan(imputed_result['imputation']).sum()}")

## Task 2: Mortality Prediction with BRITS

Now let's predict in-hospital mortality using BRITS, which handles missing values during classification.

First, prepare the labels from OMOP's `death` table:


In [None]:
# Get death information from the database
death_df = con.execute("SELECT person_id, death_date FROM death").df()

# Merge with our cohort to get mortality labels
obs_with_death = edata.obs.merge(death_df, on="person_id", how="left", suffixes=("", "_death"))

# Create binary labels (1 = died, 0 = survived)
y = (~obs_with_death["death_date_death"].isna()).astype(int).values

print(f"Mortality rate: {y.mean() * 100:.1f}%")
print(f"Deaths: {y.sum()} / {len(y)} patients")

In [None]:
# Split into train/test (simple split for demonstration)
n_train = int(0.7 * len(X))

X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]

print(f"Training set: {len(X_train)} patients ({y_train.mean() * 100:.1f}% mortality)")
print(f"Test set: {len(X_test)} patients ({y_test.mean() * 100:.1f}% mortality)")

In [None]:
# Initialize BRITS classifier
torch.manual_seed(42)
brits = BRITS(n_steps=X.shape[2], n_features=X.shape[1], rnn_hidden_size=64, n_classes=2, epochs=20, batch_size=16)

# Train the model
print("Training BRITS...")
brits.fit({"X": X_train, "y": y_train})

# Make predictions
predictions = brits.predict({"X": X_test})
pred_labels = predictions["classification"]

# Calculate accuracy
accuracy = (pred_labels == y_test).mean()
print(f"\nTest Accuracy: {accuracy * 100:.1f}%")
print(f"Baseline (predict majority class): {max(y_test.mean(), 1 - y_test.mean()) * 100:.1f}%")

## Summary

In this tutorial, we learned:

- ✅ How to define clinical cohorts using SQL queries on OMOP tables
- ✅ How to use `person_visit_occurrence` for visit-level analysis
- ✅ How to extract ICU time series with hourly resolution
- ✅ How to apply PyPOTS models to OMOP data:
  - **SAITS** for intelligent missing value imputation
  - **BRITS** for mortality prediction with built-in imputation
- ✅ The complete workflow: **OMOP cohort** → **EHRData** → **PyPOTS ML**

## Key Advantages of This Approach

1. **OMOP Standardization** - Same code works across different hospitals
2. **Cohort Definition** - SQL queries filter for specific patient populations
3. **Time-Aware Analysis** - Visit start times define temporal reference (t=0)
4. **Missing Data Handling** - PyPOTS handles incomplete clinical data natively
5. **Rapid Prototyping** - From database to predictions in minimal code

## Next Steps

- Try different cohorts: chronic disease patients, emergency admissions, etc.
- Experiment with different time windows and aggregation strategies
- Use other PyPOTS models: `CRLI` (clustering), `Transformer` (forecasting)
- Combine multiple OMOP tables: measurements + drugs + procedures
- See the [PyPOTS tutorial](tutorial_time_series_with_pypots.ipynb) for more advanced features
- Explore the [OHDSI Book](https://ohdsi.github.io/TheBookOfOhdsi/) for OMOP best practices


# Machine Learning on OMOP Data with PyPOTS

This tutorial demonstrates how to quickly apply machine learning to OMOP data using [PyPOTS](https://github.com/WenjieDu/PyPOTS), a powerful toolkit for time series analysis.

**Prerequisites:** Complete the [OMOP Introduction tutorial](omop_tables_tutorial.ipynb) first to understand how to load OMOP data into EHRData.

## What is PyPOTS?

PyPOTS provides state-of-the-art neural network models for time series tasks:
- **Imputation** - Fill missing values in incomplete time series
- **Classification** - Predict outcomes from time series
- **Forecasting** - Predict future values
- **Clustering** - Group similar patients

PyPOTS works seamlessly with EHRData objects!


## Setup and Installation


In [None]:
%pip install pypots ehrapy

In [None]:
# PyPOTS requires this for scipy compatibility
import os

os.environ["SCIPY_ARRAY_API"] = "1"

import ehrdata as ed
import duckdb
import numpy as np
import torch
from pypots.classification import BRITS
from pypots.imputation import SAITS

## Load OMOP Data

We'll use the same MIMIC-IV OMOP demo dataset from the introduction tutorial:


In [None]:
# Setup database and download data
con = duckdb.connect(":memory:")
data_path = ed.dt.mimic_iv_omop(backend_handle=con)

# Setup observations
edata = ed.io.omop.setup_obs(backend_handle=con, observation_table="person", death_table=True)

print(f"Loaded {edata.n_obs} patients")

In [None]:
# Setup variables - extract time series measurements
edata = ed.io.omop.setup_variables(
    edata=edata,
    backend_handle=con,
    layer="measurements",
    data_tables=["measurement"],
    data_field_to_keep={"measurement": "value_as_number"},
    interval_length_number=1,
    interval_length_unit="day",
    num_intervals=7,  # First week of data
    aggregation_strategy="mean",
    enrich_var_with_feature_info=True,
)

print(f"Final shape: {edata.n_obs} patients × {edata.n_vars} variables × {edata.n_tem} days")
edata

## Prepare Data for PyPOTS

PyPOTS expects data in a specific format. EHRData's `.layers` attribute is already compatible!


In [None]:
# Extract the time series data
# Shape: (n_patients, n_variables, n_timepoints)
X = edata.layers["measurements"]

print(f"Time series shape: {X.shape}")
print(f"Missing values: {np.isnan(X).sum()} / {X.size} ({np.isnan(X).mean() * 100:.1f}%)")

## Task 1: Imputation with SAITS

Clinical data has many missing values. SAITS (Self-Attention-based Imputation) can fill them intelligently:


## Task 2: Mortality Prediction with BRITS

Now let's predict in-hospital mortality using BRITS (Bidirectional Recurrent Imputation for Time Series), which handles missing values during classification:


In [None]:
# Prepare labels from death information in .obs
y = (~edata.obs["death_date"].isna()).astype(int).values

print(f"Mortality rate: {y.mean() * 100:.1f}%")
print(f"Deaths: {y.sum()} / {len(y)} patients")

In [None]:
# Split data (simple split for demonstration)
n_train = int(0.7 * len(X))

X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]

print(f"Training set: {len(X_train)} patients")
print(f"Test set: {len(X_test)} patients")

In [None]:
# Initialize BRITS classifier
brits = BRITS(n_steps=X.shape[2], n_features=X.shape[1], rnn_hidden_size=128, n_classes=2, epochs=10, batch_size=32)

# Train the model
brits.fit({"X": X_train, "y": y_train})

# Make predictions
predictions = brits.predict({"X": X_test})
pred_labels = predictions["classification"]

# Calculate accuracy
accuracy = (pred_labels == y_test).mean()
print(f"Test Accuracy: {accuracy * 100:.1f}%")

## Summary

In this tutorial, we learned:

- ✅ How to apply PyPOTS models to OMOP data loaded via ehrdata
- ✅ **Imputation**: Using SAITS to fill missing values in clinical time series
- ✅ **Classification**: Using BRITS to predict mortality while handling missing data
- ✅ The seamless workflow: **OMOP data** → **EHRData** → **PyPOTS models**

## Key Takeaways

**Why this workflow is powerful:**
1. **OMOP standardization** - Your code works across different hospitals' data
2. **EHRData structure** - Clean 3D tensor format (patients × variables × time)
3. **PyPOTS integration** - State-of-the-art models work directly with EHRData
4. **Minimal code** - From database to predictions in ~20 lines

## Next Steps

- Try other PyPOTS models: `CRLI` for clustering, `Transformer` for forecasting
- Use different OMOP tables: `drug_exposure`, `procedure_occurrence`
- Apply to your own OMOP database following the same workflow
- See the [PyPOTS documentation](https://docs.pypots.com/) for more models
- Explore the [full PyPOTS tutorial](tutorial_time_series_with_pypots.ipynb) for advanced features
