In [None]:
# %%
# Imports and configuration
import os
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.neighbors import NearestNeighbors
import statsmodels.api as sm

# Configurable parameters
N = 5000
N_CONFOUNDERS = 6
TREATMENT_EFFECT = 2.0
HETEROGENEOUS = True          # if True, tau varies with X; else constant
CALIPER_SD_UNITS = 0.2        # caliper in SD units of logit(p_hat)
MATCH_REPLACE = False         # matching without replacement
POLY_DEGREE_OUTCOME = 2      # degree for polynomial features in outcome model for AIPW
BOOTSTRAP_B = 200             # bootstrap iterations for AIPW se
OUT_DIR = "/mnt/data/causal_project_outputs"  # change if needed
RNG_SEED = 42

os.makedirs(OUT_DIR, exist_ok=True)

# %%
# 1) Synthetic data generation
def generate_synthetic_data(n=N, n_conf=N_CONFOUNDERS, treatment_effect=TREATMENT_EFFECT, hetero=HETEROGENEOUS, seed=RNG_SEED):
    np.random.seed(seed)
    X = np.random.normal(0, 1, size=(n, n_conf))
    # introduce correlation and nonlinear transforms
    X[:, 1] = 0.6 * X[:, 0] + 0.8 * np.random.normal(size=n)
    nonlin = np.column_stack([
        np.sin(X[:, 0]),
        X[:, 1]**2,
        np.exp(0.2 * X[:, 2]),
        np.tanh(X[:, 3]),
        X[:, 4] * X[:, 5 % n_conf]
    ])
    logits = (-0.2 + 0.8 * X[:, 0]
              - 0.6 * (X[:, 1]**2)
              + 0.5 * np.sin(X[:, 2])
              + 0.3 * X[:, 3]
              + 0.4 * nonlin[:, 1]
              - 0.3 * nonlin[:, 3])
    p = 1.0 / (1.0 + np.exp(-logits))
    T = np.random.binomial(1, p, size=n)
    mu0 = (1.5 * X[:, 0] - 0.7 * (X[:, 1]**2)
           + 0.9 * np.sin(X[:, 2]) + 0.5 * X[:, 3] + 0.3 * X[:, 4] + 0.4 * nonlin[:, 2])
    if hetero:
        tau = treatment_effect + 0.5 * np.tanh(X[:, 0]) - 0.3 * X[:, 1]
    else:
        tau = np.full(n, treatment_effect)
    mu1 = mu0 + tau
    Y = mu0 + T * tau + np.random.normal(0, 1.0, size=n)
    df = pd.DataFrame(X, columns=[f"X{i+1}" for i in range(n_conf)])
    df['T'] = T
    df['Y'] = Y
    df['true_propensity'] = p
    df['true_mu0'] = mu0
    df['true_mu1'] = mu1
    return df

# Generate dataset
df = generate_synthetic_data()
feature_cols = [c for c in df.columns if c.startswith("X")]

# %%
# 2) Fit propensity model (logistic regression)
def fit_propensity_model(df, feature_cols):
    X = df[feature_cols].values
    y = df['T'].values
    model = LogisticRegression(solver='lbfgs', max_iter=1000)
    model.fit(X, y)
    p_hat = model.predict_proba(X)[:, 1]
    eps = 1e-8
    logit = np.log((p_hat + eps) / (1 - p_hat + eps))
    df = df.copy()
    df['p_hat'] = p_hat
    df['logit_p_hat'] = logit
    return df, model

df, prop_model = fit_propensity_model(df, feature_cols)

