## Experiments for binary treatment effect estimation comparison

In [1]:
import sys, os

# add the project root to sys.path
root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root not in sys.path:
    sys.path.insert(0, root)

from data_causl.utils import *
from data_causl.data import *
from frengression import *

device = torch.device('cpu')

import CausalEGM as cegm
# import the module
from models import *

import numpy as np
import pickle
import os
from tqdm import tqdm

from matplotlib import pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)



np.random.seed(42)
n_tr = 1000
n_p = 1000

nI = 2
nX = 2
nO = 2
nS= 2
p = nI+nX+nO+nS
ate = 2
beta_cov = 0
strength_instr = 1
strength_conf = 1
strength_outcome = 1
binary_intervention=True

## Example of hyperparameter tuning

In [2]:
from functools import lru_cache

def tune_and_eval(model_name,
                  X_train, t_train, y_train,
                  X_val,   t_val,   y_val,
                  X_test,  t_test,  y_test,
                  provided_params=None,
                  n_trials=20):
    """
    If best_params is None: runs Optuna, returns (ITE_array, best_params).
    If best_params is given: skips Optuna, returns ITE_array only.
    """
    # 1) hyperparam search
    if provided_params is None:
        import optuna
        study = optuna.create_study(direction="minimize",
                                    study_name=f"{model_name}_tune")
        def objective(trial):
            # common
            lr     = trial.suggest_loguniform("lr", 1e-5, 1e-2)
            wd     = trial.suggest_loguniform("wd", 1e-5, 1e-2)
            bs     = trial.suggest_categorical("bs", [32, 128, 256])
            epochs = trial.suggest_int("epochs", 300, 800)

            # model‐specific
            if model_name == "tarnet":
                rep1 = trial.suggest_int("rep1", 20, 50 )
                rep2 = trial.suggest_int("rep2", 50, 100)
                head = trial.suggest_int("head", 50, 100)
                drop = trial.suggest_uniform("drop", 0.0, 0.001)
                trainer = TARNetTrainer(X_train.shape[1], [rep1,rep2], [head], drop)

            elif model_name == "cfrnet":
                rep1   = trial.suggest_int("rep1", 50, 200)
                rep2   = trial.suggest_int("rep2", 50, 200)
                head   = trial.suggest_int("head", 50, 200)
                drop   = trial.suggest_uniform("drop", 0.0, 0.001)
                ipm_w  = trial.suggest_loguniform("ipm_weight", 0.01, 10.0)
                trainer = CFRNetTrainer(X_train.shape[1], [rep1,rep2], [head], drop, ipm_w)

            elif model_name == "cevae":
                ld = trial.suggest_int("latent_dim", 10, 200)
                hd = trial.suggest_int("hidden_dim", 20, 400)
                nl = trial.suggest_int("num_layers", 2, 5)     # note: 2→5 to avoid pop error
                ns = trial.suggest_categorical("num_samples", [10,50,100,200])
                trainer = CEVAETrainer(X_train.shape[1], ld, hd, nl, ns)

            else:  # dragonnet
                sh = trial.suggest_int("shared_hidden", 50, 200)
                oh = trial.suggest_int("outcome_hidden", 50, 200)
                trainer = DragonNetTrainer(X_train.shape[1], sh, oh)

            return trainer.fit(
                X_train, t_train, y_train,
                X_val,   t_val,   y_val,
                lr=lr, weight_decay=wd,
                batch_size=bs, epochs=epochs
            )

        study.optimize(objective, n_trials=n_trials)
        best_params = study.best_params
        print(f"🔍 Best params for {model_name}: {best_params}")
    else:
        best_params = provided_params
    # 2) retrain on combined train+val
    X_trn = np.vstack([X_train, X_val])
    t_trn = np.concatenate([t_train, t_val])
    y_trn = np.concatenate([y_train, y_val])

    if model_name == "tarnet":
        trainer = TARNetTrainer(
            X_trn.shape[1],
            [best_params['rep1'], best_params['rep2']],
            [best_params['head']],
            best_params['drop']
        )
    elif model_name == "cfrnet":
        trainer = CFRNetTrainer(
            X_trn.shape[1],
            [best_params['rep1'], best_params['rep2']],
            [best_params['head']],
            best_params['drop'],
            best_params['ipm_weight']
        )
    elif model_name == "cevae":
        trainer = CEVAETrainer(
            X_trn.shape[1],
            best_params['latent_dim'],
            best_params['hidden_dim'],
            best_params['num_layers'],
            best_params['num_samples']
        )
    else:
        trainer = DragonNetTrainer(
            X_trn.shape[1],
            best_params['shared_hidden'],
            best_params['outcome_hidden']
        )

    trainer.fit(
        X_trn, t_trn, y_trn,
        X_test, t_test, y_test,
        lr=best_params['lr'],
        weight_decay=best_params['wd'],
        batch_size=best_params['bs'],
        epochs=best_params['epochs']
    )

    if model_name == "cevae":
        ite = trainer.predict(X_test)
    else:
        y0p, y1p = trainer.predict(X_test)
        ite = y1p - y0p

    return (ite, best_params) if provided_params is None else ite


