<a href="https://colab.research.google.com/github/pharringtonp19/rfp/blob/main/notebooks/pld.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax 
import jax.numpy as jnp 

In [219]:
def sample(key, weights, features, clusters):
  subkey1, subkey2, subkey3, subkey4 = jax.random.split(key, 4)
  controls = jax.random.normal(subkey1, shape=(features,))
  time_period = jax.random.bernoulli(subkey2, p=0.3, shape=())
  cluster = jax.random.choice(subkey3, clusters, shape=())
  treatment = (cluster >= (clusters / 2)).astype(jnp.float32)
  outcome = jnp.dot(weights, controls) + treatment + time_period + cluster + 2*treatment*time_period + jax.random.normal(subkey3, shape=())#+ 0.1*treatment*time_period*cluster
  return outcome, treatment, time_period, controls, jax.nn.one_hot(cluster, clusters)

In [220]:
features = 5 
clusters = 6
n = 1000
weights = jax.random.normal(jax.random.PRNGKey(0), shape=(features,))

In [221]:
y, d, t, f, c = jax.vmap(lambda key: sample(key, weights, features, clusters))(jax.random.split(jax.random.PRNGKey(0), n))
y, d, t = y.reshape(-1,1), d.reshape(-1,1), t.reshape(-1,1)

In [223]:
regs = jnp.hstack((d*t, d, jnp.ones_like(d), t, f, c))

In [224]:
ols = jnp.linalg.lstsq(regs, y)[0][0]

In [225]:
ols

DeviceArray([2.0564907], dtype=float32)

In [226]:
type(jnp.where(t==0))

tuple

In [227]:
y_pre = y[jnp.where(t==0)[0]]#.reshape(-1,1)
print(y_pre.shape)
f_pre = f[jnp.where(t==0)[0]]
print(f_pre.shape)
c_pre = c[jnp.where(t==0)[0]]

y_post = y[jnp.where(t==1)[0]].reshape(-1,1)
f_post = f[jnp.where(t==1)[0]]
c_post = c[jnp.where(t==1)[0]]
d_post = d[jnp.where(t==1)[0]]

(691, 1)
(691, 5)


### **Pre-Process**

In [228]:
left_reg_pre = jnp.hstack((jnp.ones_like(y_pre), f_pre, c_pre))
left_reg_post = jnp.hstack((jnp.ones_like(y_post), f_post, c_post))
left_coef = jnp.linalg.lstsq(left_reg_pre, y_pre)[0]
yhat = left_reg_post @ left_coef
ydiff = y_post - yhat

In [229]:
right_reg = jnp.hstack((jnp.ones_like(d_post),f_post))
right_coef = jnp.linalg.lstsq(right_reg, d_post)[0]
dhat = right_reg @ right_coef
ddiff = d_post - dhat

In [230]:
beta = jnp.linalg.lstsq(ddiff, ydiff)[0]
beta

DeviceArray([[2.0474837]], dtype=float32)

In [231]:
left_reg = jnp.hstack((jnp.ones_like(y), f))
left_coef = jnp.linalg.lstsq(left_reg, y)[0]
yhat = left_reg @ left_coef
ydiff = y - yhat

In [232]:
right_reg = jnp.hstack((jnp.ones_like(d), f))
right_coef = jnp.linalg.lstsq(right_reg, d)[0]
dhat = right_reg @ right_coef
ddiff = d - dhat

In [233]:
final_regs = jnp.hstack((ddiff*t, ddiff))
beta = jnp.linalg.lstsq(final_regs, ydiff)[0]
print(beta)

[[2.3252537]
 [3.9920003]]