# %%
# 3) Nearest Neighbor Matching on logit(p_hat) with caliper
def nearest_neighbor_match_on_logit(df, caliper_sd_units=CALIPER_SD_UNITS, replace=MATCH_REPLACE):
    treated = df[df['T'] == 1].copy().reset_index()
    control = df[df['T'] == 0].copy().reset_index()
    sd_logit = df['logit_p_hat'].std()
    cal = caliper_sd_units * sd_logit
    nbrs = NearestNeighbors(n_neighbors=1).fit(control[['logit_p_hat']])
    distances, indices = nbrs.kneighbors(treated[['logit_p_hat']])
    distances = distances.ravel()
    indices = indices.ravel()
    matched_t_idx = []
    matched_c_idx = []
    used_c = set()
    for t_pos, (d, c_pos) in enumerate(zip(distances, indices)):
        if d <= cal:
            orig_t_idx = treated.loc[t_pos, 'index']
            orig_c_idx = control.loc[c_pos, 'index']
            if (not replace) and (c_pos in used_c):
                continue
            matched_t_idx.append(orig_t_idx)
            matched_c_idx.append(orig_c_idx)
            used_c.add(c_pos)
    pairs_df = pd.DataFrame({'treated_idx': matched_t_idx, 'control_idx': matched_c_idx})
    matched_list = []
    for gid, row in pairs_df.reset_index().iterrows():
        trow = df.loc[row['treated_idx']].copy()
        trow['match_group'] = gid
        crow = df.loc[row['control_idx']].copy()
        crow['match_group'] = gid
        matched_list.append(trow)
        matched_list.append(crow)
    if matched_list:
        matched_df = pd.DataFrame(matched_list).reset_index(drop=True)
    else:
        matched_df = pd.DataFrame(columns=df.columns.tolist() + ['match_group'])
    return matched_df, pairs_df, cal

matched_df, pairs_df, cal_used = nearest_neighbor_match_on_logit(df)

# %%
# 4) Standardized Mean Differences (SMD)
def standardized_mean_difference(df_t, df_c, cols):
    smd = {}
    for col in cols:
        m_t = df_t[col].mean()
        m_c = df_c[col].mean()
        s_t = df_t[col].var(ddof=1)
        s_c = df_c[col].var(ddof=1)
        pooled = np.sqrt((s_t + s_c) / 2.0)
        smd[col] = (m_t - m_c) / (pooled + 1e-10)
    return pd.Series(smd)

smd_before = standardized_mean_difference(df[df['T'] == 1], df[df['T'] == 0], feature_cols)
if len(matched_df) > 0:
    smd_after = standardized_mean_difference(matched_df[matched_df['T'] == 1], matched_df[matched_df['T'] == 0], feature_cols)
else:
    smd_after = pd.Series({col: np.nan for col in feature_cols})
smd_df = pd.DataFrame({'SMD_before': smd_before, 'SMD_after': smd_after})

# %%
# 5) PSM ATE from matched pairs (diff-in-means)
def psm_ate_from_pairs(df, pairs_df):
    diffs = []
    for _, row in pairs_df.iterrows():
        y_t = df.loc[row['treated_idx'], 'Y']
        y_c = df.loc[row['control_idx'], 'Y']
        diffs.append(y_t - y_c)
    diffs = np.array(diffs)
    if len(diffs) == 0:
        return np.nan, np.nan
    return diffs.mean(), diffs.std(ddof=1) / np.sqrt(len(diffs))

psm_ate, psm_se = psm_ate_from_pairs(df, pairs_df)

# %%
# 6) Matched-sample regression (Y ~ T + X)
def matched_regression_ate(matched_df, feature_cols):
    X_mat = matched_df[feature_cols + ['T']]
    y_mat = matched_df['Y'].values
    X_sm = sm.add_constant(X_mat)
    res = sm.OLS(y_mat, X_sm).fit(cov_type='HC1')
    coef_T = res.params['T']
    se_T = res.bse['T']
    return float(coef_T), float(se_T), res

if len(matched_df) > 0:
    matched_ate, matched_se, matched_res = matched_regression_ate(matched_df, feature_cols)
else:
    matched_ate, matched_se, matched_res = np.nan, np.nan, None

