# Longitudinal EHR Data Analysis with ehrapy and ehrdata

This notebook demonstrates how to analyze longitudinal electronic health records (EHR) data using `ehrapy` and `ehrdata`. We will use the Physionet2012 Challenge dataset to showcase:

- Data loading and description
- Cohort tracking
- Longitudinal feature imputation
- Time series visualization
- Sankey diagrams (static and time-based)
- Machine learning for representation learning
- Clustering with Leiden algorithm
- UMAP visualization
- Feature ranking based on clusters

The Physionet2012 Challenge dataset contains ICU stay data with 37 numeric features measured over 48 hours for approximately 4,000 patients.


## Imports


In [None]:
import ehrdata as ed
import ehrapy as ep
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import holoviews as hv

hv.extension("bokeh")


## Data Loading and Description

We load the Physionet2012 Challenge dataset using `ehrdata`. This dataset contains temporal data in a 3D format (patients × features × timepoints).


In [None]:
# Load the Physionet2012 dataset
edata = ed.dt.physionet2012(layer="tem_data")


In [None]:
# Display basic information about the dataset
print("Dataset shape:", edata.shape)
print("Number of patients:", edata.n_obs)
print("Number of features:", edata.n_vars)
print("Number of timepoints:", edata.tem.shape[0] if hasattr(edata, 'tem') else "N/A")
print("\nFeature names:")
print(edata.var_names[:10])  # Show first 10 features
print("\nObservation metadata columns:")
print(list(edata.obs.columns)[:10])  # Show first 10 obs columns


In [None]:
# Display the dataset object
edata


## Cohort Tracking

We use `CohortTracker` to monitor changes in patient demographics and characteristics throughout the analysis pipeline.


In [None]:
# Identify categorical columns for tracking
categorical_cols = []
for col in edata.obs.columns:
    if edata.obs[col].dtype == 'object' or edata.obs[col].dtype.name == 'category':
        unique_vals = edata.obs[col].nunique()
        if unique_vals < 20:  # Consider columns with < 20 unique values as categorical
            categorical_cols.append(col)

print("Categorical columns for tracking:", categorical_cols[:5])  # Show first 5


In [None]:
# Initialize CohortTracker
# We'll track a subset of columns that are available in the dataset
tracking_cols = [col for col in ['Gender', 'Age', 'In-hospital_death'] if col in edata.obs.columns]
if len(tracking_cols) == 0:
    # If those columns don't exist, use available columns
    tracking_cols = list(edata.obs.columns)[:5]

ct = ep.tl.CohortTracker(edata, columns=tracking_cols, categorical=categorical_cols[:3] if categorical_cols else None)

# Track initial state
ct(edata, label="Initial cohort", operations_done="Data loaded")


## Longitudinal Feature Imputation

For longitudinal data, we can use imputation methods that work across the time axis. We'll use `simple_impute` with a layer parameter to impute missing values in the temporal data.


In [None]:
# Check missing values before imputation
if "tem_data" in edata.layers:
    missing_before = np.isnan(edata.layers["tem_data"]).sum()
    total_values = edata.layers["tem_data"].size
    print(f"Missing values before imputation: {missing_before} ({100*missing_before/total_values:.2f}%)")
    
    # Perform longitudinal imputation using mean strategy
    ep.pp.simple_impute(edata, layer="tem_data", strategy="mean")
    
    # Check missing values after imputation
    missing_after = np.isnan(edata.layers["tem_data"]).sum()
    print(f"Missing values after imputation: {missing_after} ({100*missing_after/total_values:.2f}%)")
    
    # Track after imputation
    ct(edata, label="After imputation", operations_done="Longitudinal imputation (mean)")
else:
    print("tem_data layer not found")


## Time Series Visualization

We can visualize time series data for individual patients and features using `ep.pl.timeseries`.


