## Experiments for binary treatment effect estimation comparison

In [None]:
import sys, os

# add the project root to sys.path
root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root not in sys.path:
    sys.path.insert(0, root)

from data_causl.utils import *
from data_causl.data import *
from frengression import *

device = torch.device('cpu')

from CausalEGM import *
# import the module
from dragonnet.dragonnet import DragonNet # https://github.com/farazmah/dragonnet-pytorch
from catenets.models.torch import TARNet
from pyro.contrib.cevae import CEVAE

# X: [N×D], t: [N], y: [N]
cevae = CEVAE(feature_dim=D)
cevae.fit(torch.tensor(X), torch.tensor(t), torch.tensor(y))
ite = cevae.ite(torch.tensor(X_test))  # returns per-sample effects



import numpy as np
import pickle
import os
from tqdm import tqdm

from matplotlib import pyplot as plt
import seaborn as sns


import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)



np.random.seed(2024)
n_tr = 1000
n_p = 1000

nI = 2
nX = 2
nO = 2
nS= 2
p = nI+nX+nO+nS
ate = 2
beta_cov = 0
strength_instr = 1
strength_conf = 1
strength_outcome = 1
binary_intervention=True

In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from sklearn.model_selection import train_test_split
import optuna
from pyro.contrib.cevae import CEVAE
from dragonnet.dragonnet import DragonNet

# ------------------------------------------------------------------------
# 1) TARNetModel + Trainer
# ------------------------------------------------------------------------
def mmd_rbf(x, y, gamma=None):
    x_flat = x.view(x.size(0), -1)
    y_flat = y.view(y.size(0), -1)
    Z = torch.cat([x_flat, y_flat], dim=0)
    dist = (
        Z.pow(2).sum(1, keepdim=True)
        - 2 * Z @ Z.t()
        + Z.pow(2).sum(1, keepdim=True).t()
    )
    if gamma is None:
        d = dist.detach().cpu().numpy()
        gamma = 1.0 / (0.5 * np.median(d[d > 0]))
    K = torch.exp(-gamma * dist)
    n = x_flat.size(0)
    return K[:n, :n].mean() + K[n:, n:].mean() - 2 * K[:n, n:].mean()

class TARNetModel(nn.Module):
    def __init__(self, input_dim, rep_dims, head_dims, dropout):
        super().__init__()
        layers = []
        last_dim = input_dim
        for h in rep_dims:
            layers += [nn.Linear(last_dim, h), nn.ReLU(), nn.Dropout(dropout)]
            last_dim = h
        self.repr_net = nn.Sequential(*layers)

        def make_head(in_dim):
            head_layers = []
            cur = in_dim
            for h in head_dims:
                head_layers += [nn.Linear(cur, h), nn.ReLU(), nn.Dropout(dropout)]
                cur = h
            head_layers += [nn.Linear(cur, 1)]
            return nn.Sequential(*head_layers)

        self.h0 = make_head(last_dim)
        self.h1 = make_head(last_dim)

    def forward(self, x):
        z = self.repr_net(x)
        y0 = self.h0(z).squeeze(-1)
        y1 = self.h1(z).squeeze(-1)
        return y0, y1

