<a href="https://colab.research.google.com/github/pharringtonp19/housing-and-homelessness/blob/main/notebooks/Residualized_Regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### **Import Libraries**

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

#### **Generate Data**

In [5]:
def Outcome(treatment, control, key):
  k1, k2 = jax.random.split(key)
  treatment = jax.random.bernoulli(k1, p=jax.nn.sigmoid(control), shape=(1,))
  return 2.0*treatment + 3.0*control + jax.random.normal(k2, shape=(1,))

def sample(key):
  k1, k2, k3 = jax.random.split(key, 3)
  treatment = jax.random.bernoulli(k1, shape=(1,))
  control = jax.random.normal(k2, shape=(1,))
  outcome = Outcome(treatment, control, k3)
  return treatment, control, outcome

### **Sample Data**

In [6]:
n_obs = 10000
init_key = jax.random.PRNGKey(0)
Ds, Xs, Ys = jax.vmap(sample)(jax.random.split(init_key, n_obs))
print(Ds.shape, Xs.shape, Ys.shape)

(10000, 1) (10000, 1) (10000, 1)


#### **Linear Regression**

In [8]:
regs = jnp.hstack((Ds, jnp.ones_like(Ds), Xs))
coeffs = jnp.linalg.lstsq(regs, Ys)[0]
print(coeffs)

[[-0.02938269]
 [ 1.023129  ]
 [ 3.396988  ]]


#### **Residualized Regression**

In [9]:
regs1 = jnp.hstack((jnp.ones_like(Ds), Xs))
coeffsD = jnp.linalg.lstsq(regs1, Ds)[0]
Dhat = regs1@coeffsD
D_resid = Ds - Dhat
coeffs = jnp.linalg.lstsq(D_resid, Ys)[0]
print(coeffs)

[[-0.02938769]]
