In [11]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.subplots


def plot_ehrdata_sankey(
    ehrdata,
    *,
    max_obs: int = 20,
    max_vars: int = 20,
    max_timepoints: int = 20,
    obs_indices: list[int] | None = None,
    var_indices: list[int] | None = None,
    time_indices: list[int] | None = None,
    colorscale: str = "Viridis",
    normalize: bool = True,
    link_mode: str = "abs",
    title: str = "EHRData Sankey Diagram",
) -> go.Figure:
    """Create a Sankey diagram visualization of EHRData's 3D tensor (R) and time (t).

    Args:
        ehrdata: The EHRData object containing R and t data
        max_obs: Maximum number of observations to include
        max_vars: Maximum number of variables to include
        max_timepoints: Maximum number of timepoints to include
        obs_indices: Specific observation indices to include
        var_indices: Specific variable indices to include
        time_indices: Specific time indices to include
        colorscale: Plotly colorscale for links
        normalize: Whether to normalize link values to improve visualization
        link_mode: How to treat values: 'abs' uses absolute values, 'signed' preserves sign
        title: Title for the figure

    Returns:
        Plotly figure containing the Sankey diagram
    """
    if ehrdata.R is None:
        msg = "EHRData object does not contain R data for visualization"
        raise ValueError(msg)

    # Extract dimensions and data
    n_obs = ehrdata.n_obs
    n_vars = ehrdata.n_vars
    n_t = ehrdata.n_t

    # Handle indices selection or defaults with limits
    if obs_indices is None:
        obs_indices = list(range(min(max_obs, n_obs)))
    if var_indices is None:
        var_indices = list(range(min(max_vars, n_vars)))
    if time_indices is None:
        # If too many timepoints, sample evenly
        if n_t > max_timepoints:
            time_indices = list(np.linspace(0, n_t - 1, max_timepoints, dtype=int))
        else:
            time_indices = list(range(n_t))

    # Get subset of data
    R_subset = ehrdata.R[obs_indices][:, var_indices][:, :, time_indices]

    # Create node labels
    obs_labels = [f"Obs {i}" for i in obs_indices]

    # Use variable names from ehrdata if available
    if hasattr(ehrdata, "var") and hasattr(ehrdata.var, "index"):
        var_labels = [str(ehrdata.var.index[i]) for i in var_indices]
    else:
        var_labels = [f"Var {i}" for i in var_indices]

    # Use time values from ehrdata if available
    if hasattr(ehrdata, "t") and hasattr(ehrdata.t, "time_value") and "time_value" in ehrdata.t:
        time_labels = [f"T {ehrdata.t.time_value.iloc[i]:.2f}" for i in time_indices]
    elif hasattr(ehrdata, "t") and hasattr(ehrdata.t, "index"):
        time_labels = [str(ehrdata.t.index[i]) for i in time_indices]
    else:
        time_labels = [f"T {i}" for i in time_indices]

    all_labels = obs_labels + var_labels + time_labels

    # Create node colors for Sankey - consistent colors for each variable node
    node_obs_colors = ["rgba(31, 119, 180, 0.8)"] * len(obs_labels)  # Blue for observations

    # Generate distinct colors for variables (will be used consistently)
    node_var_colors = [f"rgba({(i * 50) % 255}, {(i * 120) % 255}, {(i * 180) % 255}, 0.8)" for i in var_indices]

    node_time_colors = ["rgba(44, 160, 44, 0.8)"] * len(time_labels)  # Green for timepoints

    # Combine all colors
    node_colors = node_obs_colors + node_var_colors + node_time_colors

    # Create source, target, and value arrays for Sankey
    sources = []
    targets = []
    values = []
    link_colors = []  # Colors for links based on variables

    # Offset indices for the different node types
    var_offset = len(obs_labels)
    time_offset = var_offset + len(var_labels)

    # Create links from observations to variables
    R_sum = np.nansum(R_subset, axis=2)  # Sum over time dimension

    for i, obs_idx in enumerate(range(len(obs_indices))):
        for j, var_idx in enumerate(range(len(var_indices))):
            if not np.isnan(R_sum[obs_idx, var_idx]) and R_sum[obs_idx, var_idx] != 0:
                sources.append(i)
                targets.append(var_offset + j)
                if link_mode == "abs":
                    values.append(abs(R_sum[obs_idx, var_idx]))
                else:
                    values.append(R_sum[obs_idx, var_idx])
                link_colors.append(node_var_colors[j])

    # Create links from variables to timepoints
    R_sum_var = np.nansum(R_subset, axis=0)  # Sum over observation dimension

    for j, var_idx in enumerate(range(len(var_indices))):
        for k, time_idx in enumerate(range(len(time_indices))):
            if not np.isnan(R_sum_var[j, time_idx]) and R_sum_var[j, time_idx] != 0:
                sources.append(var_offset + j)
                targets.append(time_offset + k)
                if link_mode == "abs":
                    values.append(abs(R_sum_var[j, time_idx]))
                else:
                    values.append(R_sum_var[j, time_idx])
                link_colors.append(node_var_colors[j])

    # Normalize values if requested
    if normalize and values:
        max_val = max(abs(v) for v in values)
        values = [v / max_val for v in values]

    # Create Sankey diagram
    fig = go.Figure(
        data=[
            go.Sankey(
                node={
                    "pad": 15,
                    "thickness": 20,
                    "line": {"color": "black", "width": 0.5},
                    "label": all_labels,
                    "color": node_colors,
                },
                link={
                    "source": sources,
                    "target": targets,
                    "value": values,
                    "color": link_colors,  # Use variable-specific colors for links
                },
            )
        ]
    )

    # Update layout
    fig.update_layout(
        title_text=title,
        font_size=10,
        autosize=True,
        height=600,
        margin={"l": 50, "r": 50, "b": 100, "t": 100, "pad": 4},
    )

    return fig


