## Experiments for binary treatment effect estimation comparison

In [1]:
import sys, os

# add the project root to sys.path
root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if root not in sys.path:
    sys.path.insert(0, root)

from data_causl.utils import *
from data_causl.data import *
from frengression import *

device = torch.device('cpu')

import CausalEGM as cegm
# import the module
from models import *

import numpy as np
import pickle
import os
from tqdm import tqdm

from matplotlib import pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)


np.random.seed(42)
n_tr = 1000
n_p = 1000


## Twins

In [2]:
params = {"dataset": "Semi_Twins",
    "output_dir": './',
    "v_dim": 50,
    "z_dims": [1,1,1,7],
    "x_min": 0,
    "x_max": 3,
    "lr": 0.0002,
    "bs": 32,
    "alpha": 1,
    "beta": 1,
    "gamma": 10,
    "g_d_freq": 5,
    "g_units": [64,64,64,64,64],
    "e_units": [64,64,64,64,64],
    "f_units": [64,32,8],
    "h_units": [64,32,8],
    "dz_units": [64,32,8],
    "dv_units": [64,32,8],
    "binary_treatment": False,
    "use_z_rec": True,
    "use_v_gan": False,
    "save_res": False,
    "save_model": False,}
# Set up the range of x values for predictions
x_vals = np.arange(0, 6, 0.1)
x_tensor = torch.tensor(x_vals.reshape(-1, 1), dtype=torch.float32)

In [None]:
# Number of repetitions
nrep = 30
n_tr = 1000
num_iters = 1000
binary_intervention = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize storage for predictions
fr_preds = [[], [], []]  # 2.5%, mean, 97.5%
cegm_preds = []

# Ground truth ADRF values
ground_truth_adrf = Semi_Twins_adrf(x_vals)

# Begin experiment
for rep in tqdm(range(nrep)):
    # Simulate training data

    x_train, y_train, z_train = Semi_Twins_sampler(path = "../data_causl"
    ).load_all()
    
    # Convert to tensors
    x_tr = torch.tensor(x_train, dtype=torch.float32)
    y_tr = torch.tensor(y_train, dtype=torch.float32)
    z_tr = torch.tensor(z_train, dtype=torch.float32)
    
    # --- Frengression Model ---
    # Initialize and train
    fr_model = Frengression(
        x_tr.shape[1], y_tr.shape[1], z_tr.shape[1],
        noise_dim=1, num_layer=3, hidden_dim=100,
        device=device, x_binary=binary_intervention
    )
    fr_model.train_y(x_tr, z_tr, y_tr, num_iters=num_iters, lr=1e-4, print_every_iter=400)

    # Predict using Frengression model (quantiles and mean)
    fr_pred = fr_model.predict_causal(
        x_tensor, target=[0.025, "mean", 0.975], sample_size=1000
    )  # Returns a list of 3 tensors, each (30, 1)
    
    # Append predictions to the respective lists
    for i, quantile_pred in enumerate(fr_pred):
        fr_preds[i].append(quantile_pred.numpy().flatten())  # Convert to NumPy array and flatten


    # --- CausalEGM Model ---
    cegm_model = CausalEGM(params=params, random_seed=123 + rep)
    cegm_model.train(data=[x_train, y_train, z_train], n_iter=1000, verbose=False)

    # Predict using CausalEGM model
    cegm_pred = cegm_model.getADRF(x_list=x_vals.tolist())
    cegm_preds.append(cegm_pred)

# Convert predictions to NumPy arrays
fr_preds = [np.array(pred_list).T for pred_list in fr_preds]  # Each becomes (30, nrep)
cegm_preds = np.array(cegm_preds).T  # Convert to (30, nrep)

# Compute mean predictions and CIs for Frengression
fr_ci_lower = fr_preds[0].mean(axis=1)  # Mean of 2.5% predictions
fr_mean_pred = fr_preds[1].mean(axis=1)  # Mean predictions
fr_ci_upper = fr_preds[2].mean(axis=1)  # Mean of 97.5% predictions

# Compute mean predictions and CIs for CausalEGM
cegm_mean_pred = cegm_preds.mean(axis=1)
cegm_ci_lower = np.quantile(cegm_preds, 0.025, axis=1)
cegm_ci_upper = np.quantile(cegm_preds, 0.975, axis=1)


  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1: loss 1.1069,	loss_y 0.3359, 0.3541, 0.0365,	loss_eta 0.7710, 0.8010, 0.0598


In [None]:
import os
import json
import numpy as np

# --- after youâ€™ve computed these arrays --- #
# x_vals: array of dose levels, shape (n_doses,)
# fr_ci_lower, fr_mean_pred, fr_ci_upper: each shape (n_doses,)
# cegm_ci_lower, cegm_mean_pred, cegm_ci_upper: each shape (n_doses,)

# 1) build a results dict of pure Python types
summary = {
    "x_vals": x_vals.tolist(),
    "Frengression": {
        "ci_lower": fr_ci_lower.tolist(),
        "mean":     fr_mean_pred.tolist(),
        "ci_upper": fr_ci_upper.tolist()
    },
    "CausalEGM": {
        "ci_lower": cegm_ci_lower.tolist(),
        "mean":     cegm_mean_pred.tolist(),
        "ci_upper": cegm_ci_upper.tolist()
    }
}

# 2) make sure the folder exists
output_dir = "result/continuous"
os.makedirs(output_dir, exist_ok=True)

# 3) write to JSON
outfile = os.path.join(output_dir, "twins.json")
with open(outfile, "w") as f:
    json.dump(summary, f, indent=4)

print(f"Saved ADRF summary to {outfile}")
