<a href="https://colab.research.google.com/github/pharringtonp19/housing-and-homelessness/blob/main/notebooks/Residualized_Regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### **Import Libraries**

In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

#### **Generate Data**

In [18]:
def expected_treatment(control):
  return jax.nn.sigmoid(control)

def Outcome_and_Treatment(control, key):
  k1, k2 = jax.random.split(key)
  treatment = jax.random.bernoulli(k1, p=expected_treatment(control), shape=(1,))
  return 2.0*treatment + 3.0*control + jax.random.normal(k2, shape=(1,)), treatment

def sample(key):
  k1, k2 = jax.random.split(key)
  control = jax.random.normal(k1, shape=(1,))
  outcome, treatment = Outcome_and_Treatment(control, k2)
  return treatment, control, outcome

### **Sample Data**

In [19]:
n_obs = 10000
init_key = jax.random.PRNGKey(0)
Ds, Xs, Ys = jax.vmap(sample)(jax.random.split(init_key, n_obs))
print(Ds.shape, Xs.shape, Ys.shape)

(10000, 1) (10000, 1) (10000, 1)


#### **Linear Regression**

In [20]:
regs = jnp.hstack((Ds, jnp.ones_like(Ds), Xs))
coeffs = jnp.linalg.lstsq(regs, Ys)[0]
print(coeffs)

[[ 1.9936699]
 [-0.0033822]
 [ 3.0164602]]


#### **Residualized Regression**

In [21]:
regs1 = jnp.hstack((jnp.ones_like(Ds), Xs))
coeffsD = jnp.linalg.lstsq(regs1, Ds)[0]
Dhat = regs1@coeffsD
D_resid = Ds - Dhat
coeffs = jnp.linalg.lstsq(D_resid, Ys)[0]
print(coeffs)

[[1.9936712]]
