# Chapter 6: Post-Treatment Bias

**When Controlling for Consequences Breaks Causal Inference**

## Goal

Understand **post-treatment bias** - what happens when you include a variable that is a **consequence** of the treatment.

**Key insight:**
- Treatment causes intermediate variable
- Intermediate variable causes outcome
- Controlling for intermediate variable **blocks the causal path**
- Treatment effect disappears (even though it's real!)

**Example**: Plant growth experiment with anti-fungal treatment

---

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats

plt.style.use('default')
%matplotlib inline

np.random.seed(42)

print('‚úì Imports loaded')

In [None]:
# Import quap
import sys
from pathlib import Path

sys.path.append(str(Path.cwd().parent.parent))
from src.quap import quap, QuapResult

print('‚úì Loaded quap')

---

## Step 1: The Causal Structure

**Scenario**: Testing an anti-fungal treatment on plant growth

**The DAG:**
```
H‚ÇÄ (initial height) ‚îÄ‚îÄ‚îÄ‚Üí H‚ÇÅ (final height)
  ‚Üì                       ‚Üë
  ‚Üì                       ‚Üë
  ‚Üì                       ‚Üë
Treatment ‚Üí Fungus ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
```

**Causal paths:**
1. H‚ÇÄ ‚Üí H‚ÇÅ (initial height affects final height)
2. Treatment ‚Üí Fungus ‚Üí H‚ÇÅ (treatment reduces fungus, less fungus increases growth)

**The trap:**
- Fungus is **post-treatment** (caused by treatment)
- If you control for fungus, you block path #2
- Treatment effect vanishes!

---

## Step 2: Simulate the Data

Let's create data following the causal structure.

In [None]:
# Simulate plant growth experiment
n = 100
np.random.seed(71)

# Initial height (standardized)
h0 = np.random.normal(0, 1, n)

# Treatment: 0=control, 1=treated
treatment = np.repeat([0, 1], n//2)

# Fungus: treatment reduces fungus
# Control group has more fungus
fungus = np.random.binomial(1, 0.5 - treatment * 0.4, n)

# Final height: grows from initial height, reduced by fungus
h1 = h0 + np.random.normal(5 - 3 * fungus, 1, n)

print(f"Simulated {n} plants")
print(f"\nTreatment groups:")
print(f"  Control (0): {np.sum(treatment==0)}")
print(f"  Treated (1): {np.sum(treatment==1)}")
print(f"\nFungus presence:")
print(f"  Control group: {np.mean(fungus[treatment==0]):.1%} have fungus")
print(f"  Treated group: {np.mean(fungus[treatment==1]):.1%} have fungus")
print(f"\nFinal height:")
print(f"  Control group: {h1[treatment==0].mean():.2f} ¬± {h1[treatment==0].std():.2f}")
print(f"  Treated group: {h1[treatment==1].mean():.2f} ¬± {h1[treatment==1].std():.2f}")
print(f"\n‚úì Treatment reduces fungus, increases final height!")

In [None]:
# Visualize the data
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: Treatment vs Fungus
ax = axes[0]
fungus_rate = [np.mean(fungus[treatment==0]), np.mean(fungus[treatment==1])]
bars = ax.bar(['Control', 'Treated'], fungus_rate, 
              color=['lightcoral', 'lightgreen'], edgecolor='black', linewidth=1.5)
ax.set_ylabel('Proportion with Fungus', fontsize=11)
ax.set_title('Treatment ‚Üí Fungus\nTreatment reduces fungus', 
             fontsize=12, fontweight='bold')
ax.set_ylim(0, 0.7)
for bar, val in zip(bars, fungus_rate):
    ax.text(bar.get_x() + bar.get_width()/2, val + 0.02, 
           f'{val:.1%}', ha='center', fontsize=10, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Plot 2: Fungus vs Height
ax = axes[1]
heights_by_fungus = [h1[fungus==0], h1[fungus==1]]
bp = ax.boxplot(heights_by_fungus, positions=[0, 1], widths=0.5,
               labels=['No Fungus', 'Fungus'], patch_artist=True)
bp['boxes'][0].set_facecolor('lightgreen')
bp['boxes'][1].set_facecolor('lightcoral')
ax.set_ylabel('Final Height', fontsize=11)
ax.set_title('Fungus ‚Üí Final Height\nFungus reduces growth', 
             fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Plot 3: Treatment vs Height (total effect)
ax = axes[2]
heights_by_treatment = [h1[treatment==0], h1[treatment==1]]
bp2 = ax.boxplot(heights_by_treatment, positions=[0, 1], widths=0.5,
                labels=['Control', 'Treated'], patch_artist=True)
bp2['boxes'][0].set_facecolor('lightcoral')
bp2['boxes'][1].set_facecolor('lightgreen')
ax.set_ylabel('Final Height', fontsize=11)
ax.set_title('Treatment ‚Üí Final Height\nTreatment increases growth (via fungus)', 
             fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nCausal pathway:")
print("  Treatment ‚Üí reduces fungus ‚Üí increases final height")

---

## Step 3: Model WITHOUT Fungus - Correct Total Effect

First, let's fit the model **without** controlling for fungus.

**Model**: h‚ÇÅ ~ h‚ÇÄ + treatment

In [None]:
# Standardize for regression
h0_std = (h0 - h0.mean()) / h0.std()
h1_std = (h1 - h1.mean()) / h1.std()

print("Variables standardized for regression")

In [None]:
# Model 1: h‚ÇÅ ~ h‚ÇÄ + treatment (NO fungus)
def neg_log_posterior_no_fungus(params):
    alpha, beta_h0, beta_treatment, log_sigma = params
    sigma = np.exp(log_sigma)
    mu = alpha + beta_h0 * h0_std + beta_treatment * treatment
    log_lik = np.sum(stats.norm.logpdf(h1_std, loc=mu, scale=sigma))
    log_prior = (stats.norm.logpdf(alpha, 0, 0.2) +
                 stats.norm.logpdf(beta_h0, 0, 0.5) +
                 stats.norm.logpdf(beta_treatment, 0, 0.5) +
                 stats.expon.logpdf(sigma, scale=1))
    return -(log_lik + log_prior + log_sigma)

m_no_fungus = quap(neg_log_posterior_no_fungus, [0, 0, 0, np.log(1)], 
                   ['alpha', 'beta_h0', 'beta_treatment', 'log_sigma'])
m_no_fungus.transform_param('log_sigma', 'sigma', np.exp)

print("Model 1: h‚ÇÅ ~ h‚ÇÄ + treatment (NO fungus)")
print("="*70)
m_no_fungus.summary()

In [None]:
# Interpret results
coef_no_fungus = m_no_fungus.coef()

print("\nResults WITHOUT controlling for fungus:")
print("="*70)
print(f"Œ≤_treatment = {coef_no_fungus['beta_treatment']:.3f}")
print("\n‚úì POSITIVE and SIGNIFICANT!")
print("  Treatment increases final height.")
print("  This is the CORRECT total causal effect.")
print("\n  It captures the full pathway: Treatment ‚Üí Fungus ‚Üí Height")

---

## Step 4: Model WITH Fungus - POST-TREATMENT BIAS!

Now let's add fungus to the model and see what happens.

**Model**: h‚ÇÅ ~ h‚ÇÄ + treatment + fungus

In [None]:
# Model 2: h‚ÇÅ ~ h‚ÇÄ + treatment + fungus (WITH fungus)
def neg_log_posterior_with_fungus(params):
    alpha, beta_h0, beta_treatment, beta_fungus, log_sigma = params
    sigma = np.exp(log_sigma)
    mu = alpha + beta_h0 * h0_std + beta_treatment * treatment + beta_fungus * fungus
    log_lik = np.sum(stats.norm.logpdf(h1_std, loc=mu, scale=sigma))
    log_prior = (stats.norm.logpdf(alpha, 0, 0.2) +
                 stats.norm.logpdf(beta_h0, 0, 0.5) +
                 stats.norm.logpdf(beta_treatment, 0, 0.5) +
                 stats.norm.logpdf(beta_fungus, 0, 0.5) +
                 stats.expon.logpdf(sigma, scale=1))
    return -(log_lik + log_prior + log_sigma)

m_with_fungus = quap(neg_log_posterior_with_fungus, [0, 0, 0, 0, np.log(1)], 
                     ['alpha', 'beta_h0', 'beta_treatment', 'beta_fungus', 'log_sigma'])
m_with_fungus.transform_param('log_sigma', 'sigma', np.exp)

print("Model 2: h‚ÇÅ ~ h‚ÇÄ + treatment + fungus (WITH fungus)")
print("="*70)
m_with_fungus.summary()

In [None]:
# Compare the two models
coef_with_fungus = m_with_fungus.coef()

comparison = pd.DataFrame({
    'Without Fungus': [coef_no_fungus['beta_treatment'], '‚Äî'],
    'With Fungus': [coef_with_fungus['beta_treatment'], coef_with_fungus['beta_fungus']]
}, index=['Œ≤_treatment', 'Œ≤_fungus'])

print("\nModel Comparison:")
print("="*70)
print(comparison)
print("="*70)

print("\n‚ö†Ô∏è POST-TREATMENT BIAS!")
print(f"\n  WITHOUT fungus: Œ≤_treatment = {coef_no_fungus['beta_treatment']:.3f}")
print("    ‚Üí Shows treatment works! (CORRECT)")
print(f"\n  WITH fungus: Œ≤_treatment = {coef_with_fungus['beta_treatment']:.3f}")
print("    ‚Üí Treatment effect DISAPPEARS! (WRONG)")
print(f"\n  WITH fungus: Œ≤_fungus = {coef_with_fungus['beta_fungus']:.3f}")
print("    ‚Üí Fungus has strong negative effect (blocks pathway)")
print("\nüí° By controlling for fungus (a consequence), we BLOCK the causal path!")

---

## Step 5: Why Does This Happen?

**The causal mechanism:**

```
Treatment ‚Üí Fungus ‚Üí Height
```

**Without controlling for fungus:**
- We see the full effect: Treatment ‚Üí reduced fungus ‚Üí taller plants
- This is the **total causal effect**

**With controlling for fungus:**
- We ask: "Among plants with the SAME fungus level, does treatment matter?"
- Answer: No! Treatment only works through fungus
- We've blocked the causal pathway
- This is the **direct effect** (which is zero)

**Analogy**: It's like asking "Does studying improve test scores, holding knowledge constant?"
- Of course not! Studying works BY increasing knowledge
- Controlling for knowledge blocks the causal path

In [None]:
# Visualize the bias with predictions
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Model without fungus (correct)
ax = axes[0]
# Predictions for control vs treated
h0_grid = np.linspace(h0_std.min(), h0_std.max(), 100)
pred_control = coef_no_fungus['alpha'] + coef_no_fungus['beta_h0'] * h0_grid + coef_no_fungus['beta_treatment'] * 0
pred_treated = coef_no_fungus['alpha'] + coef_no_fungus['beta_h0'] * h0_grid + coef_no_fungus['beta_treatment'] * 1

ax.scatter(h0_std[treatment==0], h1_std[treatment==0], s=40, alpha=0.6, 
          color='red', label='Control', edgecolor='black', linewidth=0.5)
ax.scatter(h0_std[treatment==1], h1_std[treatment==1], s=40, alpha=0.6, 
          color='green', label='Treated', edgecolor='black', linewidth=0.5)
ax.plot(h0_grid, pred_control, 'r-', linewidth=3, alpha=0.7, label='Control prediction')
ax.plot(h0_grid, pred_treated, 'g-', linewidth=3, alpha=0.7, label='Treated prediction')

# Add arrow showing treatment effect
mid_x = 0
mid_y_control = coef_no_fungus['alpha'] + coef_no_fungus['beta_treatment'] * 0
mid_y_treated = coef_no_fungus['alpha'] + coef_no_fungus['beta_treatment'] * 1
ax.annotate('', xy=(mid_x, mid_y_treated), xytext=(mid_x, mid_y_control),
           arrowprops=dict(arrowstyle='<->', lw=2, color='blue'))
ax.text(mid_x + 0.3, (mid_y_control + mid_y_treated)/2, 
       f'Treatment\neffect\n{coef_no_fungus["beta_treatment"]:.2f}',
       fontsize=9, color='blue', fontweight='bold')

ax.set_xlabel('Initial Height (standardized)', fontsize=11)
ax.set_ylabel('Final Height (standardized)', fontsize=11)
ax.set_title('Model WITHOUT Fungus\n‚úì Shows treatment effect (CORRECT)', 
            fontsize=12, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# Plot 2: Model with fungus (biased)
ax = axes[1]
# Now predictions are almost identical because we control for fungus
# Average fungus level
avg_fungus = fungus.mean()
pred_control_f = coef_with_fungus['alpha'] + coef_with_fungus['beta_h0'] * h0_grid + coef_with_fungus['beta_treatment'] * 0 + coef_with_fungus['beta_fungus'] * avg_fungus
pred_treated_f = coef_with_fungus['alpha'] + coef_with_fungus['beta_h0'] * h0_grid + coef_with_fungus['beta_treatment'] * 1 + coef_with_fungus['beta_fungus'] * avg_fungus

ax.scatter(h0_std[treatment==0], h1_std[treatment==0], s=40, alpha=0.6, 
          color='red', label='Control', edgecolor='black', linewidth=0.5)
ax.scatter(h0_std[treatment==1], h1_std[treatment==1], s=40, alpha=0.6, 
          color='green', label='Treated', edgecolor='black', linewidth=0.5)
ax.plot(h0_grid, pred_control_f, 'r-', linewidth=3, alpha=0.7, label='Control prediction')
ax.plot(h0_grid, pred_treated_f, 'g-', linewidth=3, alpha=0.7, label='Treated prediction')

ax.set_xlabel('Initial Height (standardized)', fontsize=11)
ax.set_ylabel('Final Height (standardized)', fontsize=11)
ax.set_title('Model WITH Fungus\n‚ö†Ô∏è Treatment effect gone (POST-TREATMENT BIAS)', 
            fontsize=12, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nVisualization shows:")
print("  Left: Lines separated ‚Üí treatment has effect")
print("  Right: Lines overlap ‚Üí treatment effect disappears!")

---

## Step 6: General Lessons

### What is Post-Treatment Bias?

**Definition**: Controlling for a variable that is a **consequence** of the treatment.

**Structure**:
```
Treatment ‚Üí Mediator ‚Üí Outcome
```

**Problem**: Controlling for the mediator blocks the causal pathway from treatment to outcome.

### How to Detect It

1. **Draw the DAG**: Is the variable caused by treatment?
2. **Think temporally**: Does it occur AFTER treatment?
3. **Ask**: Does treatment work THROUGH this variable?

### Common Examples

| Treatment | Post-Treatment Variable | Outcome |
|-----------|------------------------|----------|
| Education | Job skills | Income |
| Exercise | Weight loss | Health |
| Medicine | Symptoms | Recovery |
| Training | Knowledge | Test scores |

In each case, controlling for the middle variable blocks the causal effect!

### What To Do

**Don't control for post-treatment variables if you want total causal effect!**

**When to control:**
- If you specifically want the **direct effect** (e.g., does treatment do anything BESIDES reducing fungus?)
- For mediation analysis (advanced topic)

**When NOT to control:**
- If you want the **total causal effect** (usual goal)
- If the variable is an intermediate step in the causal chain

---

## Step 7: Another Example - Simulated

Let's create another example to solidify understanding: **Job training ‚Üí Skills ‚Üí Salary**

In [None]:
# Simulate job training example
n = 200
np.random.seed(123)

# Education level (confound)
education = np.random.normal(0, 1, n)

# Training: 0=no, 1=yes
training = np.random.binomial(1, 0.5, n)

# Skills: improved by training
skills = education + 2 * training + np.random.normal(0, 0.5, n)

# Salary: determined by skills
salary = 10 * skills + np.random.normal(0, 2, n)

# Standardize
education_std = (education - education.mean()) / education.std()
skills_std = (skills - skills.mean()) / skills.std()
salary_std = (salary - salary.mean()) / salary.std()

print(f"Job training simulation: {n} people")
print(f"\nTraining effect on skills:")
print(f"  No training: skills = {skills[training==0].mean():.2f}")
print(f"  Training: skills = {skills[training==1].mean():.2f}")
print(f"\nTraining effect on salary:")
print(f"  No training: salary = {salary[training==0].mean():.2f}")
print(f"  Training: salary = {salary[training==1].mean():.2f}")

In [None]:
# Model A: Salary ~ Education + Training (NO skills)
def neg_log_posterior_salary_no_skills(params):
    alpha, beta_ed, beta_train, log_sigma = params
    sigma = np.exp(log_sigma)
    mu = alpha + beta_ed * education_std + beta_train * training
    log_lik = np.sum(stats.norm.logpdf(salary_std, loc=mu, scale=sigma))
    log_prior = (stats.norm.logpdf(alpha, 0, 0.2) +
                 stats.norm.logpdf(beta_ed, 0, 0.5) +
                 stats.norm.logpdf(beta_train, 0, 0.5) +
                 stats.expon.logpdf(sigma, scale=1))
    return -(log_lik + log_prior + log_sigma)

m_salary_no_skills = quap(neg_log_posterior_salary_no_skills, [0, 0, 0, np.log(1)], 
                          ['alpha', 'beta_ed', 'beta_train', 'log_sigma'])
m_salary_no_skills.transform_param('log_sigma', 'sigma', np.exp)

print("Model A: Salary ~ Education + Training (NO skills)")
print("="*70)
m_salary_no_skills.summary()

In [None]:
# Model B: Salary ~ Education + Training + Skills (WITH skills)
def neg_log_posterior_salary_with_skills(params):
    alpha, beta_ed, beta_train, beta_skills, log_sigma = params
    sigma = np.exp(log_sigma)
    mu = alpha + beta_ed * education_std + beta_train * training + beta_skills * skills_std
    log_lik = np.sum(stats.norm.logpdf(salary_std, loc=mu, scale=sigma))
    log_prior = (stats.norm.logpdf(alpha, 0, 0.2) +
                 stats.norm.logpdf(beta_ed, 0, 0.5) +
                 stats.norm.logpdf(beta_train, 0, 0.5) +
                 stats.norm.logpdf(beta_skills, 0, 0.5) +
                 stats.expon.logpdf(sigma, scale=1))
    return -(log_lik + log_prior + log_sigma)

m_salary_with_skills = quap(neg_log_posterior_salary_with_skills, [0, 0, 0, 0, np.log(1)], 
                            ['alpha', 'beta_ed', 'beta_train', 'beta_skills', 'log_sigma'])
m_salary_with_skills.transform_param('log_sigma', 'sigma', np.exp)

print("Model B: Salary ~ Education + Training + Skills (WITH skills)")
print("="*70)
m_salary_with_skills.summary()

In [None]:
# Compare
coef_no_skills = m_salary_no_skills.coef()
coef_with_skills = m_salary_with_skills.coef()

comparison2 = pd.DataFrame({
    'Without Skills': [coef_no_skills['beta_train'], '‚Äî'],
    'With Skills': [coef_with_skills['beta_train'], coef_with_skills['beta_skills']]
}, index=['Œ≤_training', 'Œ≤_skills'])

print("\nJob Training Model Comparison:")
print("="*70)
print(comparison2)
print("="*70)

print("\n‚ö†Ô∏è POST-TREATMENT BIAS AGAIN!")
print(f"\n  WITHOUT skills: Œ≤_training = {coef_no_skills['beta_train']:.3f}")
print("    ‚Üí Training increases salary! (CORRECT total effect)")
print(f"\n  WITH skills: Œ≤_training = {coef_with_skills['beta_train']:.3f}")
print("    ‚Üí Training effect DISAPPEARS! (WRONG - post-treatment bias)")
print(f"\n  WITH skills: Œ≤_skills = {coef_with_skills['beta_skills']:.3f}")
print("    ‚Üí Skills strongly predict salary (blocks pathway)")
print("\nüí° Training works BY improving skills. Controlling for skills blocks this!")

---

## Summary

### What We Learned

**1. Post-Treatment Bias Definition**
- Occurs when you control for a **consequence** of the treatment
- Blocks the causal pathway from treatment to outcome
- Makes treatment appear ineffective (even when it works!)

**2. The Causal Structure**
```
Treatment ‚Üí Mediator ‚Üí Outcome
```
- Mediator is POST-treatment (caused by treatment)
- Outcome is caused by mediator
- Controlling for mediator breaks the chain

**3. Examples**
- Plants: Treatment ‚Üí Fungus ‚Üí Height
- Jobs: Training ‚Üí Skills ‚Üí Salary
- Health: Medicine ‚Üí Symptoms ‚Üí Recovery
- Education: Study ‚Üí Knowledge ‚Üí Scores

**4. How to Avoid**
- Draw a DAG BEFORE modeling
- Identify post-treatment variables
- DON'T control for consequences if you want total effect
- Only control for pre-treatment confounders

**5. When It's Okay**
- Mediation analysis (explicitly studying the pathway)
- Estimating direct effects (rare)
- When you understand the causal structure

**Key insight**: Not all variables should be controlled for! Think about causality, not just correlation.

---

**Next**: Collider bias - another way controlling can go wrong!