def plot_ehrdata_3d(
    ehrdata,
    *,
    obs_indices: list[int] | None = None,
    var_indices: list[int] | None = None,
    plot_type: str = "time_series",
) -> go.Figure:
    """Create a comprehensive 3D visualization of EHRData.

    Args:
        ehrdata: The EHRData object to visualize
        obs_indices: List of observation indices to include
        var_indices: List of variable indices to include
        plot_type: Type of visualization to create

    Returns:
        Plotly figure containing the visualization
    """
    if ehrdata.R is None:
        msg = "EHRData object does not contain R data for visualization"
        raise ValueError(msg)

    # Default to first 5 observations/variables if not specified
    if obs_indices is None:
        obs_indices = list(range(min(5, ehrdata.n_obs)))
    if var_indices is None:
        var_indices = list(range(min(5, ehrdata.n_vars)))

    # Get time points
    if hasattr(ehrdata.t, "time_value") and "time_value" in ehrdata.t:
        time_values = ehrdata.t.time_value.values
    else:
        time_values = np.arange(ehrdata.n_t)

    if plot_type == "time_series":
        # Create subplot grid - one row per observation, one column per variable
        n_selected_obs = len(obs_indices)
        n_selected_vars = len(var_indices)

        # Get variable names
        if hasattr(ehrdata, "var") and hasattr(ehrdata.var, "index"):
            var_names = [str(ehrdata.var.index[i]) for i in var_indices]
        else:
            var_names = [f"Var {i}" for i in var_indices]

        # Create subplots
        fig = plotly.subplots.make_subplots(
            rows=n_selected_obs,
            cols=n_selected_vars,
            subplot_titles=[var_names[j] for j in range(n_selected_vars)] * n_selected_obs,
            shared_xaxes=True,
            vertical_spacing=0.05,
            horizontal_spacing=0.05,
        )

        # Generate color map for variables - consistent colors across patients
        var_colors = [
            f"rgba({(var_idx * 50) % 255}, {(var_idx * 120) % 255}, {(var_idx * 180) % 255}, 0.8)"
            for var_idx in var_indices
        ]

        # Add time series traces for each observation and variable
        for i, obs_idx in enumerate(obs_indices):
            for j, var_idx in enumerate(var_indices):
                time_series = ehrdata.R[obs_idx, var_idx, :]

                # Filter out NaN values
                valid_indices = ~np.isnan(time_series)
                valid_times = time_values[valid_indices]
                valid_values = time_series[valid_indices]

                if len(valid_values) > 0:
                    fig.add_trace(
                        go.Scatter(
                            x=valid_times,
                            y=valid_values,
                            mode="lines+markers",
                            name=f"Var {var_idx}",
                            line={"width": 2, "color": var_colors[j]},
                            marker={"size": 6, "color": var_colors[j]},
                        ),
                        row=i + 1,
                        col=j + 1,
                    )

        # Update layout
        fig.update_layout(
            height=250 * n_selected_obs,
            width=200 * n_selected_vars,
            title_text="EHRData Time Series Visualization",
            showlegend=False,
        )

        # Update x-axis titles for bottom row
        for j in range(n_selected_vars):
            fig.update_xaxes(title_text="Time", row=n_selected_obs, col=j + 1)

        # Update y-axis titles for left column
        for i in range(n_selected_obs):
            fig.update_yaxes(title_text=f"Obs {obs_indices[i]}", row=i + 1, col=1)

    elif plot_type == "heatmap":
        # Create a heatmap for each observation
        fig = plotly.subplots.make_subplots(
            rows=len(obs_indices),
            cols=1,
            subplot_titles=[f"Observation {obs_idx}" for obs_idx in obs_indices],
            vertical_spacing=0.1,
        )

        for i, obs_idx in enumerate(obs_indices):
            # Extract the data for this observation
            heatmap_data = ehrdata.R[obs_idx, :, :]

            # Keep only selected variables
            heatmap_data = heatmap_data[var_indices, :]

            fig.add_trace(
                go.Heatmap(
                    z=heatmap_data,
                    x=time_values,
                    y=[f"Var {var_idx}" for var_idx in var_indices],
                    colorscale="Viridis",
                    colorbar={"title": "Value", "y": (1.0 / len(obs_indices)) * (i + 0.5)},
                    zmin=np.nanmin(ehrdata.R),  # Consistent scale across all observations
                    zmax=np.nanmax(ehrdata.R),
                ),
                row=i + 1,
                col=1,
            )

        # Update layout
        fig.update_layout(height=300 * len(obs_indices), width=800, title_text="EHRData Heatmap Visualization")

    elif plot_type == "surface":
        # Create 3D surface plot
        fig = go.Figure()

        for i, obs_idx in enumerate(obs_indices):
            # Create meshgrid for x (variables) and y (time)
            var_grid, time_grid = np.meshgrid(var_indices, time_values)

            # Get the data for this observation
            surface_data = ehrdata.R[obs_idx, :, :]

            # Keep only selected variables
            surface_data = surface_data[var_indices, :].T

            fig.add_trace(
                go.Surface(
                    x=var_grid,
                    y=time_grid,
                    z=surface_data,
                    name=f"Obs {obs_idx}",
                    colorscale="Viridis",
                    showscale=(i == 0),  # Show colorbar only for first surface
                    opacity=0.9,
                    surfacecolor=np.ones_like(surface_data)
                    * var_indices[j],  # Use variable index for color consistency
                )
            )

        # Update layout
        fig.update_layout(
            title_text="EHRData 3D Surface Visualization",
            scene={"xaxis_title": "Variables", "yaxis_title": "Time", "zaxis_title": "Value"},
            height=800,
            width=800,
        )

    return fig

In [None]:
from ehrdata import EHRData

# Create sample data
n_obs = 50
n_vars = 10
n_time = 20

# Create random data for demonstration
X = np.random.randn(n_obs, n_vars)
R = np.random.randn(n_obs, n_vars, n_time)

# Create dataframes for obs, var, t
obs = pd.DataFrame(index=[f"Patient_{i}" for i in range(n_obs)])
var = pd.DataFrame(index=[f"Biomarker_{i}" for i in range(n_vars)])
t = pd.DataFrame(index=pd.date_range("2023-01-01", periods=n_time, freq="D"))

# Create EHRData object
ehr_data = EHRData(X=X, R=R, obs=obs, var=var, t=t)

plot_ehrdata_sankey(ehr_data)

In [13]:
plot_ehrdata_3d(ehr_data)