# Automated Causal Discovery for Healthcare Systems

This notebook applies causal discovery algorithms (PC and GES) to learn causal mechanisms linking CHW interventions, patient engagement, and health outcomes in Medicaid populations.

## Research Question
What are the true causal mechanisms linking CHW interventions, patient engagement, and health outcomes in Medicaid populations?

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd

# Add project root to path
REPO_ROOT = Path.cwd().parent
sys.path.insert(0, str(REPO_ROOT / "notebooks"))

from causal_discovery.data_loader import (
    load_causal_dataset,
    TemporalConfig,
)
from causal_discovery.algorithms import (
    PCAlgorithm,
    GESAlgorithm,
    CausalGraph,
)

# Plotting configuration
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["figure.dpi"] = 100

## 1. Load and Prepare Data

We create a dataset with temporal structure:
- **Baseline features** (6 months before activation): demographics, risk score, historical utilization
- **Treatment features**: intervention exposure during follow-up
- **Outcome features** (6 months after activation): utilization and costs

In [None]:
# Load data with temporal alignment
config = TemporalConfig(
    baseline_months=6,
    followup_months=6,
    intervention_buffer_days=30,
)

# Sample 1000 members for faster computation
dataset, metadata = load_causal_dataset(
    config=config,
    sample_size=1000,
)

print(f"Dataset shape: {dataset.shape}")
print(f"\nMetadata: {metadata}")
print(f"\nColumns: {list(dataset.columns)}")

In [None]:
# Display basic statistics
dataset.describe()

## 2. Define Temporal Tiers

We enforce temporal precedence by organizing variables into tiers:
- **Tier 0** (baseline/static): Demographics, risk score, baseline utilization
- **Tier 1** (treatment): Intervention exposure
- **Tier 2** (outcomes): Follow-up utilization and costs

Edges can only go from earlier tiers to later tiers.

In [None]:
# Define variable selection and temporal tiers
baseline_vars = [
    "age",
    "gender_female",
    "risk_score",
    "baseline_ed_ct",
    "baseline_ip_ct",
    "baseline_total_paid",
]

treatment_vars = [
    "intervention_any",
    "intervention_count",
]

outcome_vars = [
    "followup_ed_ct",
    "followup_ip_ct",
    "followup_total_paid",
]

# Select variables that exist in dataset
all_vars = baseline_vars + treatment_vars + outcome_vars
selected_vars = [v for v in all_vars if v in dataset.columns]

print(f"Selected variables ({len(selected_vars)}): {selected_vars}")

# Define temporal tiers
temporal_tiers = [
    [v for v in baseline_vars if v in selected_vars],  # Tier 0: Baseline
    [v for v in treatment_vars if v in selected_vars],  # Tier 1: Treatment
    [v for v in outcome_vars if v in selected_vars],    # Tier 2: Outcomes
]

print(f"\nTemporal tiers:")
for i, tier in enumerate(temporal_tiers):
    print(f"  Tier {i}: {tier}")

# Prepare analysis dataset
analysis_data = dataset[selected_vars].copy()

# Remove rows with missing values
analysis_data = analysis_data.dropna()
print(f"\nAnalysis dataset shape after dropna: {analysis_data.shape}")

## 3. Run PC Algorithm

The PC (Peter-Clark) algorithm is a constraint-based method that:
1. Tests conditional independence to learn the skeleton (undirected graph)
2. Orients edges using v-structures and temporal constraints

In [None]:
# Initialize PC algorithm
pc = PCAlgorithm(
    alpha=0.05,  # Significance level
    max_conditioning_set_size=3,  # Limit for computational efficiency
    temporal_tiers=temporal_tiers,
)

# Fit the algorithm
print("Running PC algorithm...")
pc_graph = pc.fit(analysis_data, variable_names=selected_vars)