# %%
# 7) Doubly Robust AIPW estimator
def aipw_estimator(df, feature_cols, poly_degree=POLY_DEGREE_OUTCOME, bootstrap_B=BOOTSTRAP_B, seed=RNG_SEED):
    eps = 1e-8
    poly = PolynomialFeatures(degree=poly_degree, include_bias=False)
    X_poly = poly.fit_transform(df[feature_cols])
    X_for_outcome = np.column_stack([df['T'].values.reshape(-1, 1), X_poly])
    out_model = LinearRegression()
    out_model.fit(X_for_outcome, df['Y'].values)
    m1 = out_model.predict(np.column_stack([np.ones(len(df)).reshape(-1, 1), X_poly]))
    m0 = out_model.predict(np.column_stack([np.zeros(len(df)).reshape(-1, 1), X_poly]))
    p = df['p_hat'].values
    T = df['T'].values
    Y = df['Y'].values
    aipw_scores = (T * (Y - m1) / (p + eps)) - ((1 - T) * (Y - m0) / (1 - p + eps)) + (m1 - m0)
    ate = aipw_scores.mean()
    rng = np.random.default_rng(seed)
    boot_ates = []
    n = len(df)
    for b in range(bootstrap_B):
        idx = rng.integers(0, n, n)
        boot_ates.append(aipw_scores[idx].mean())
    se = float(np.std(boot_ates, ddof=1))
    return float(ate), se

aipw_ate, aipw_se = aipw_estimator(df, feature_cols)

# %%
# 8) Naive and true ATE (for comparison)
naive_ate = float(df[df['T'] == 1]['Y'].mean() - df[df['T'] == 0]['Y'].mean())
true_ate = float((df['true_mu1'] - df['true_mu0']).mean())

# %%
# 9) Save dataset and concise report
data_outfile = os.path.join(OUT_DIR, "synthetic_data.csv")
report_outfile = os.path.join(OUT_DIR, "this_report.txt")
df.to_csv(data_outfile, index=False)

report_lines = []
report_lines.append("Project: Propensity Score Matching & Doubly Robust Estimation")
report_lines.append(f"Dataset: Synthetic (N={len(df)}), {len(feature_cols)} continuous confounders")
report_lines.append("")
report_lines.append("1) Data generation")
report_lines.append(f" - Heterogeneous treatment effect: {HETEROGENEOUS}")
report_lines.append(f" - True ATE (population): {true_ate:.6f}")
report_lines.append("")
report_lines.append("2) Propensity score modeling & matching")
report_lines.append(f" - Logistic regression used for propensity estimation.")
report_lines.append(f" - Caliper on logit(p_hat) used: {cal_used:.6f} (logit units); caliper parameter (SD units) = {CALIPER_SD_UNITS}")
report_lines.append(f" - Matched pairs retained: {len(pairs_df)}")
report_lines.append(" - SMDs before vs after matching:")
for col in feature_cols:
    report_lines.append(f"    - {col}: before={smd_df.loc[col,'SMD_before']:.3f}, after={smd_df.loc[col,'SMD_after']:.3f}")
report_lines.append("")
report_lines.append("3) ATE estimators and results")
report_lines.append(f" - Naive (unadjusted) diff-in-means: {naive_ate:.6f}")
report_lines.append(f" - PSM (matched difference-in-means): {psm_ate:.6f} (SE {psm_se:.6f})")
report_lines.append(f" - Matched-sample regression (Y ~ T + X): {matched_ate:.6f} (SE {matched_se:.6f})")
report_lines.append(f" - Doubly Robust AIPW estimator: {aipw_ate:.6f} (bootstrap SE {aipw_se:.6f})")
report_lines.append("")
report_lines.append("Interpretation:")
report_lines.append(" - Matching reduces covariate imbalance (SMDs close to zero after matching).")
report_lines.append(" - The naive estimator is biased; adjusted methods move toward the true ATE.")
report_lines.append("")
report_lines.append(f"Files saved to: {OUT_DIR}")
report_lines.append(" - synthetic_data.csv")
report_lines.append(" - this_report.txt")

with open(report_outfile, "w") as f:
    f.write("\n".join(report_lines))

# %%
# 10) Print concise summary
print("Summary of results:")
print(f" True ATE: {true_ate:.6f}")
print(f" Naive ATE: {naive_ate:.6f}")
print(f" PSM ATE: {psm_ate:.6f} (SE {psm_se:.6f})")
print(f" Matched regression ATE: {matched_ate:.6f} (SE {matched_se:.6f})")
print(f" AIPW ATE: {aipw_ate:.6f} (SE {aipw_se:.6f})")
print("")
print("SMDs before vs after matching:")
print(smd_df)
print("")
print("Files saved:", os.listdir(OUT_DIR))
print("Dataset saved to:", data_outfile)
print("Report saved to:", report_outfile)