## Fitting synthetic data generated by causl

### Data generation

In [None]:
nrep = 30  # Number of repetitions
n_tr = 1000  # Training sample size
n_val = 400
n_te = 400  # Testing sample size
strength_instr_values = np.arange(1,4.5,1)  # Varying strength of instrumental variables
nI = 4 # Fixed number of instrumental variables
nX = 3
nO = 3
nS = 0
binary_intervention = True
num_iters = 800  # Fixed number of training iterations
ate = 2
strength_conf = 1
strength_outcome = 1

# Initialize tracker for strength_instr
tracker = {strength_instr: {"fr": [], "dr": [], "causalegm":[], "tarnet":[], "cfrnet":[], "cevae":[], "dragonnet":[]}
           for strength_instr in strength_instr_values}
best_hps = {model: None for model in ["tarnet","cfrnet","cevae","dragonnet"]}
# Begin loop over strength_instr
for strength_instr in strength_instr_values:
    print(f"Running experiments for strength_instr = {strength_instr}")
    p = nI + nX + nO + nS  # Update the number of covariates
    
    for rep in tqdm(range(nrep)):
        # Generate training and testing data
        df_tr = generate_data_causl(n=n_tr, nI=nI, nX=nX, nO=nO, nS=nS, ate=ate, 
                                    beta_cov=beta_cov, strength_instr=strength_instr, 
                                    strength_conf=strength_conf, 
                                    strength_outcome=strength_outcome, 
                                    binary_intervention=binary_intervention)
        z_tr = torch.tensor(df_tr[[f"X{i}" for i in range(1, p + 1)]].values, dtype=torch.float32)
        x_tr = torch.tensor(df_tr['A'].values, dtype=torch.int32).view(-1, 1) if binary_intervention else \
            torch.tensor(df_tr['A'].values, dtype=torch.float32).view(-1, 1)
        y_tr = torch.tensor(df_tr['y'].values, dtype=torch.float32).view(-1, 1)
        
        z_tr_np = df_tr[[f"X{i}" for i in range(1, p + 1)]].values
        x_tr_np = df_tr['A'].values
        y_tr_np = df_tr['y'].values

        df_val = generate_data_causl(n=n_val, nI=nI, nX=nX, nO=nO, nS=nS, ate=ate, 
                                    beta_cov=beta_cov, strength_instr=strength_instr, 
                                    strength_conf=strength_conf, 
                                    strength_outcome=strength_outcome, 
                                    binary_intervention=binary_intervention)


        z_val_np = df_val[[f"X{i}" for i in range(1, p + 1)]].values
        x_val_np = df_val['A'].values
        y_val_np = df_val['y'].values

        df_te = generate_data_causl(n=n_te, nI=nI, nX=nX, nO=nO, nS=nS, ate=ate, 
                                    beta_cov=beta_cov, strength_instr=strength_instr, 
                                    strength_conf=strength_conf, 
                                    strength_outcome=strength_outcome, 
                                    binary_intervention=binary_intervention)

        z_te_np = df_te[[f"X{i}" for i in range(1, p + 1)]].values
        x_te_np = df_te['A'].values
        y_te_np = df_te['y'].values
        z_te = torch.tensor(z_te_np, dtype=torch.float32)

        model = Frengression(x_dim = x_tr.shape[1], y_dim = 1, z_dim =z_tr.shape[1], 
                             noise_dim=1, num_layer=3, hidden_dim=100, 
                             device=device, x_binary=binary_intervention, z_binary_dims=0)

        # Train Frengression model
        model.train_y(x=x_tr,
                      z=z_tr, 
                      y=y_tr, 
                      num_iters=num_iters, lr=1e-4, print_every_iter=1000)

        # Sample model distributions
        P0 = model.sample_causal_margin(torch.tensor([0], dtype=torch.int32), sample_size=n_te).numpy().reshape(-1, 1)
        P1 = model.sample_causal_margin(torch.tensor([1], dtype=torch.int32), sample_size=n_te).numpy().reshape(-1, 1)
        ate_fr = np.mean(P1) - np.mean(P0)

        # DR Estimation
        ate_dr, _ = dr_ate(x_tr_np, y_tr_np, z_tr_np ,x_te_np, y_te_np, z_te_np)

        for model in ["tarnet","cfrnet","cevae","dragonnet"]:
            if rep == 0:
                ite, best_hps[model] = tune_and_eval(
                    model,
                    z_tr_np, x_tr_np, y_tr_np,
                    z_val_np, x_val_np, y_val_np,
                    z_te_np, x_te_np,y_te_np,
                    provided_params=None,
                    n_trials=20
                )
            else:
                ite = tune_and_eval(
                    model,
                    z_tr_np, x_tr_np, y_tr_np,
                    z_val_np, x_val_np, y_val_np,
                    z_te_np, x_te_np,y_te_np,
                    provided_params=best_hps[model]
                )
            tracker[strength_instr][model].append(ite.mean())


        cegm_params = {'dataset': 'Semi_acic', 
                        'output_dir': '.', 
                        'v_dim': z_tr.shape[1], 
                        'z_dims': [1, 1, 1, 1], 
                        'lr': 0.0002, 
                        'alpha': 1, 
                        'beta': 1, 
                        'gamma': 10, 
                        'g_d_freq': 5, 
                        'g_units': [64, 64, 64, 64, 64], 
                        'e_units': [64, 64, 64, 64, 64], 
                        'f_units': [64, 32, 8], 
                        'h_units': [64, 32, 8], 
                        'dz_units': [64, 32, 8], 
                        'dv_units': [64, 32, 8], 'save_res': False, 'save_model': False, 'binary_treatment': True, 'use_z_rec': True, 'use_v_gan': True}
        egm_model = cegm.CausalEGM(params=cegm_params, random_seed=42)
        egm_model.train(data=[x_tr,y_tr,z_tr],n_iter=1000, verbose=False)
        ate_causalegm=egm_model.getCATE(z_te).mean()

        # Log results
        tracker[strength_instr]["fr"].append(ate_fr)
        tracker[strength_instr]["dr"].append(ate_dr)
        tracker[strength_instr]["causalegm"].append(ate_causalegm)
    
        