print(f"\nPC Algorithm Results:")
print(f"  Nodes: {len(pc_graph.nodes)}")
print(f"  Directed edges: {len(pc_graph.edges)}")
print(f"  Undirected edges: {len(pc_graph.undirected_edges)}")
print(f"\nDirected edges:")
for edge in sorted(pc_graph.edges):
    print(f"  {edge[0]} -> {edge[1]}")
print(f"\nUndirected edges:")
for edge in sorted(pc_graph.undirected_edges):
    print(f"  {edge[0]} - {edge[1]}")

## 4. Run GES Algorithm

The GES (Greedy Equivalence Search) algorithm is a score-based method that:
1. Starts with an empty graph
2. Greedily adds/removes edges to maximize the BIC score
3. Respects temporal constraints

In [None]:
# Initialize GES algorithm
ges = GESAlgorithm(
    temporal_tiers=temporal_tiers,
    max_iter=100,
)

# Fit the algorithm
print("Running GES algorithm...")
ges_graph = ges.fit(analysis_data, variable_names=selected_vars)

print(f"\nGES Algorithm Results:")
print(f"  Nodes: {len(ges_graph.nodes)}")
print(f"  Directed edges: {len(ges_graph.edges)}")
print(f"\nDirected edges:")
for edge in sorted(ges_graph.edges):
    print(f"  {edge[0]} -> {edge[1]}")

## 5. Visualize Causal Graphs

In [None]:
def visualize_causal_graph(graph: CausalGraph, title: str = "Causal Graph"):
    """Visualize a causal graph using networkx."""
    G = nx.DiGraph()
    
    # Add nodes
    G.add_nodes_from(graph.nodes)
    
    # Add directed edges
    G.add_edges_from(graph.edges)
    
    # Add undirected edges (if any)
    for u, v in graph.undirected_edges:
        G.add_edge(u, v, style="dashed")
        G.add_edge(v, u, style="dashed")
    
    # Assign colors based on temporal tiers
    node_colors = []
    for node in G.nodes():
        if node in temporal_tiers[0]:
            node_colors.append("lightblue")
        elif node in temporal_tiers[1]:
            node_colors.append("lightgreen")
        elif node in temporal_tiers[2]:
            node_colors.append("lightcoral")
        else:
            node_colors.append("gray")
    
    # Layout
    pos = nx.spring_layout(G, k=2, iterations=50, seed=42)
    
    # Draw
    plt.figure(figsize=(14, 10))
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=3000, alpha=0.9)
    nx.draw_networkx_labels(G, pos, font_size=9, font_weight="bold")
    
    # Draw directed edges
    directed_edges = [(u, v) for u, v in G.edges() if G.get_edge_data(u, v).get("style") != "dashed"]
    nx.draw_networkx_edges(
        G, pos, edgelist=directed_edges,
        edge_color="black", arrows=True, arrowsize=20, width=2,
    )
    
    plt.title(title, fontsize=16, fontweight="bold")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize PC results
visualize_causal_graph(pc_graph, title="PC Algorithm: Learned Causal Graph")

In [None]:
# Visualize GES results
visualize_causal_graph(ges_graph, title="GES Algorithm: Learned Causal Graph")

## 6. Interpret Mechanisms

### Key Questions:
1. **Do interventions cause changes in outcomes?** Look for edges from `intervention_any` or `intervention_count` to outcome variables.
2. **What are the mediating pathways?** Look for chains like `intervention -> baseline_ed_ct -> followup_ed_ct`.
3. **Are there confounders?** Check if baseline variables have edges to both treatment and outcomes.
4. **What drives intervention exposure?** Look at edges pointing to `intervention_any`.

