<a href="https://colab.research.google.com/github/pharringtonp19/econometrics/blob/main/notebooks/probability_and_statistics/phillips.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### **Import Libraries**

In [12]:
import jax
import jax.numpy as jnp
from functools import partial

#### **Helper Functions**

In [13]:
def dgp(fstage, key):
    k1, k2, k3 = jax.random.split(key, 3)
    z = jax.random.bernoulli(k1).astype('float')
    d = jax.random.bernoulli(k2, p=jnp.where(z==1., fstage, 0.0)).astype('float')
    y = 1*d + 0.1*jax.random.normal(k3)
    return z, d, y

def sample(init_key, fstage, n):
    return  jax.vmap(partial(dgp, fstage))(jax.random.split(init_key, n))

def fwl(X, D, Y):
    coeffsD = jnp.linalg.lstsq(X, D, rcond=None)[0]
    dhat = X @ coeffsD
    resD = D - dhat
    coeffsY = jnp.linalg.lstsq(resD, Y, rcond=None)[0][0]
    return coeffsY

def fst_stage(X, Z, D):
    regs = jnp.hstack([Z, X])
    coeffs = jnp.linalg.lstsq(regs, D, rcond=None)[0]
    Dhat = regs @ coeffs
    return Dhat

def estimate(init_key, fstage, n):
    zs, ds, ys = sample(init_key, fstage, n)
    dhat = fst_stage(jnp.ones_like(zs).reshape(-1,1),zs.reshape(-1,1), ds.reshape(-1,1))
    return fwl(jnp.ones_like(ds).reshape(-1,1),dhat, ys.reshape(-1,1))

#### **Parameters**

In [14]:
init_key = jax.random.PRNGKey(0)
n_sims = 1000

#### **Perfect Compliance**

In [15]:
n_obs = 144
compliance = jax.vmap(partial(estimate, fstage=1., n=n_obs))(jax.random.split(init_key, n_sims))
print(f"Mean: {jnp.mean(compliance)}\nStd: {jnp.std(compliance)}")

Mean: 0.9990527629852295
Std: 0.016402436420321465


  return lax_numpy.astype(arr, dtype)


##### **Imperfect Compliance**

In [18]:
n_obs = 250
partial_compliance = jax.vmap(partial(estimate, fstage=0.12, n=n_obs))(jax.random.split(init_key, n_sims))
print(f"Mean: {jnp.mean(partial_compliance)}\nStd: {jnp.std(partial_compliance)}")

Mean: 1.0010396242141724
Std: 0.12306595593690872
