<a href="https://colab.research.google.com/github/pharringtonp19/housing-and-homelessness/blob/main/notebooks/Residualized_Regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### **Import Libraries**

In [2]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

#### **Generate Data**

In [9]:
def Outcome(treatment, controls, key):
  return 2.0*treatment + controls + jax.random.normal(key, shape=(1,))

def sample(key):
  k1, k2, k3 = jax.random.split(key, 3)
  treatment = jax.random.bernoulli(k1, shape=(1,))
  controls = jax.random.uniform(k2, shape=(1,))
  outcome = Outcome(treatment, controls, k3)
  return treatment, controls, outcome

### **Sample Data**

In [13]:
n_obs = 10000
init_key = jax.random.PRNGKey(0)
Ds, Xs, Ys = jax.vmap(sample)(jax.random.split(init_key, n_obs))
print(Ds.shape, Xs.shape, Ys.shape)

(10000, 1) (10000, 1) (10000, 1)


#### **Linear Regression**

In [18]:
regs = jnp.hstack((Ds, jnp.ones_like(Ds), Xs))
coeffs = jnp.linalg.lstsq(regs, Ys)[0]
print(coeffs)

[[ 1.9840641 ]
 [-0.02037146]
 [ 1.0632734 ]]


#### **Residualized Regression**

In [22]:
regs1 = jnp.hstack((jnp.ones_like(Ds), Xs))
coeffsD = jnp.linalg.lstsq(regs1, Ds)[0]
Dhat = regs1@coeffsD
D_resid = Ds - Dhat
coeffs = jnp.linalg.lstsq(D_resid, Ys)[0]
print(coeffs)

[[1.9840647]]
