## Causal Inference with Factorial Design

### **Objective:** Provide accurate and scalable model to solve the causal clustering problem

In [47]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
import numpy as np
import statsmodels.api as sm
from sklearn import preprocessing
from itertools import combinations
from tabulate import tabulate

### Set up FactorialModel class

In [49]:
class FactorialModel(object):
    def __init__(
        self,
        n,
        p_t=0.5,
        k=2,
        degree=2,
        sigma=0.1,
        sparsity=0.5,
        beta_seed=42,
    ) -> None:
        self.n = n
        self.p_t = p_t
        self.k = k
        self.degree = degree
        self.sigma = sigma
        self.sparsity = sparsity
        self.beta_seed = beta_seed
        # initialize beta random number generator
        self.rng_beta = np.random.default_rng(self.beta_seed)
        # initialize interaction expansion transformation
        self.xfm = preprocessing.PolynomialFeatures(
            degree=self.degree, interaction_only=True, include_bias=True
        )
        _ = self.xfm.fit_transform(np.zeros((1, self.k), dtype="float32"))
        # sample ground truth betas
        self.beta = self.rng_beta.normal(0, 1, self.xfm.n_output_features_).astype(
            "float32"
        )
        zero_indices = self.rng_beta.choice(
            self.xfm.n_output_features_,
            size=int(self.xfm.n_output_features_ * self.sparsity),
            replace=False,
        )
        self.beta[zero_indices] = 0.0

    def sample(self, seed=None):
        self.rng = np.random.default_rng(seed)
        # sample treatment array
        t = self.rng.binomial(1, self.p_t, (self.n, self.k)).astype("float32")
        # expand treatment array
        T = self.xfm.fit_transform(t)
        # build response surface
        self.mu = T @ self.beta
        # sample outcome
        self.eps = self.rng.normal(0, self.sigma, size=self.n)
        y = self.mu + self.eps
        return t, y

### Initialize model parameters

$\beta_{1:k} \sim \mathcal{N}(0, 1)$, with uniform random sparsity imposed on all parameters.

In [50]:
n = 1000
k = 3
degree = 3
sigma = 0.1
sparsity = 0.5

fm = FactorialModel(
    n=n,
    k=k,
    degree=degree,
    sigma=sigma,
    sparsity=sparsity,
    beta_seed=42,
)

For `degree=3` and `interaction_only=True`, the parameters appear in the following order: [bias, $\beta_{t_1}$, $\beta_{t_2}$, $\beta_{t_3}$, $\beta_{t_1, t_2}$, $\beta_{t_1, t_3}$, $\beta_{t_2, t_3}$, $\beta_{t_1, t_2, t_3}$]

In [51]:
print(fm.beta)

[ 0.        -1.0399841  0.         0.        -1.9510351  0.
  0.1278404 -0.3162426]


### Create a sample dataset

**Treatment matrix:** $t_{n,k} \sim Bern(p_t=0.5)$; $T = t_{\text{expanded}}$

**Outcomes (for k=3):** $y_i = \beta_{i,0} + \beta_{i,1} t_{i,1} + \beta_{i,2} t_{i,2} + \beta_{i,3} t_{i,3} + \beta_{i,12} t_{i,12} + \beta_{i,13} t_{i,13} + \beta_{i,23} t_{i,23} + \beta_{i,123} t_{i,123} + \epsilon$ where $\epsilon \sim \mathcal{N}(0, \sigma)$

In [52]:
t, y = fm.sample(seed=0)
print(t.shape, y.shape)

(1000, 3) (1000,)


### Fit OLS model with sample data

In [53]:
T = preprocessing.PolynomialFeatures(
    degree=degree, interaction_only=True, include_bias=True,
).fit_transform(t)
print(T.shape)

(1000, 8)


In [54]:
m = sm.OLS(y, T)
results = m.fit()
print(results.summary())

                            OLS Regression Results                            
Dep. Variable:                      y   R-squared:                       0.993
Model:                            OLS   Adj. R-squared:                  0.993
Method:                 Least Squares   F-statistic:                 2.093e+04
Date:                Tue, 27 Feb 2024   Prob (F-statistic):               0.00
Time:                        11:00:44   Log-Likelihood:                 881.00
No. Observations:                1000   AIC:                            -1746.
Df Residuals:                     992   BIC:                            -1707.
Df Model:                           7                                         
Covariance Type:            nonrobust                                         
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const          0.0085      0.009      0.964      0.3

## Average Marginal Interaction Effect (AMIE)