class TARNetTrainer:
    def __init__(self, input_dim, rep_dims, head_dims, dropout):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = TARNetModel(input_dim, rep_dims, head_dims, dropout).to(self.device)

    def fit(self, X_train, t_train, y_train,
            X_val, t_val, y_val,
            lr, weight_decay, batch_size, epochs):
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)

        train_ds = torch.utils.data.TensorDataset(
            torch.tensor(X_train, dtype=torch.float32, device=self.device),
            torch.tensor(t_train, dtype=torch.long, device=self.device),
            torch.tensor(y_train, dtype=torch.float32, device=self.device),
        )
        train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)

        best_val = float("inf")
        for _ in range(epochs):
            # training loop
            self.model.train()
            for xb, tb, yb in train_loader:
                optimizer.zero_grad()
                y0, y1 = self.model(xb)
                y_pred = torch.where(tb == 1, y1, y0)
                loss = criterion(y_pred, yb)
                loss.backward()
                optimizer.step()

            # validation
            self.model.eval()
            with torch.no_grad():
                Xv = torch.tensor(X_val, dtype=torch.float32, device=self.device)
                tv = torch.tensor(t_val, dtype=torch.long, device=self.device)
                yv = torch.tensor(y_val, dtype=torch.float32, device=self.device)
                y0v, y1v = self.model(Xv)
                ypv = torch.where(tv == 1, y1v, y0v)
                val_loss = criterion(ypv, yv).item()
                best_val = min(best_val, val_loss)

        return best_val

    def predict(self, X):
        self.model.eval()
        Xb = torch.tensor(X, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            y0, y1 = self.model(Xb)
        return y0.cpu().numpy(), y1.cpu().numpy()

# ------------------------------------------------------------------------
# 2) CEVAE Trainer (Pyro) with normal outcome(self, X):
        self.wrapper.model.eval()
        Xb = torch.tensor(X, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            outs = self.wrapper.model(Xb)
            if len(outs) == 3:
                y0, y1, _ = outs
            else:
                y0, y1 = outs
        return y0.cpu().numpy(), y1.cpu().numpy()

# ------------------------------------------------------------------------
# 2) CEVAE Trainer (Pyro) with normal outcome
# ------------------------------------------------------------------------
class CEVAETrainer:
    def __init__(self, input_dim, latent_dim, hidden_dim, num_layers, num_samples):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = CEVAE(
            feature_dim=input_dim,
            latent_dim=latent_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_samples=num_samples,
            outcome_dist="normal"
        ).to(self.device)

    def fit(self, X_train, t_train, y_train,
            X_val, t_val, y_val,
            lr, weight_decay, batch_size, epochs):
        Xb = torch.tensor(X_train, dtype=torch.float32, device=self.device)
        tb = torch.tensor(t_train, dtype=torch.float32, device=self.device)
        yb = torch.tensor(y_train, dtype=torch.float32, device=self.device)
        elbo_list = self.model.fit(
            Xb, tb, yb,
            num_epochs=epochs,
            batch_size=batch_size,
            learning_rate=lr,
            weight_decay=weight_decay
        )
        return elbo_list[-1]

    def predict(self, X):
        Xb = torch.tensor(X, dtype=torch.float32, device=self.device)
        ite_samples = self.model.ite(Xb).cpu().numpy()
        return ite_samples.mean(0)

# ------------------------------------------------------------------------
# 3) DragonNet Trainer
# ------------------------------------------------------------------------
class DragonNetTrainer:
    def __init__(self, input_dim, shared_hidden, outcome_hidden):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.wrapper = DragonNet(input_dim, shared_hidden, outcome_hidden)
        # Move internal nn.Module to device
        self.wrapper.model.to(self.device)

    def fit(self, X_train, t_train, y_train,
            X_val, t_val, y_val,
            lr, weight_decay, batch_size, epochs):
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.wrapper.model.parameters(), lr=lr, weight_decay=weight_decay)

        train_ds = torch.utils.data.TensorDataset(
            torch.tensor(X_train, dtype=torch.float32, device=self.device),
            torch.tensor(t_train, dtype=torch.float32, device=self.device),
            torch.tensor(y_train, dtype=torch.float32, device=self.device)
        )
        train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)

        best_val = float("inf")
        for _ in range(epochs):
            # Training
            self.wrapper.model.train()
            for xb, tb, yb in train_loader:
                optimizer.zero_grad()
                outs = self.wrapper.model(xb)
                # Always take first two outputs as y0, y1
                y0, y1 = outs[0], outs[1]
                y_pred = torch.where(tb.unsqueeze(1) == 1, y1, y0).squeeze(-1)
                loss = criterion(y_pred, yb)
                loss.backward()
                optimizer.step()

            # Validation
            self.wrapper.model.eval()
            with torch.no_grad():
                Xv = torch.tensor(X_val, dtype=torch.float32, device=self.device)
                tv = torch.tensor(t_val, dtype=torch.float32, device=self.device)
                yv = torch.tensor(y_val, dtype=torch.float32, device=self.device)
                outs_v = self.wrapper.model(Xv)
                y0v, y1v = outs_v[0], outs_v[1]
                ypv = torch.where(tv.unsqueeze(1) == 1, y1v, y0v).squeeze(-1)
                val_loss = criterion(ypv, yv).item()
                best_val = min(best_val, val_loss)
        return best_val

    def predict(self, X):
        self.wrapper.model.eval()
        Xb = torch.tensor(X, dtype=torch.float32, device=self.device)
        with torch.no_grad():
            outs = self.wrapper.model(Xb)
            y0, y1 = outs[0], outs[1]
        return y0.cpu().numpy(), y1.cpu().numpy()