In [None]:
def analyze_mechanisms(graph: CausalGraph, algorithm_name: str):
    """Analyze and interpret causal mechanisms."""
    print(f"\n{'=' * 60}")
    print(f"{algorithm_name} - Mechanism Analysis")
    print(f"{'=' * 60}")
    
    adj = graph.to_adjacency_dict()
    
    # 1. Treatment effects
    print("\n1. INTERVENTION EFFECTS:")
    treatment_nodes = [n for n in graph.nodes if "intervention" in n]
    for treatment in treatment_nodes:
        children = adj[treatment]["children"]
        if children:
            print(f"   {treatment} causes:")
            for child in children:
                print(f"     -> {child}")
        else:
            print(f"   {treatment}: NO DIRECT EFFECTS DETECTED")
    
    # 2. Baseline predictors of outcomes
    print("\n2. BASELINE PREDICTORS OF OUTCOMES:")
    outcome_nodes = [n for n in graph.nodes if "followup" in n]
    for outcome in outcome_nodes:
        parents = adj[outcome]["parents"]
        if parents:
            baseline_parents = [p for p in parents if "baseline" in p or p in ["age", "risk_score", "gender_female"]]
            if baseline_parents:
                print(f"   {outcome} is predicted by:")
                for parent in baseline_parents:
                    print(f"     <- {parent}")
    
    # 3. What drives intervention exposure?
    print("\n3. DRIVERS OF INTERVENTION EXPOSURE:")
    for treatment in treatment_nodes:
        parents = adj[treatment]["parents"]
        if parents:
            print(f"   {treatment} is driven by:")
            for parent in parents:
                print(f"     <- {parent}")
        else:
            print(f"   {treatment}: NO DETECTED CAUSES (exogenous or weakly predicted)")
    
    # 4. Potential mediators
    print("\n4. POTENTIAL MEDIATING PATHWAYS:")
    print("   (Treatment -> Mediator -> Outcome)")
    found_mediator = False
    for treatment in treatment_nodes:
        for mediator in graph.nodes:
            if mediator in treatment_nodes or mediator in outcome_nodes:
                continue
            # Check if treatment -> mediator -> outcome
            if mediator in adj[treatment]["children"]:
                for outcome in outcome_nodes:
                    if outcome in adj[mediator]["children"]:
                        print(f"   {treatment} -> {mediator} -> {outcome}")
                        found_mediator = True
    if not found_mediator:
        print("   (None detected)")
    
    print(f"\n{'=' * 60}\n")

In [None]:
# Analyze PC results
analyze_mechanisms(pc_graph, "PC Algorithm")

In [None]:
# Analyze GES results
analyze_mechanisms(ges_graph, "GES Algorithm")

## 7. Compare Algorithms

Compare the graphs learned by PC and GES to identify consensus mechanisms.

In [None]:
def compare_graphs(graph1: CausalGraph, graph2: CausalGraph, name1: str, name2: str):
    """Compare two causal graphs."""
    edges1 = set(graph1.edges)
    edges2 = set(graph2.edges)
    
    consensus = edges1 & edges2
    only_in_1 = edges1 - edges2
    only_in_2 = edges2 - edges1
    
    print(f"\nGraph Comparison: {name1} vs {name2}")
    print(f"{'=' * 60}")
    print(f"Consensus edges (in both): {len(consensus)}")
    for edge in sorted(consensus):
        print(f"  {edge[0]} -> {edge[1]}")
    
    print(f"\nOnly in {name1}: {len(only_in_1)}")
    for edge in sorted(only_in_1):
        print(f"  {edge[0]} -> {edge[1]}")
    
    print(f"\nOnly in {name2}: {len(only_in_2)}")
    for edge in sorted(only_in_2):
        print(f"  {edge[0]} -> {edge[1]}")
    print(f"{'=' * 60}\n")

compare_graphs(pc_graph, ges_graph, "PC", "GES")

## 8. Export Results

In [None]:
# Save graphs to CSV
output_dir = Path("../results/causal_discovery")
output_dir.mkdir(parents=True, exist_ok=True)

# PC edges
pd.DataFrame(pc_graph.edges, columns=["from", "to"]).to_csv(
    output_dir / "pc_edges.csv", index=False
)

# GES edges
pd.DataFrame(ges_graph.edges, columns=["from", "to"]).to_csv(
    output_dir / "ges_edges.csv", index=False
)

print(f"Results saved to {output_dir}")