[Egami & Imai, 2018](https://www.tandfonline.com/doi/10.1080/01621459.2018.1476246)

AMIE can be expressed as:

$$\pi_{1:K}(\mathbf{t}^{1:K};\mathbf{t}_0^{1:K}) = \tau_{1:K}(\mathbf{t}^{1:K};\mathbf{t}_0^{1:K}) - \sum_{k=1}^{K-1}\sum_{\mathcal{K}_k \in \mathcal{K}_K} \pi_{\mathcal{K}_k}(\mathbf{t}^{\mathcal{K}_k};\mathbf{t}_0^{\mathcal{K}_k})$$

where $\tau = $ average combination effect (ACE) 

$$\tau_{1:K}(\mathbf{t}^{1:K};\mathbf{t}_0^{1:K}) = \mathbb{E}\biggl[\int \Bigl\{Y_i(\mathbf{T}_i^{1:K} = \mathbf{t}^{1:K}, \mathbf{T}_i^{(K+1):J}) - Y_i(\mathbf{T}_i^{1:K} = \mathbf{t}_0^{1:K}, \mathbf{T}_i^{(K+1):J})\Bigr\} dF(\mathbf{T}_i^{(K+1):J}\biggr]$$

where $J$ is the total number of treatments and $K$ is the number of treatments of interest. Note that without loss of generality, we assume the first $K \leq J$ treatments as those of interest.

In [55]:
J = k   # number of treatments
K = 2   # number of treatments of interest
assert(K <= J & K >= 2)

In [56]:
# Compute average combination effect (ACE)
def average_combination_effect(y, t, K_indices):
    combo_idx = np.where(np.all(t[:, K_indices] == 1, axis=1))[0]
    control_idx = np.where(np.all(t[:, K_indices] == 0, axis=1))[0]
    ACE = np.mean(y[combo_idx]) - np.mean(y[control_idx])
    return ACE

In [57]:
# Helper function to compute average marginal effect (AME). Only for base case when K=2.
def average_marginal_effect(t_a, t_b):
    a1_b1_idx = np.where((t[:, t_a] == 1) & (t[:, t_b] == 1))[0]
    a1_b0_idx = np.where((t[:, t_a] == 1) & (t[:, t_b] == 0))[0]
    a0_b1_idx = np.where((t[:, t_a] == 0) & (t[:, t_b] == 1))[0]
    a0_b0_idx = np.where((t[:, t_a] == 0) & (t[:, t_b] == 0))[0]
    AME = np.mean(y[a1_b1_idx]) - np.mean(y[a0_b1_idx]) + np.mean(y[a1_b0_idx]) - np.mean(y[a0_b0_idx])
    return AME

In [58]:
# Recursive function for computing average marginal interaction effect (AMIE)
def average_marginal_interaction_effect(y, t, K_indices):
    # Base case
    if len(K_indices) == 2:
        return average_combination_effect(y, t, K_indices) - average_marginal_effect(K_indices[0], K_indices[1])
    
    # Recursive case
    AMIE = average_combination_effect(y, t, K_indices)
    for k in range(2, len(K_indices)):
        subsets = list(combinations(K_indices, k))
        for subset in subsets:
            AMIE -= average_marginal_interaction_effect(y, t, subset)
    return AMIE

In [59]:
average_marginal_interaction_effect(y, t, [0, 1])

1.1064595750620718

## Summary data

In [60]:
# Table that enumerates the treatment combinations (t), the true beta for each combination (fm.beta), the predicted
# beta for each combination (results), and the AMIE for each combination.
t_combos = []
for i in range(k+1):
    t_combos += list(combinations(range(k), i))

table = []
for i in range(len(t_combos)):
    beta = fm.beta[i]
    pred = results.params[i]
    amie_combos = np.array([average_marginal_interaction_effect(y, t, t_combos[i])])
    table.append([t_combos[i], beta, pred, amie_combos])
print(tabulate(table, headers=["treatment combos", "true betas", "predicted betas", "AMIE"]))

treatment combos      true betas    predicted betas       AMIE
------------------  ------------  -----------------  ---------
()                      0                0.00854169   0
(0,)                   -1.03998         -1.04548     -1.98994
(1,)                    0               -0.00407739  -0.7829
(2,)                    0                0.00466065   0.083422
(0, 1)                 -1.95104         -1.93547      1.10646
(0, 2)                  0               -0.00817732   2.09497
(1, 2)                  0.12784          0.117534     0.843058
(0, 1, 2)              -0.316243        -0.312447    -7.22794