Running experiments for strength_instr = 1.0


  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1: loss 2.3854,	loss_y 1.6279, 1.6506, 0.0455,	loss_eta 0.7576, 0.7917, 0.0683


[I 2025-04-26 07:30:03,303] A new study created in memory with name: tarnet_tune
Exception ignored in: <function _xla_gc_callback at 0x351d55580>
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.11/site-packages/jax/_src/lib/__init__.py", line 96, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 
[I 2025-04-26 07:30:08,798] Trial 0 finished with value: 0.7342360615730286 and parameters: {'lr': 0.004206814077835359, 'wd': 0.0014372372254375487, 'bs': 256, 'epochs': 716, 'rep1': 29, 'rep2': 93, 'head': 100, 'drop': 0.0007061397880518154}. Best is trial 0 with value: 0.7342360615730286.
[I 2025-04-26 07:30:12,186] Trial 1 finished with value: 1.122741937637329 and parameters: {'lr': 1.2212100789764055e-05, 'wd': 0.0020118063827876172, 'bs': 128, 'epochs': 384, 'rep1': 42, 'rep2': 84, 'head': 55, 'drop': 0.0002811096668861708}. Best is trial 0 with value: 0.7342360615730286.


In [None]:
import json
output_dir = "result/binary"
os.makedirs(output_dir, exist_ok=True)
tracker_serializable = {
    str(k): [float(x) for x in v_dict.get("fr",[])] 
              + []  # (we'll overwrite below) 
    for k, v_dict in tracker.items()
}
# actually build full dict:
tracker_serializable = {
    str(k): {
        model: [float(x) for x in v_list]
        for model, v_list in v_dict.items()
    }
    for k, v_dict in tracker.items()
}