# ------------------------------------------------------------------------
# 4) Synthetic data & split
# ------------------------------------------------------------------------
def make_synthetic_data(N=1000, D=5, seed=0):
    torch.manual_seed(seed)
    X = torch.randn(N, D)
    w_p, b_p = torch.randn(D), 0.1
    p = torch.sigmoid(X @ w_p + b_p)
    t = torch.bernoulli(p)
    beta0 = torch.randn(D)
    y0 = X @ beta0 + 0.1 * torch.randn(N)
    tau = (X[:, 0] * 2.0).clamp(min=0)
    y1 = y0 + tau + 0.1 * torch.randn(N)
    y = y0 * (1 - t) + y1 * t
    return X.numpy(), t.numpy(), y.numpy(), y0.numpy(), y1.numpy()

X, t, y, y0, y1 = make_synthetic_data(N=400, D=2)
X_train, X_tmp, t_train, t_tmp, y_train, y_tmp = train_test_split(
    X, t, y, test_size=0.4, random_state=42
)
X_val, X_test, t_val, t_test, y_val, y_test = train_test_split(
    X_tmp, t_tmp, y_tmp, test_size=0.5, random_state=42
)

# ------------------------------------------------------------------------
# 5) Optuna tuning & evaluation
# ------------------------------------------------------------------------
def tune_and_eval(model_name):
    def objective(trial):
        lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
        wd = trial.suggest_loguniform("wd", 1e-6, 1e-3)
        bs = trial.suggest_categorical("bs", [64, 128, 256])
        epochs = trial.suggest_int("epochs", 20, 60)

        if model_name == "tarnet":
            rep1 = trial.suggest_categorical("rep1", [100, 200])
            rep2 = trial.suggest_categorical("rep2", [100, 200])
            head = trial.suggest_int("head", 50, 150, step=50)
            drop = trial.suggest_uniform("drop", 0.0, 0.5)
            trainer = TARNetTrainer(
                input_dim=X_train.shape[1], rep_dims=[rep1, rep2], head_dims=[head], dropout=drop
            )
        elif model_name == "cevae":
            ld = trial.suggest_int("latent_dim", 10, 100)
            hd = trial.suggest_int("hidden_dim", 50, 200)
            nl = trial.suggest_int("num_layers", 2, 4)
            ns = trial.suggest_categorical("num_samples", [10, 50, 100])
            trainer = CEVAETrainer(
                input_dim=X_train.shape[1], latent_dim=ld,
                hidden_dim=hd, num_layers=nl, num_samples=ns
            )
        else:  # dragonnet
            sh = trial.suggest_int("shared_hidden", 50, 200)
            oh = trial.suggest_int("outcome_hidden", 50, 200)
            trainer = DragonNetTrainer(
                input_dim=X_train.shape[1], shared_hidden=sh, outcome_hidden=oh
            )

        return trainer.fit(
            X_train, t_train, y_train,
            X_val, t_val, y_val,
            lr=lr, weight_decay=wd, batch_size=bs, epochs=epochs
        )

    study = optuna.create_study(direction="minimize", study_name=f"{model_name}_tune")
    study.optimize(objective, n_trials=30)
    best = study.best_params
    print(f"Best params for {model_name}: {best}")

    # retrain on train+val
    X_trn = np.vstack([X_train, X_val])
    t_trn = np.concatenate([t_train, t_val])
    y_trn = np.concatenate([y_train, y_val])

    if model_name == "tarnet":
        trainer = TARNetTrainer(
            input_dim=X_trn.shape[1], rep_dims=[best["rep1"], best["rep2"]],
            head_dims=[best["head"]], dropout=best["drop"]
        )
    elif model_name == "cevae":
        trainer = CEVAETrainer(
            input_dim=X_trn.shape[1], latent_dim=best["latent_dim"], hidden_dim=best["hidden_dim"],
            num_layers=best["num_layers"], num_samples=best["num_samples"]
        )
    else:
        trainer = DragonNetTrainer(
            input_dim=X_trn.shape[1], shared_hidden=best["shared_hidden"], outcome_hidden=best["outcome_hidden"]
        )

    trainer.fit(
        X_trn, t_trn, y_trn,
        X_test, t_test, y_test,
        lr=best["lr"], weight_decay=best["wd"], batch_size=best["bs"], epochs=best["epochs"]
    )

    if model_name == "cevae":
        return trainer.predict(X_test)
    else:
        y0p, y1p = trainer.predict(X_test)
        return y1p - y0p

