# Causal Forest Analysis

Heterogeneous treatment effects using Causal Forests

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from src.models.causal_forest import fit_causal_forest

# Load data
df = pd.read_csv('../data/processed/preprocessed_ad_data.csv')
confounders = pd.read_csv('../data/processed/confounders.csv')['confounder'].tolist()

print(f"Dataset: {len(df)} observations, {len(confounders)} confounders")

In [None]:
# Fit causal forest
cf_results = fit_causal_forest(df, confounders[:10])  # Use top 10 confounders for speed

print("=== CAUSAL FOREST RESULTS ===")
print(f"Average Treatment Effect: {cf_results['ate']:.4f} ± {cf_results['ate_se']:.4f}")
print(f"95% CI: [{cf_results['ate_ci_lower']:.4f}, {cf_results['ate_ci_upper']:.4f}]")
print(f"Heterogeneity p-value: {cf_results['heterogeneity_pvalue']:.4f}")

treatment_effects = cf_results['treatment_effects']
print(f"\nEffect Range: [{treatment_effects.min():.4f}, {treatment_effects.max():.4f}]")
print(f"Effect Std Dev: {treatment_effects.std():.4f}")

In [None]:
# Visualize treatment effects
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Distribution of individual effects
axes[0].hist(treatment_effects, bins=30, alpha=0.7, color='skyblue')
axes[0].axvline(cf_results['ate'], color='red', linestyle='--', label=f'ATE = {cf_results["ate"]:.4f}')
axes[0].set_title('Treatment Effect Distribution')
axes[0].legend()

# Effects by age quartiles
df['treatment_effect'] = treatment_effects
df['age_quartile'] = pd.qcut(df['age'], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
age_effects = df.groupby('age_quartile')['treatment_effect'].mean()

axes[1].bar(age_effects.index, age_effects.values, alpha=0.8, color='lightcoral')
axes[1].set_title('Effects by Age Quartile')
axes[1].set_ylabel('Average Treatment Effect')

# Effects by income quartiles
df['income_quartile'] = pd.qcut(df['income'], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])
income_effects = df.groupby('income_quartile')['treatment_effect'].mean()

axes[2].bar(income_effects.index, income_effects.values, alpha=0.8, color='lightgreen')
axes[2].set_title('Effects by Income Quartile')
axes[2].set_ylabel('Average Treatment Effect')

plt.tight_layout()
plt.show()

In [None]:
# Targeting recommendations
high_responders = treatment_effects >= np.percentile(treatment_effects, 75)
low_responders = treatment_effects <= np.percentile(treatment_effects, 25)

print("=== TARGETING RECOMMENDATIONS ===")
print(f"High Responders (Top 25%): {high_responders.sum()} users")
print(f"  Average Effect: {treatment_effects[high_responders].mean():.4f}")
print(f"  Recommendation: Premium targeting")

print(f"\nLow Responders (Bottom 25%): {low_responders.sum()} users")
print(f"  Average Effect: {treatment_effects[low_responders].mean():.4f}")
print(f"  Recommendation: Reduce or exclude from campaigns")

# Profile differences
print(f"\n📊 HIGH-VALUE SEGMENT PROFILE:")
for var in ['age', 'income', 'website_visits']:
    if var in df.columns:
        high_val_mean = df[high_responders][var].mean()
        overall_mean = df[var].mean()
        diff_pct = (high_val_mean - overall_mean) / overall_mean * 100
        print(f"  {var.title()}: {high_val_mean:.1f} ({diff_pct:+.1f}% vs overall)")

# Save results
import pickle
with open('../results/causal_forest_results.pkl', 'wb') as f:
    pickle.dump(cf_results, f)
print("\n✅ Results saved")