# 3) write it out
with open(os.path.join(output_dir, "synthetic_1k.json"), "w") as f:
    json.dump(tracker_serializable, f, indent=4)


## IHDP

In [None]:
# Experiment parameters
binary_intervention = True
num_iters = 1000
p = 25
z_binary_dims = 19
path = '/Users/linyingyang/Documents/Project/frengression/frengression/data_causl'

# Initialize tracker for valid trials

valid_trials = 0
max_trials = 100  # We want results from 100 valid trials
trial = 0


tracker = {"fr": [], "dr": [], "causalegm":[], "tarnet":[], "cfrnet":[], "cevae":[], "dragonnet":[]}
best_hps = {model: None for model in ["tarnet","cfrnet","cevae","dragonnet"]}

while valid_trials < max_trials:
    print(f"Checking trial = {trial}")
    df_tr, df_te = process_data(path=path, trial=trial)

    # Skip this trial if any y_factual in df_tr exceeds 20
    if (df_tr['y_factual'] > 20).any():
        print(f"Skipping trial {trial} because y_factual > 20")
        trial += 1
        continue

    print(f"Running on valid trial = {trial}")
    
    # Prepare tensors for training
    z_tr = torch.tensor(df_tr[[f"X{i}" for i in range(1, p + 1)]].values, dtype=torch.float32)
    x_tr = torch.tensor(df_tr['treatment'].values, dtype=torch.float32).view(-1, 1)
    y_tr = torch.tensor(df_tr['y_factual'].values, dtype=torch.float32).view(-1, 1)
    ate_sample = torch.tensor(np.mean(df_tr['mu1'].values - df_tr['mu0'].values), dtype=torch.float32).view(-1, 1)
    
    # DR ATE estimation
    z_tr_np = df_tr[[f"X{i}" for i in range(1, p + 1)]].values
    x_tr_np = df_tr['treatment'].values
    y_tr_np = df_tr['y_factual'].values

    z_te_np = df_te[[f"X{i}" for i in range(1, p + 1)]].values
    x_te_np = df_te['treatment'].values
    y_te_np = df_te['y_factual'].values
    hat_dr, _ = dr_ate(x_tr_np, y_tr_np, z_tr_np, x_te_np, y_te_np, z_te_np)

    # Initialize Frengression model
    model = Frengression(x_tr.shape[1], y_tr.shape[1], z_tr.shape[1], 
                         noise_dim=1, num_layer=3, hidden_dim=400, 
                         device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
                         x_binary=binary_intervention, z_binary_dims=19)
    

    model.train_y(x_tr, z_tr, y_tr, num_iters=num_iters, lr=1e-4, print_every_iter=400)


    # Sample model distributions
    P0 = model.sample_causal_margin(torch.tensor([0], dtype=torch.int32), sample_size=n_te).numpy().reshape(-1, 1)
    P1 = model.sample_causal_margin(torch.tensor([1], dtype=torch.int32), sample_size=n_te).numpy().reshape(-1, 1)
    ate_fr = np.mean(P1) - np.mean(P0)

    # DR Estimation
    ate_dr, _ = dr_ate(x_tr_np, y_tr_np, z_tr_np ,x_te_np, y_te_np, z_te_np)

    for model in ["tarnet","cfrnet","cevae","dragonnet"]:
        if valid_trials == 0:
            ite, best_hps[model] = tune_and_eval(
                model,
                z_tr_np, x_tr_np, y_tr_np,
                z_tr_np, x_tr_np, y_tr_np,
                z_te_np, x_te_np,y_te_np,
                provided_params=None,
                n_trials=20
            )
        else:
            ite = tune_and_eval(
                model,
                z_tr_np, x_tr_np, y_tr_np,
                z_tr_np, x_tr_np, y_tr_np,
                z_te_np, x_te_np,y_te_np,
                provided_params=best_hps[model]
            )
        tracker[model].append(ite.mean())


    cegm_params = {'dataset': 'Semi_acic', 
                    'output_dir': '.', 
                    'v_dim': z_tr.shape[1], 
                    'z_dims': [1, 1, 1, 1], 
                    'lr': 0.0002, 
                    'alpha': 1, 
                    'beta': 1, 
                    'gamma': 10, 
                    'g_d_freq': 5, 
                    'g_units': [64, 64, 64, 64, 64], 
                    'e_units': [64, 64, 64, 64, 64], 
                    'f_units': [64, 32, 8], 
                    'h_units': [64, 32, 8], 
                    'dz_units': [64, 32, 8], 
                    'dv_units': [64, 32, 8], 'save_res': False, 'save_model': False, 'binary_treatment': True, 'use_z_rec': True, 'use_v_gan': True}
    egm_model = cegm.CausalEGM(params=cegm_params, random_seed=42)
    egm_model.train(data=[x_tr,y_tr,z_tr],n_iter=1000, verbose=False)
    ate_causalegm=egm_model.getCATE(z_te).mean()
    tracker['causalegm'].append(ate_causalegm)

    # Increment valid trials counter and move to the next trial
    valid_trials += 1
    trial += 1



