<a href="https://colab.research.google.com/github/thousandoaks/Intro-Causal-Inference/blob/main/Example_2_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Estimating the causal effect of sodium on blood pressure in a simulated example
adapted from Luque-Fernandez et al. (2018):
    https://academic.oup.com/ije/article/48/2/640/5248195
"""

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression

In [2]:
def generate_data(n=1000, seed=0, beta1=1.05, alpha1=0.4, alpha2=0.3, binary_treatment=True, binary_cutoff=3.5):
    np.random.seed(seed)
    age = np.random.normal(65, 5, n)
    sodium = age / 18 + np.random.normal(size=n)
    if binary_treatment:
        if binary_cutoff is None:
            binary_cutoff = sodium.mean()
        sodium = (sodium > binary_cutoff).astype(int)
    blood_pressure = beta1 * sodium + 2 * age + np.random.normal(size=n)
    proteinuria = alpha1 * sodium + alpha2 * blood_pressure + np.random.normal(size=n)
    hypertension = (blood_pressure >= 140).astype(int)  # not used, but could be used for binary outcomes
    return pd.DataFrame({'blood_pressure': blood_pressure, 'sodium': sodium,
                         'age': age, 'proteinuria': proteinuria})

In [3]:
def estimate_causal_effect(Xt, y, model=LinearRegression(), treatment_idx=0, regression_coef=False):
    model.fit(Xt, y)
    if regression_coef:
        return model.coef_[treatment_idx]
    else:
        Xt1 = pd.DataFrame.copy(Xt)
        Xt1[Xt.columns[treatment_idx]] = 1
        Xt0 = pd.DataFrame.copy(Xt)
        Xt0[Xt.columns[treatment_idx]] = 0
        return (model.predict(Xt1) - model.predict(Xt0)).mean()

In [4]:
binary_t_df = generate_data(beta1=1.05, alpha1=.4, alpha2=.3, binary_treatment=True, n=10000000)
continuous_t_df = generate_data(beta1=1.05, alpha1=.4, alpha2=.3, binary_treatment=False, n=10000000)

In [5]:
ate_est_naive = None
ate_est_adjust_all = None
ate_est_adjust_age = None

Let's try with the binary treatment example first

In [9]:
binary_t_df.sample(3)

Unnamed: 0,blood_pressure,sodium,age,proteinuria
8752272,131.442772,0,65.691963,40.938811
9100273,146.237698,1,72.035257,43.332678
9995594,137.621423,1,68.193237,41.992012


In [11]:
df=binary_t_df

In [12]:
ate_est_naive = estimate_causal_effect(df[['sodium']], df['blood_pressure'], treatment_idx=0)
ate_est_adjust_all = estimate_causal_effect(df[['sodium', 'age', 'proteinuria']],df['blood_pressure'], treatment_idx=0)
ate_est_adjust_age = estimate_causal_effect(df[['sodium', 'age']], df['blood_pressure'])
print('# Adjustment Formula Estimates #')
print('Naive ATE estimate:\t\t\t\t\t\t\t', ate_est_naive)
print('ATE estimate adjusting for all covariates:\t', ate_est_adjust_all)
print('ATE estimate adjusting for age:\t\t\t\t', ate_est_adjust_age)
print()

# Adjustment Formula Estimates #
Naive ATE estimate:							 5.328501680864975
ATE estimate adjusting for all covariates:	 0.8537946431496021
ATE estimate adjusting for age:				 1.0502124539714488



In [13]:
ate_est_naive = estimate_causal_effect(df[['sodium']], df['blood_pressure'], treatment_idx=0,
                                               regression_coef=True)
ate_est_adjust_all = estimate_causal_effect(df[['sodium', 'age', 'proteinuria']],
                                                    df['blood_pressure'], treatment_idx=0,
                                                    regression_coef=True)
ate_est_adjust_age = estimate_causal_effect(df[['sodium', 'age']], df['blood_pressure'],
                                                    regression_coef=True)
print('# Regression Coefficient Estimates #')
print('Naive ATE estimate:\t\t\t\t\t\t\t', ate_est_naive)
print('ATE estimate adjusting for all covariates:\t', ate_est_adjust_all)
print('ATE estimate adjusting for age:\t\t\t\t', ate_est_adjust_age)
print()

# Regression Coefficient Estimates #
Naive ATE estimate:							 5.328501680864978
ATE estimate adjusting for all covariates:	 0.8537946431495851
ATE estimate adjusting for age:				 1.0502124539714823