# ------------------------------------------------------------------------
# 6) Run all
# ------------------------------------------------------------------------
if __name__ == "__main__":
    print("Tuning and training TARNet...")
    ite_tarnet = tune_and_eval("tarnet")
    print("TARNet ITE shape:", ite_tarnet.shape)

    print("Tuning and training CEVAE...")
    ite_cevae = tune_and_eval("cevae")
    print("CEVAE ITE shape:", ite_cevae.shape)

    print("Tuning and training DragonNet...")
    ite_dragonnet = tune_and_eval("dragonnet")
    print("DragonNet ITE shape:", ite_dragonnet.shape)

    print("All models complete.")


[I 2025-04-24 20:38:19,755] A new study created in memory with name: tarnet_tune
[I 2025-04-24 20:38:19,932] Trial 0 finished with value: 0.040450528264045715 and parameters: {'lr': 0.0004994689501493047, 'wd': 4.886282511483952e-05, 'bs': 64, 'epochs': 48, 'rep1': 100, 'rep2': 100, 'head': 150, 'drop': 0.42985230220261156}. Best is trial 0 with value: 0.040450528264045715.


Tuning and training TARNet...


[I 2025-04-24 20:38:20,121] Trial 1 finished with value: 0.022942176088690758 and parameters: {'lr': 0.007955498151842575, 'wd': 0.00011449883718167686, 'bs': 64, 'epochs': 53, 'rep1': 100, 'rep2': 200, 'head': 50, 'drop': 0.03220843654610184}. Best is trial 1 with value: 0.022942176088690758.
[I 2025-04-24 20:38:20,312] Trial 2 finished with value: 0.02519337460398674 and parameters: {'lr': 0.0020630840624249716, 'wd': 0.0003783508656240764, 'bs': 64, 'epochs': 53, 'rep1': 200, 'rep2': 100, 'head': 100, 'drop': 0.20117938788096745}. Best is trial 1 with value: 0.022942176088690758.
[I 2025-04-24 20:38:20,425] Trial 3 finished with value: 0.03102118894457817 and parameters: {'lr': 0.00047758478415335106, 'wd': 0.00021975273678424403, 'bs': 128, 'epochs': 43, 'rep1': 100, 'rep2': 100, 'head': 150, 'drop': 0.1347563316704551}. Best is trial 1 with value: 0.022942176088690758.
[I 2025-04-24 20:38:20,572] Trial 4 finished with value: 0.041703782975673676 and parameters: {'lr': 0.0047927428

Best params for tarnet: {'lr': 0.0006750814407541536, 'wd': 2.288879600687233e-06, 'bs': 64, 'epochs': 43, 'rep1': 100, 'rep2': 200, 'head': 150, 'drop': 3.431328529691219e-05}


[I 2025-04-24 20:38:25,187] A new study created in memory with name: cevae_tune
INFO 	 Training with 1 minibatches per epoch
[I 2025-04-24 20:38:25,336] Trial 0 finished with value: 13.590973663330079 and parameters: {'lr': 0.0002680757208310392, 'wd': 7.53440203370384e-06, 'bs': 256, 'epochs': 30, 'latent_dim': 25, 'hidden_dim': 188, 'num_layers': 2, 'num_samples': 10}. Best is trial 0 with value: 13.590973663330079.
INFO 	 Training with 4 minibatches per epoch


TARNet ITE shape: (80,)
Tuning and training CEVAE...


[I 2025-04-24 20:38:25,870] Trial 1 finished with value: 10.820096397399903 and parameters: {'lr': 0.0004513322364411776, 'wd': 2.4012210537286576e-05, 'bs': 64, 'epochs': 35, 'latent_dim': 13, 'hidden_dim': 105, 'num_layers': 4, 'num_samples': 10}. Best is trial 1 with value: 10.820096397399903.
INFO 	 Training with 4 minibatches per epoch
[I 2025-04-24 20:38:26,387] Trial 2 finished with value: 11.075271479288737 and parameters: {'lr': 0.0023287815570887972, 'wd': 5.022581646361076e-06, 'bs': 64, 'epochs': 36, 'latent_dim': 34, 'hidden_dim': 58, 'num_layers': 4, 'num_samples': 10}. Best is trial 1 with value: 10.820096397399903.
INFO 	 Training with 2 minibatches per epoch
[I 2025-04-24 20:38:26,743] Trial 3 finished with value: 17.482425181070962 and parameters: {'lr': 0.00017177227711979916, 'wd': 0.0005270547830928239, 'bs': 128, 'epochs': 36, 'latent_dim': 79, 'hidden_dim': 94, 'num_layers': 4, 'num_samples': 100}. Best is trial 1 with value: 10.820096397399903.
INFO 	 Training w

Best params for cevae: {'lr': 0.0028436412555809133, 'wd': 2.7020290041516205e-06, 'bs': 128, 'epochs': 38, 'latent_dim': 10, 'hidden_dim': 82, 'num_layers': 4, 'num_samples': 50}


INFO 	 Evaluating 1 minibatches
[I 2025-04-24 20:38:36,952] A new study created in memory with name: dragonnet_tune
[I 2025-04-24 20:38:37,145] Trial 0 finished with value: 0.02311547100543976 and parameters: {'lr': 0.009035165710312931, 'wd': 2.892278644424301e-05, 'bs': 64, 'epochs': 36, 'shared_hidden': 159, 'outcome_hidden': 196}. Best is trial 0 with value: 0.02311547100543976.


CEVAE ITE shape: ()
Tuning and training DragonNet...


[I 2025-04-24 20:38:37,277] Trial 1 finished with value: 0.02035456709563732 and parameters: {'lr': 0.00047778814865433595, 'wd': 0.00033016151836169464, 'bs': 128, 'epochs': 45, 'shared_hidden': 123, 'outcome_hidden': 125}. Best is trial 1 with value: 0.02035456709563732.
[I 2025-04-24 20:38:37,463] Trial 2 finished with value: 0.02062293514609337 and parameters: {'lr': 0.008227575578011268, 'wd': 8.526111829171785e-06, 'bs': 128, 'epochs': 56, 'shared_hidden': 110, 'outcome_hidden': 200}. Best is trial 1 with value: 0.02035456709563732.
[I 2025-04-24 20:38:37,600] Trial 3 finished with value: 0.020305996760725975 and parameters: {'lr': 0.0005777826642197931, 'wd': 0.0008023549564377211, 'bs': 64, 'epochs': 39, 'shared_hidden': 80, 'outcome_hidden': 66}. Best is trial 3 with value: 0.020305996760725975.
[I 2025-04-24 20:38:37,702] Trial 4 finished with value: 0.7565242648124695 and parameters: {'lr': 0.00022333130290256336, 'wd': 0.0008888479896610172, 'bs': 256, 'epochs': 55, 'shared

Best params for dragonnet: {'lr': 0.0007446133996675199, 'wd': 0.0005419686217158935, 'bs': 64, 'epochs': 40, 'shared_hidden': 85, 'outcome_hidden': 110}
DragonNet ITE shape: (80, 1)
All models complete.


In [46]:
ite_dragonnet.mean()

0.98210394

In [47]:
ite_cevae.mean()

0.05340328

In [48]:
ite_tarnet.mean()

0.9251715