In [None]:
# Plot time series for a few patients and features
if "tem_data" in edata.layers and hasattr(edata, 'tem'):
    # Select a few patients and features to visualize
    n_patients = min(3, edata.n_obs)
    n_features = min(3, edata.n_vars)
    
    plot = ep.pl.timeseries(
        edata,
        obs_names=slice(0, n_patients),
        var_names=slice(0, n_features),
        layer="tem_data",
        overlay=False,
        title="Time Series for Selected Patients and Features"
    )
    hv.render(plot)
else:
    print("Temporal data layer not available for plotting")


## Sankey Diagrams

Sankey diagrams are useful for visualizing flows and transitions. We'll create both a standard Sankey diagram and a time-based Sankey diagram.


In [None]:
# Standard Sankey diagram for categorical variables in obs
# Find categorical columns that can be used for Sankey
sankey_cols = []
for col in edata.obs.columns:
    if col in categorical_cols or edata.obs[col].nunique() < 10:
        sankey_cols.append(col)
        if len(sankey_cols) >= 3:
            break

if len(sankey_cols) >= 2:
    sankey_plot = ep.pl.sankey_diagram(edata, columns=sankey_cols[:3], title="Patient Flow Across Categories")
    hv.render(sankey_plot)
else:
    print("Not enough categorical columns for Sankey diagram")


In [None]:
# Time-based Sankey diagram for a categorical variable
# First, we need to identify or create a categorical variable in the temporal data
# For demonstration, we'll use a feature that can be discretized

if "tem_data" in edata.layers and edata.n_vars > 0:
    # Use the first feature and discretize it into states
    # This is a simplified example - in practice, you'd use an actual categorical feature
    tem_data = edata.layers["tem_data"]
    
    # Discretize the first feature into 3 states (low, medium, high)
    feature_idx = 0
    feature_data = tem_data[:, feature_idx, :]
    
    # Create states based on percentiles
    for t in range(feature_data.shape[1]):
        values = feature_data[:, t]
        valid_values = values[~np.isnan(values)]
        if len(valid_values) > 0:
            p33 = np.nanpercentile(valid_values, 33)
            p66 = np.nanpercentile(valid_values, 66)
            
            # Create discrete states
            states = np.zeros_like(values, dtype=int)
            states[values < p33] = 0
            states[(values >= p33) & (values < p66)] = 1
            states[values >= p66] = 2
            
            # Store in a temporary layer for visualization
            if t == 0:
                discrete_layer = np.zeros((edata.n_obs, 1, feature_data.shape[1]), dtype=int)
            discrete_layer[:, 0, t] = states
    
    # Create a temporary edata with discrete states
    edata_temp = edata.copy()
    edata_temp.layers["discrete_states"] = discrete_layer
    edata_temp.var_names = ["state_feature"]
    
    # Create state labels
    state_labels = {0: "Low", 1: "Medium", 2: "High"}
    
    # Plot time-based Sankey
    sankey_time_plot = ep.pl.sankey_diagram_time(
        edata_temp,
        var_name="state_feature",
        layer="discrete_states",
        state_labels=state_labels,
        title="State Transitions Over Time"
    )
    hv.render(sankey_time_plot)
else:
    print("Temporal data not available for time-based Sankey diagram")


## Machine Learning for Representation Learning

We'll use a simple approach to learn patient representations from the temporal data. For this, we'll extract features at a specific timepoint and use them for downstream analysis.


In [None]:
# Extract patient representations at the final timepoint (or a specific timepoint)
# This creates a 2D representation from the 3D temporal data
if "tem_data" in edata.layers:
    # Use the last timepoint for representation
    final_timepoint = edata.layers["tem_data"][:, :, -1]  # Shape: (n_obs, n_vars)
    
    # Store in obsm for downstream analysis
    edata.obsm["X_final_timepoint"] = final_timepoint
    
    # Also create a mean representation across all timepoints
    mean_representation = np.nanmean(edata.layers["tem_data"], axis=2)
    edata.obsm["X_mean_timepoint"] = mean_representation
    
    print(f"Created representations:")
    print(f"  - Final timepoint: {edata.obsm['X_final_timepoint'].shape}")
    print(f"  - Mean across timepoints: {edata.obsm['X_mean_timepoint'].shape}")
