<a href="https://colab.research.google.com/github/pharringtonp19/econometrics/blob/main/notebooks/probability_and_statistics/phillips.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### **Import Libraries**

In [1]:
import jax 
import jax.numpy as jnp 
from functools import partial 

#### **Helper Functions**

In [13]:
def dgp1(key):
    k1, k2 = jax.random.split(key)
    d = jax.random.bernoulli(k1).astype('float')
    y = 1*d + 0.5*jax.random.normal(k2)
    return d, y

def dgp2(key):
    k1, k2, k3 = jax.random.split(key, 3)
    z = jax.random.bernoulli(k1).astype('float')
    d = jax.random.bernoulli(k2, p=0.4 + 0.12*z).astype('float')
    y = 1*d + 0.5*jax.random.normal(k3)
    return z, d, y

def sample(init_key, dgp, n):
    return  jax.vmap(dgp)(jax.random.split(init_key, n))


def fwl(X, D, Y):
    coeffsD = jnp.linalg.lstsq(X, D, rcond=None)[0]
    dhat = X @ coeffsD
    resD = D - dhat
    coeffsY = jnp.linalg.lstsq(resD, Y, rcond=None)[0][0]
    return coeffsY

def compliance_est(init_key, n):
    ds, ys = sample(init_key, dgp1, n)
    return fwl(jnp.ones_like(ds).reshape(-1,1), ds.reshape(-1,1), ys.reshape(-1,1))

def fst_stage(X, Z, D): 
    regs = jnp.hstack([Z, X])
    coeffs = jnp.linalg.lstsq(regs, D, rcond=None)[0]
    Dhat = regs @ coeffs
    return Dhat

def partial_compliance_est(init_key, n):
    zs, ds, ys = sample(init_key, dgp2, n)
    dhat = fst_stage(jnp.ones_like(zs).reshape(-1,1),zs.reshape(-1,1), ds.reshape(-1,1))
    return fwl(jnp.ones_like(ds).reshape(-1,1),dhat, ys.reshape(-1,1))

#### **Parameters**

In [14]:
init_key = jax.random.PRNGKey(0)
n_sims = 1000 

#### **Perfect Compliance**

In [15]:
n_obs = 144
compliance = jax.vmap(partial(compliance_est, n=n_obs))(jax.random.split(init_key, n_sims))
print(f"Mean: {jnp.mean(compliance)}\nStd: {jnp.std(compliance)}")

Mean: 1.000201940536499
Std: 0.08543296903371811


  return getattr(self.aval, name).fun(self, *args, **kwargs)


##### **Imperfect Compliance**

In [17]:
n_obs = 10_000
partial_compliance = jax.vmap(partial(partial_compliance_est, n=n_obs))(jax.random.split(init_key, n_sims))
print(f"Mean: {jnp.mean(partial_compliance)}\nStd: {jnp.std(partial_compliance)}")

Mean: 1.0019097328186035
Std: 0.08531661331653595
