# Causal Graphs: DAGs, d-Separation, and Causal Reasoning

This notebook introduces **Directed Acyclic Graphs (DAGs)** as a language for expressing causal assumptions and reasoning about causal effects.

## Learning Objectives

1. Understand DAGs as representations of causal structure
2. Learn the three fundamental connection types (chains, forks, colliders)
3. Master d-separation for identifying conditional independencies
4. Apply these concepts to identify confounders and valid adjustment sets
5. Use computational tools for causal graph analysis

## Prerequisites

- Basic probability (conditional independence)
- Familiarity with confounding (see `01_treatment_effects.ipynb`)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from itertools import combinations

# Set random seed for reproducibility
np.random.seed(42)

# Plotting style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

---

## Part 1: Introduction to DAGs

A **Directed Acyclic Graph (DAG)** is a graph where:
- **Directed**: Edges have arrows indicating direction
- **Acyclic**: No cycles (you can't follow arrows and return to where you started)

In causal DAGs:
- **Nodes** represent variables
- **Edges** represent direct causal effects (arrow points from cause to effect)

### Why DAGs?

DAGs provide a visual and mathematical language for:
1. Encoding causal assumptions
2. Identifying what to control for (and what NOT to control for)
3. Determining if causal effects are identifiable from observational data

In [None]:
# Helper function to draw DAGs
def draw_dag(edges, pos=None, title="", node_colors=None, figsize=(8, 6)):
    """
    Draw a DAG with labeled nodes and directed edges.
    
    Parameters
    ----------
    edges : list of tuples
        List of (source, target) edges
    pos : dict, optional
        Node positions {node: (x, y)}
    title : str
        Plot title
    node_colors : dict, optional
        Colors for specific nodes
    """
    G = nx.DiGraph()
    G.add_edges_from(edges)
    
    if pos is None:
        pos = nx.spring_layout(G, seed=42)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Default colors
    colors = ['lightblue'] * len(G.nodes())
    if node_colors:
        for i, node in enumerate(G.nodes()):
            if node in node_colors:
                colors[i] = node_colors[node]
    
    nx.draw(G, pos, ax=ax, with_labels=True, node_color=colors,
            node_size=2000, font_size=14, font_weight='bold',
            arrows=True, arrowsize=20, edge_color='gray',
            connectionstyle="arc3,rad=0.1")
    
    ax.set_title(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    return G

In [None]:
# Example: Simple confounding DAG
# Z confounds the relationship between X and Y
edges_confounding = [('Z', 'X'), ('Z', 'Y'), ('X', 'Y')]
pos_confounding = {'Z': (0.5, 1), 'X': (0, 0), 'Y': (1, 0)}

draw_dag(edges_confounding, pos_confounding, 
         title="Confounding: Z affects both X and Y",
         node_colors={'Z': 'salmon', 'X': 'lightgreen', 'Y': 'lightyellow'})
plt.show()

---

## Part 2: The Three Fundamental Structures

All causal reasoning in DAGs builds on three basic structures:

### 1. Chain (Mediator): A → B → C
- B **mediates** the effect of A on C
- A and C are dependent (marginally)
- A and C become **independent** when conditioning on B

### 2. Fork (Confounder): A ← B → C
- B is a **common cause** of A and C
- A and C are dependent (marginally)
- A and C become **independent** when conditioning on B

### 3. Collider: A → B ← C
- B is a **common effect** of A and C
- A and C are **independent** (marginally)
- A and C become **dependent** when conditioning on B ("collider bias")

In [None]:
# Visualize the three fundamental structures
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

structures = [
    ([('A', 'B'), ('B', 'C')], "Chain (Mediator)\nA → B → C"),
    ([('B', 'A'), ('B', 'C')], "Fork (Confounder)\nA ← B → C"),
    ([('A', 'B'), ('C', 'B')], "Collider\nA → B ← C")
]

pos = {'A': (0, 0), 'B': (1, 0), 'C': (2, 0)}

for ax, (edges, title) in zip(axes, structures):
    G = nx.DiGraph()
    G.add_edges_from(edges)
    
    nx.draw(G, pos, ax=ax, with_labels=True, node_color='lightblue',
            node_size=1500, font_size=14, font_weight='bold',
            arrows=True, arrowsize=20, edge_color='gray')
    ax.set_title(title, fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Simulate data to demonstrate the three structures
n = 5000

def simulate_chain(n):
    """A → B → C: Conditioning on B blocks the path"""
    A = np.random.normal(0, 1, n)
    B = 0.7 * A + np.random.normal(0, 0.5, n)
    C = 0.7 * B + np.random.normal(0, 0.5, n)
    return pd.DataFrame({'A': A, 'B': B, 'C': C})

def simulate_fork(n):
    """A ← B → C: Conditioning on B blocks the path"""
    B = np.random.normal(0, 1, n)
    A = 0.7 * B + np.random.normal(0, 0.5, n)
    C = 0.7 * B + np.random.normal(0, 0.5, n)
    return pd.DataFrame({'A': A, 'B': B, 'C': C})

def simulate_collider(n):
    """A → B ← C: Conditioning on B OPENS the path"""
    A = np.random.normal(0, 1, n)
    C = np.random.normal(0, 1, n)
    B = 0.5 * A + 0.5 * C + np.random.normal(0, 0.5, n)
    return pd.DataFrame({'A': A, 'B': B, 'C': C})

# Generate data
df_chain = simulate_chain(n)
df_fork = simulate_fork(n)
df_collider = simulate_collider(n)

In [None]:
# Compute correlations to demonstrate independence patterns
from scipy import stats

def partial_correlation(df, x, y, z):
    """Compute partial correlation of x and y given z."""
    # Residualize x and y on z
    from sklearn.linear_model import LinearRegression
    
    lr = LinearRegression()
    x_resid = df[x] - lr.fit(df[[z]], df[x]).predict(df[[z]])
    y_resid = df[y] - lr.fit(df[[z]], df[y]).predict(df[[z]])
    
    return np.corrcoef(x_resid, y_resid)[0, 1]

print("Correlation between A and C:")
print("="*50)
print(f"{'Structure':<15} {'Marginal':<15} {'Given B':<15}")
print("-"*50)

for name, df in [('Chain', df_chain), ('Fork', df_fork), ('Collider', df_collider)]:
    marginal = np.corrcoef(df['A'], df['C'])[0, 1]
    conditional = partial_correlation(df, 'A', 'C', 'B')
    print(f"{name:<15} {marginal:>10.3f}     {conditional:>10.3f}")

print("="*50)
print("\nKey insight:")
print("- Chain/Fork: Conditioning on B BLOCKS the path (correlation → 0)")
print("- Collider: Conditioning on B OPENS the path (0 → correlation)")

---

## Part 3: d-Separation

**d-separation** is the key algorithm for reading conditional independencies from a DAG.

### Definition

Two nodes X and Y are **d-separated** given a set Z if every path between X and Y is "blocked" by Z.

A path is **blocked** by Z if it contains:
1. A chain A → B → C where B ∈ Z, OR
2. A fork A ← B → C where B ∈ Z, OR
3. A collider A → B ← C where B ∉ Z and no descendant of B is in Z

### Implication

If X and Y are d-separated given Z, then X ⊥ Y | Z (conditional independence).

In [None]:
class CausalDAG:
    """
    A simple class for causal DAG analysis.
    
    Provides methods for:
    - Finding paths between nodes
    - Checking d-separation
    - Finding valid adjustment sets
    """
    
    def __init__(self, edges):
        """
        Initialize DAG from edge list.
        
        Parameters
        ----------
        edges : list of tuples
            List of (parent, child) edges
        """
        self.graph = nx.DiGraph()
        self.graph.add_edges_from(edges)
        self.nodes = set(self.graph.nodes())
        
    def parents(self, node):
        """Get parents of a node."""
        return set(self.graph.predecessors(node))
    
    def children(self, node):
        """Get children of a node."""
        return set(self.graph.successors(node))
    
    def ancestors(self, node):
        """Get all ancestors of a node."""
        return nx.ancestors(self.graph, node)
    
    def descendants(self, node):
        """Get all descendants of a node."""
        return nx.descendants(self.graph, node)
    
    def is_collider(self, path, node):
        """
        Check if node is a collider on the given path.
        
        A node is a collider if both adjacent nodes on the path
        point INTO it (→ node ←).
        """
        idx = path.index(node)
        if idx == 0 or idx == len(path) - 1:
            return False
        
        prev_node = path[idx - 1]
        next_node = path[idx + 1]
        
        # Check if both edges point INTO node
        return (self.graph.has_edge(prev_node, node) and 
                self.graph.has_edge(next_node, node))
    
    def is_path_blocked(self, path, conditioning_set):
        """
        Check if a path is blocked by the conditioning set.
        
        A path is blocked if ANY node on the path (except endpoints) is:
        1. A non-collider that IS in the conditioning set, OR
        2. A collider where neither it nor its descendants are in the conditioning set
        """
        conditioning_set = set(conditioning_set)
        
        for i, node in enumerate(path[1:-1], 1):  # Skip endpoints
            is_collider = self.is_collider(path, node)
            
            if is_collider:
                # Collider: blocked if neither it nor descendants are conditioned on
                descendants_and_self = self.descendants(node) | {node}
                if not (descendants_and_self & conditioning_set):
                    return True
            else:
                # Non-collider: blocked if conditioned on
                if node in conditioning_set:
                    return True
        
        return False
    
    def find_all_paths(self, source, target, max_length=10):
        """
        Find all undirected paths between source and target.
        
        Uses the underlying undirected graph to find paths.
        """
        undirected = self.graph.to_undirected()
        try:
            paths = list(nx.all_simple_paths(undirected, source, target, cutoff=max_length))
            return paths
        except nx.NetworkXError:
            return []
    
    def d_separated(self, x, y, conditioning_set=None):
        """
        Check if x and y are d-separated given the conditioning set.
        
        Returns True if ALL paths between x and y are blocked.
        """
        if conditioning_set is None:
            conditioning_set = set()
        else:
            conditioning_set = set(conditioning_set)
        
        paths = self.find_all_paths(x, y)
        
        if not paths:
            return True  # No paths means d-separated
        
        # Check if ALL paths are blocked
        for path in paths:
            if not self.is_path_blocked(path, conditioning_set):
                return False  # Found an open path
        
        return True  # All paths blocked
    
    def backdoor_paths(self, treatment, outcome):
        """
        Find all backdoor paths from treatment to outcome.
        
        Backdoor paths are paths that start with an arrow INTO treatment.
        """
        all_paths = self.find_all_paths(treatment, outcome)
        backdoor = []
        
        for path in all_paths:
            if len(path) < 2:
                continue
            # Check if first edge points INTO treatment
            second_node = path[1]
            if self.graph.has_edge(second_node, treatment):
                backdoor.append(path)
        
        return backdoor
    
    def valid_adjustment_sets(self, treatment, outcome, max_size=3):
        """
        Find valid adjustment sets for estimating causal effect of treatment on outcome.
        
        A valid adjustment set must:
        1. Block all backdoor paths
        2. Not include descendants of treatment
        3. Not open new paths (by conditioning on colliders)
        """
        # Candidates: all nodes except treatment, outcome, and descendants of treatment
        descendants_of_treatment = self.descendants(treatment)
        candidates = self.nodes - {treatment, outcome} - descendants_of_treatment
        
        valid_sets = []
        
        # Check empty set
        backdoor = self.backdoor_paths(treatment, outcome)
        if all(self.is_path_blocked(p, set()) for p in backdoor):
            valid_sets.append(set())
        
        # Check all subsets up to max_size
        for size in range(1, min(max_size + 1, len(candidates) + 1)):
            for subset in combinations(candidates, size):
                subset = set(subset)
                # Check if this blocks all backdoor paths
                if all(self.is_path_blocked(p, subset) for p in backdoor):
                    # Also verify it doesn't open new paths
                    # (simplified check - full check would be more complex)
                    valid_sets.append(subset)
        
        return valid_sets

In [None]:
# Example: Classic confounding DAG
# Z is a confounder of X → Y
edges = [('Z', 'X'), ('Z', 'Y'), ('X', 'Y')]
dag = CausalDAG(edges)

print("DAG: Z → X, Z → Y, X → Y")
print("="*50)
print(f"X and Y d-separated (unconditional): {dag.d_separated('X', 'Y')}")
print(f"X and Y d-separated given Z: {dag.d_separated('X', 'Y', {'Z'})}")
print()
print("Backdoor paths from X to Y:")
for path in dag.backdoor_paths('X', 'Y'):
    print(f"  {' → '.join(path)}")
print()
print("Valid adjustment sets for X → Y:")
for adj_set in dag.valid_adjustment_sets('X', 'Y'):
    print(f"  {adj_set if adj_set else '(empty set)'}")

---

## Part 4: Collider Bias (Selection Bias)

One of the most important insights from causal graphs is **collider bias**: conditioning on a common effect can CREATE spurious associations.

This is counterintuitive because we often think "controlling for more variables is better." But controlling for a collider is harmful!

In [None]:
# Simulate collider bias: "Berkson's Paradox"
# Example: Talent and Attractiveness in Hollywood
#
# Talent → Success ← Attractiveness
#
# Among successful actors, talent and attractiveness appear negatively correlated
# (because you need at least one to succeed)

n = 10000

# Independent causes
talent = np.random.normal(0, 1, n)
attractiveness = np.random.normal(0, 1, n)

# Success depends on both (collider)
success_score = talent + attractiveness + np.random.normal(0, 0.5, n)
is_successful = success_score > 1.5  # Top ~15% succeed

print(f"Overall correlation (talent, attractiveness): {np.corrcoef(talent, attractiveness)[0,1]:.3f}")
print(f"Among successful (conditioning on collider): {np.corrcoef(talent[is_successful], attractiveness[is_successful])[0,1]:.3f}")
print(f"\nNumber successful: {is_successful.sum()} / {n}")

In [None]:
# Visualize collider bias
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Full population
axes[0].scatter(talent, attractiveness, alpha=0.3, s=10)
axes[0].set_xlabel('Talent')
axes[0].set_ylabel('Attractiveness')
axes[0].set_title(f'Full Population\nCorrelation: {np.corrcoef(talent, attractiveness)[0,1]:.3f}')

# Conditioned on success
axes[1].scatter(talent[is_successful], attractiveness[is_successful], 
                alpha=0.5, s=20, color='orange')
axes[1].set_xlabel('Talent')
axes[1].set_ylabel('Attractiveness')
corr_cond = np.corrcoef(talent[is_successful], attractiveness[is_successful])[0,1]
axes[1].set_title(f'Among Successful (Collider Bias)\nCorrelation: {corr_cond:.3f}')

# Add regression lines
for ax, x, y in [(axes[0], talent, attractiveness), 
                  (axes[1], talent[is_successful], attractiveness[is_successful])]:
    z = np.polyfit(x, y, 1)
    p = np.poly1d(z)
    x_line = np.linspace(x.min(), x.max(), 100)
    ax.plot(x_line, p(x_line), 'r--', linewidth=2, label='Trend')
    ax.legend()

plt.tight_layout()
plt.show()

print("\n⚠️  Collider bias creates a NEGATIVE correlation where none exists!")
print("   This is why 'controlling for everything' can be harmful.")

---

## Part 5: Biological Example - Gene Regulatory Network

Let's apply these concepts to a simplified gene regulatory network.

In [None]:
# Biological DAG: Simplified signaling pathway
#
# Stimulus → Receptor → Kinase → TF → Target Gene
#                  ↓              ↑
#              Feedback ──────────┘

bio_edges = [
    ('Stimulus', 'Receptor'),
    ('Receptor', 'Kinase'),
    ('Kinase', 'TF'),
    ('TF', 'Target'),
    ('Receptor', 'Feedback'),
    ('Feedback', 'TF')
]

pos_bio = {
    'Stimulus': (0, 1),
    'Receptor': (1, 1),
    'Kinase': (2, 1),
    'TF': (3, 1),
    'Target': (4, 1),
    'Feedback': (2, 0)
}

draw_dag(bio_edges, pos_bio, 
         title="Gene Regulatory Network",
         node_colors={'Stimulus': 'lightgreen', 'Target': 'lightyellow'})
plt.show()

In [None]:
# Analyze the biological DAG
bio_dag = CausalDAG(bio_edges)

print("Causal Analysis: Effect of Kinase on Target")
print("="*60)

# What are the paths from Kinase to Target?
print("\nAll paths from Kinase to Target:")
for path in bio_dag.find_all_paths('Kinase', 'Target'):
    print(f"  {' → '.join(path)}")

# Backdoor paths?
print("\nBackdoor paths from Kinase to Target:")
backdoor = bio_dag.backdoor_paths('Kinase', 'Target')
if backdoor:
    for path in backdoor:
        print(f"  {' → '.join(path)}")
else:
    print("  (none)")

# Valid adjustment sets
print("\nValid adjustment sets:")
for adj_set in bio_dag.valid_adjustment_sets('Kinase', 'Target'):
    print(f"  {adj_set if adj_set else '(empty set)'}")

In [None]:
# Simulate data from the biological DAG
def simulate_bio_network(n=2000, seed=42):
    np.random.seed(seed)
    
    # Exogenous noise
    noise = lambda: np.random.normal(0, 0.3, n)
    
    # Generate according to DAG
    stimulus = np.random.normal(0, 1, n)
    receptor = 0.8 * stimulus + noise()
    kinase = 0.7 * receptor + noise()
    feedback = 0.6 * receptor + noise()
    tf = 0.5 * kinase + 0.4 * feedback + noise()
    target = 0.8 * tf + noise()
    
    return pd.DataFrame({
        'Stimulus': stimulus,
        'Receptor': receptor,
        'Kinase': kinase,
        'Feedback': feedback,
        'TF': tf,
        'Target': target
    })

df_bio = simulate_bio_network()

# Show correlation matrix
plt.figure(figsize=(8, 6))
corr = df_bio.corr()
plt.imshow(corr, cmap='RdBu_r', vmin=-1, vmax=1)
plt.colorbar(label='Correlation')
plt.xticks(range(len(corr.columns)), corr.columns, rotation=45, ha='right')
plt.yticks(range(len(corr.columns)), corr.columns)
plt.title('Correlation Matrix of Gene Network')

# Add correlation values
for i in range(len(corr)):
    for j in range(len(corr)):
        plt.text(j, i, f'{corr.iloc[i, j]:.2f}', ha='center', va='center', fontsize=9)

plt.tight_layout()
plt.show()

---

## Part 6: The Backdoor Criterion

The **backdoor criterion** tells us when we can identify a causal effect by adjusting for a set of variables.

### Definition

A set Z satisfies the backdoor criterion relative to (X, Y) if:
1. No node in Z is a descendant of X
2. Z blocks every path between X and Y that contains an arrow INTO X

### Adjustment Formula

If Z satisfies the backdoor criterion:

$$P(Y | do(X)) = \sum_z P(Y | X, Z=z) P(Z=z)$$

This is the foundation of methods like IPW and outcome regression.

In [None]:
# More complex DAG with multiple confounders
complex_edges = [
    ('U', 'Z1'),
    ('U', 'Z2'),
    ('Z1', 'X'),
    ('Z2', 'X'),
    ('Z2', 'Y'),
    ('X', 'M'),
    ('M', 'Y'),
    ('X', 'Y')  # Direct effect
]

pos_complex = {
    'U': (1, 2),
    'Z1': (0, 1),
    'Z2': (2, 1),
    'X': (0, 0),
    'M': (1, 0),
    'Y': (2, 0)
}

draw_dag(complex_edges, pos_complex,
         title="Complex DAG with Mediator and Confounders",
         node_colors={'X': 'lightgreen', 'Y': 'lightyellow', 'U': 'salmon', 'Z2': 'salmon'})
plt.show()

In [None]:
# Analyze the complex DAG
complex_dag = CausalDAG(complex_edges)

print("Analysis: Effect of X on Y")
print("="*60)

print("\nAll paths from X to Y:")
for path in complex_dag.find_all_paths('X', 'Y'):
    print(f"  {' → '.join(path)}")

print("\nBackdoor paths:")
for path in complex_dag.backdoor_paths('X', 'Y'):
    print(f"  {' → '.join(path)}")

print("\nValid adjustment sets (up to size 3):")
valid = complex_dag.valid_adjustment_sets('X', 'Y', max_size=3)
for adj_set in valid:
    print(f"  {adj_set if adj_set else '(empty set)'}")

print("\n⚠️  Note: M is a mediator - adjusting for it would block the indirect effect!")
print("   If you want the TOTAL effect, don't adjust for M.")
print("   If you want the DIRECT effect (X→Y), you need different methods.")

---

## Part 7: Common Pitfalls

### 1. Adjusting for a Mediator
If you want the total effect, don't adjust for variables on the causal path.

### 2. Adjusting for a Collider
This opens a path and creates bias.

### 3. Adjusting for a Descendant of Treatment
This can block part of the effect you're trying to measure.

### 4. Unmeasured Confounding
If there's a confounder you can't measure, adjustment won't work.

In [None]:
# Demonstrate the mediator problem
n = 5000
np.random.seed(42)

# X → M → Y (M mediates the effect)
X = np.random.normal(0, 1, n)
M = 0.6 * X + np.random.normal(0, 0.5, n)
Y = 0.7 * M + np.random.normal(0, 0.5, n)  # No direct effect of X

# True total effect: 0.6 * 0.7 = 0.42
from sklearn.linear_model import LinearRegression

# Unadjusted (correct for total effect)
lr_unadj = LinearRegression().fit(X.reshape(-1, 1), Y)
effect_unadj = lr_unadj.coef_[0]

# Adjusted for M (WRONG - blocks the effect)
lr_adj = LinearRegression().fit(np.column_stack([X, M]), Y)
effect_adj = lr_adj.coef_[0]

print("Mediator Problem: X → M → Y")
print("="*50)
print(f"True total effect: 0.42")
print(f"Unadjusted estimate: {effect_unadj:.3f} ✓")
print(f"Adjusted for M: {effect_adj:.3f} ✗ (biased toward 0)")
print("\n⚠️  Adjusting for the mediator blocks the causal path!")

---

## Summary

### Key Concepts

1. **DAGs** encode causal assumptions visually
2. **Three structures**: chains, forks, and colliders have different blocking rules
3. **d-separation** determines conditional independence from graph structure
4. **Backdoor criterion** identifies valid adjustment sets
5. **Collider bias** is a common pitfall - don't adjust for common effects

### Practical Guidelines

- Draw your DAG before analyzing data
- Identify confounders (common causes) and adjust for them
- Don't adjust for mediators (if you want total effects)
- Never adjust for colliders
- Be explicit about unmeasured confounders

### Next Steps

- `03_sensitivity_analysis.ipynb`: What if your DAG is wrong?

---

## Exercises

1. **Draw a DAG** for a biological system you work with. Identify potential confounders.

2. **Collider identification**: In the DAG below, which variables are colliders?
   - A → C ← B
   - C → D
   - A → E ← D

3. **Adjustment sets**: For the DAG A → B → C ← D → E, find all valid adjustment sets for estimating the effect of B on E.

4. **Selection bias**: Design a simulation where conditioning on a collider creates a spurious correlation between two independent variables.