else:
    print("Temporal data layer not available")


## Leiden Clustering

We'll perform Leiden clustering on the patient representations. First, we need to compute a neighborhood graph.


In [None]:
# Compute neighbors using the mean representation
if "X_mean_timepoint" in edata.obsm:
    ep.pp.neighbors(edata, use_rep="X_mean_timepoint", n_neighbors=15, n_pcs=None)
    
    # Perform Leiden clustering
    ep.tl.leiden(edata, resolution=0.5, key_added="leiden")
    
    print(f"Leiden clusters computed. Number of clusters: {edata.obs['leiden'].nunique()}")
    print(f"Cluster distribution:\n{edata.obs['leiden'].value_counts().head()}")
    
    # Track after clustering
    ct(edata, label="After clustering", operations_done="Leiden clustering")
else:
    print("Representations not available for clustering")


## UMAP Visualization

We'll compute and visualize UMAP embeddings to explore the patient representations in 2D space.


In [None]:
# Compute UMAP embedding
if "neighbors" in edata.uns:
    ep.tl.umap(edata)
    
    # Visualize UMAP with Leiden clusters
    ep.pl.umap(edata, color="leiden", title="UMAP colored by Leiden clusters", show=False)
    plt.tight_layout()
    plt.show()
    
    # Also visualize with other metadata if available
    if "In-hospital_death" in edata.obs.columns:
        ep.pl.umap(edata, color=["leiden", "In-hospital_death"], title=["Leiden clusters", "In-hospital death"], show=False)
        plt.tight_layout()
        plt.show()
else:
    print("Neighbors not computed. Please run ep.pp.neighbors first.")


## Feature Ranking Based on Leiden Clusters

We'll identify features that are differentially expressed across Leiden clusters using `rank_features_groups`.


In [None]:
# Rank features based on Leiden clusters
# We'll use the mean representation stored in X for ranking
if "leiden" in edata.obs.columns:
    # Set X to the mean representation for ranking
    edata.X = edata.obsm["X_mean_timepoint"].copy()
    
    # Rank features
    ep.tl.rank_features_groups(edata, groupby="leiden", n_features=10)
    
    # Visualize top ranked features
    ep.pl.rank_features_groups(edata, n_features=5, key="rank_features_groups", show=False)
    plt.tight_layout()
    plt.show()
    
    # Get ranked features as DataFrame
    for cluster in sorted(edata.obs["leiden"].unique())[:3]:  # Show top 3 clusters
        df = ep.get.rank_features_groups_df(edata, group=cluster, key="rank_features_groups")
        print(f"\nTop features for cluster {cluster}:")
        print(df.head())
else:
    print("Leiden clusters not found. Please run clustering first.")


## Cohort Tracking Summary

Let's visualize the cohort changes tracked throughout the analysis.


In [None]:
# Plot cohort tracking barplot
if len(ct._tracked_tables) > 1:
    try:
        ct.plot_cohort_barplot()
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"Could not plot cohort barplot: {e}")
        print("This is expected if the tracked columns don't have the right format")
else:
    print("Not enough tracking steps to plot")


## Summary

This notebook demonstrated:

1. **Data Loading**: Loading and exploring the Physionet2012 longitudinal EHR dataset
2. **Cohort Tracking**: Monitoring patient cohort changes throughout the analysis
3. **Longitudinal Imputation**: Handling missing values in temporal data
4. **Time Series Visualization**: Plotting patient trajectories over time
5. **Sankey Diagrams**: Visualizing flows between categories and state transitions over time
6. **Representation Learning**: Extracting patient representations from temporal data
7. **Leiden Clustering**: Identifying patient subgroups
8. **UMAP Visualization**: Exploring patient relationships in 2D space
9. **Feature Ranking**: Identifying features that distinguish between clusters

These tools enable comprehensive analysis of longitudinal EHR data, from preprocessing to visualization and interpretation.