Checking trial = 0
Running on valid trial = 0
Epoch 1: loss 3.9422,	loss_y 3.2075, 3.2219, 0.0289,	loss_eta 0.7347, 0.7670, 0.0646
Epoch 400: loss 1.0614,	loss_y 0.4755, 0.9720, 0.9931,	loss_eta 0.5860, 1.1017, 1.0315
Epoch 800: loss 0.9473,	loss_y 0.3557, 0.7377, 0.7640,	loss_eta 0.5916, 1.0262, 0.8692


[I 2025-04-26 07:28:26,313] A new study created in memory with name: tarnet_tune
[I 2025-04-26 07:28:28,582] Trial 0 finished with value: 0.05483689531683922 and parameters: {'lr': 0.006491187885625268, 'wd': 0.004936980702479022, 'bs': 256, 'epochs': 392, 'rep1': 36, 'rep2': 92, 'head': 99, 'drop': 0.0006358724407596506}. Best is trial 0 with value: 0.05483689531683922.
[I 2025-04-26 07:28:32,857] Trial 1 finished with value: 3.3733069896698 and parameters: {'lr': 2.307907865982853e-05, 'wd': 0.0025256327336519034, 'bs': 32, 'epochs': 309, 'rep1': 20, 'rep2': 59, 'head': 92, 'drop': 0.0003527484933601588}. Best is trial 0 with value: 0.05483689531683922.
[I 2025-04-26 07:28:39,931] Trial 2 finished with value: 0.25672170519828796 and parameters: {'lr': 0.00039374419049185164, 'wd': 0.0010826011185339763, 'bs': 32, 'epochs': 536, 'rep1': 22, 'rep2': 50, 'head': 71, 'drop': 0.0001100284911583589}. Best is trial 0 with value: 0.05483689531683922.
[I 2025-04-26 07:28:41,594] Trial 3 finis

KeyboardInterrupt: 

In [None]:
output_dir = "result/binary"
os.makedirs(output_dir, exist_ok=True)
tracker_serializable = {
    str(k): [float(x) for x in v_dict.get("fr",[])] 
              + []  # (we'll overwrite below) 
    for k, v_dict in tracker.items()
}
# actually build full dict:
tracker_serializable = {
    str(k): {
        model: [float(x) for x in v_list]
        for model, v_list in v_dict.items()
    }
    for k, v_dict in tracker.items()
}

# 3) write it out
with open(os.path.join(output_dir, "ihdp.json"), "w") as f:
    json.dump(tracker_serializable, f, indent=4)ate_fr

3.996327