### Heterogeneous Treatment Effects with Sequential Paths

**Data Generation**  
- Paths: Each $X$ is a sequence of length 5–30, with actions in $\{0, 1, 2, 3\}$.  
- Alpha, Beta:  
  $\alpha(X) = \sum_{t=1}^{T-1} \big[\mathbf{1}\{X_t=1, X_{t+1}=3\} \cdot 5 + \mathbf{1}\{X_t=2, X_{t+1}=2\} \cdot 10 \big] + 0.2 \, T$,  
  $\beta(X) = 5 - 0.2 \, (\#\text{Up}) - 0.1 \, T$.  
- Outcome:  
  $Y = \alpha(X) + \beta(X)\,W + \varepsilon$,  
  $W \sim \text{Bernoulli}(0.5)$,  
  $\varepsilon \sim \mathcal{N}(0,1)$.

**Models**  
1. DNN: Flatten the path into a one-hot encoding, input into an MLP predicting $\alpha$ and $\beta$.  
2. LSTM: Embed the sequence, process it with an LSTM, output $\alpha$ and $\beta$.  
3. Transformer: Embed using positional encoding, apply self-attention, output $\alpha$ and $\beta$.

**Double-Robust ATE**  
$$
\widehat{\text{ATE}}
=
\frac{1}{n} \sum_{i=1}^{n} \biggl[\bigl(\hat{\alpha}_i + \hat{\beta}_i\bigr) 
+ \frac{W_i}{p} \bigl(Y_i - (\hat{\alpha}_i + \hat{\beta}_i)\bigr) 
- \hat{\alpha}_i 
- \frac{1-W_i}{1-p} \bigl(Y_i - \hat{\alpha}_i\bigr)\biggr], 
\quad p=0.5.
$$

In [31]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

from econml.dml import CausalForestDML
from lightgbm import LGBMRegressor

############################################################################
# 0) HYPERPARAMETERS
############################################################################

# General / data / training
GEN_H = {
    "max_len":      30,
    "action_dim":   4,
    "pop_size":     20000,
    "sample_train": 10000,
    "sample_test":  2000,
    "batch_size":   256,
    "lr":           1e-3
}

# Neural net
NN_H = {
    "dnn_hidden_dim":  16,
    "dnn_num_layers":  4,
    "dnn_epochs":      100,   # smaller number for demonstration

    "lstm_embed_dim":  4,
    "lstm_hidden_dim": 8,
    "lstm_epochs":     150,

    "tf_d_model":      4,
    "tf_nhead":        1,
    "tf_num_layers":   1,
    "tf_epochs":       50
}

# Causal Forest
CF_H = {
    "n_estimators":  200,
    "max_depth":     None,
    "num_leaves":    31,
    "criterion":     "mse"
}


############################################################################
# 1) DATA GENERATION
############################################################################

def generate_random_path(max_len=30, action_dim=4):
    length = np.random.randint(5, max_len+1)
    return np.random.randint(0, action_dim, size=length)

def compute_alpha_beta_from_path(path):
    length = len(path)
    alpha_val = 0.0
    for t in range(length-1):
        if path[t] == 1 and path[t+1] == 3:
            alpha_val += 5.0
        if path[t] == 2 and path[t+1] == 2:
            alpha_val += 10.0
    alpha_val += 0.2 * length
    n_up = np.sum(path==0)
    beta_val = 5.0 - 0.2*n_up - 0.1*length
    return alpha_val, beta_val

def build_population(n=20000, max_len=30, action_dim=4):
    pop = []
    for _ in range(n):
        path = generate_random_path(max_len, action_dim)
        a, b = compute_alpha_beta_from_path(path)
        pop.append((path, a, b))
    return pop

def minmax_normalize_population(pop):
    alpha_vals = [p[1] for p in pop]
    beta_vals  = [p[2] for p in pop]
    a_min, a_max = min(alpha_vals), max(alpha_vals)
    b_min, b_max = min(beta_vals),  max(beta_vals)
    
    new_pop = []
    for (path, a, b) in pop:
        a_norm = (a - a_min)/(a_max - a_min + 1e-9)
        b_norm = (b - b_min)/(b_max - b_min + 1e-9)
        new_pop.append((path, a_norm, b_norm))
    return new_pop, (a_min, a_max, b_min, b_max)

def compute_true_ate(pop):
    return np.mean([p[2] for p in pop])

def sample_dataset(pop, sample_size=10000, max_len=30, action_dim=4):
    idxs = np.random.choice(len(pop), size=sample_size, replace=False)
    data = []
    for idx in idxs:
        path, a, b = pop[idx]
        w = np.random.binomial(1, 0.5)
        noise = np.random.randn()
        y = a + b*w + noise
        data.append((path, w, y, a, b))
    return data

############################################################################
# 2) MODEL DEFS
############################################################################

def path_to_onehot(path, max_len=30, action_dim=4):
    arr = np.zeros((max_len, action_dim), dtype=np.float32)
    length = min(len(path), max_len)
    for i in range(length):
        arr[i, path[i]] = 1.0
    return arr.reshape(-1)

def collate_fn_dnn(samples, max_len=30, action_dim=4):
    X_list, W_list, Y_list, A_list, B_list = [],[],[],[],[]
    for path,w,y,a,b in samples:
        oh = path_to_onehot(path, max_len, action_dim)
        X_list.append(oh)
        W_list.append(w)
        Y_list.append(y)
        A_list.append(a)
        B_list.append(b)
    X_t = torch.tensor(X_list, dtype=torch.float32)
    W_t = torch.tensor(W_list, dtype=torch.float32).view(-1,1)
    Y_t = torch.tensor(Y_list, dtype=torch.float32).view(-1,1)
    A_t = torch.tensor(A_list, dtype=torch.float32)
    B_t = torch.tensor(B_list, dtype=torch.float32)
    return X_t, W_t, Y_t, A_t, B_t

class DNNAlphaBeta(nn.Module):
    """
    Deeper feedforward => (alpha, beta).
    'num_layers': how many hidden layers
    'hidden_dim': dimension of each hidden layer
    """
    def __init__(self, input_dim=120, hidden_dim=16, num_layers=3):
        super().__init__()
        layers = []
        # first hidden layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        # additional hidden layers
        for _ in range(num_layers-1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        # final
        layers.append(nn.Linear(hidden_dim, 2))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

def collate_fn_lstm(samples, max_len=30):
    x_padded, seq_len = [], []
    W_list, Y_list, A_list, B_list = [],[],[],[]
    for path,w,y,a,b in samples:
        arr = np.zeros(max_len, dtype=np.int64)
        plen = min(len(path), max_len)
        arr[:plen] = path[:plen]
        x_padded.append(arr)
        seq_len.append(plen)
        W_list.append(w)
        Y_list.append(y)
        A_list.append(a)
        B_list.append(b)
    x_pad = torch.tensor(x_padded, dtype=torch.long)
    slen  = torch.tensor(seq_len,  dtype=torch.long)
    W_t   = torch.tensor(W_list, dtype=torch.float32).view(-1,1)
    Y_t   = torch.tensor(Y_list, dtype=torch.float32).view(-1,1)
    A_t   = torch.tensor(A_list, dtype=torch.float32)
    B_t   = torch.tensor(B_list, dtype=torch.float32)
    return (x_pad, slen, W_t, Y_t, A_t, B_t)

class LSTMAlphaBeta(nn.Module):
    def __init__(self, action_dim=4, embed_dim=4, hidden_dim=4):
        super().__init__()
        self.embed = nn.Embedding(action_dim, embed_dim)
        self.lstm  = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.head  = nn.Linear(hidden_dim, 2)
    def forward(self, x_padded, seq_len):
        emb = self.embed(x_padded)
        packed = nn.utils.rnn.pack_padded_sequence(
            emb, seq_len, batch_first=True, enforce_sorted=False
        )
        _, (h_n, _) = self.lstm(packed)
        return self.head(h_n[-1])

def collate_fn_transformer(samples, max_len=30):
    x_padded, seq_len = [], []
    W_list, Y_list, A_list, B_list = [],[],[],[]
    for path,w,y,a,b in samples:
        arr = np.zeros(max_len, dtype=np.int64)
        plen = min(len(path), max_len)
        arr[:plen] = path[:plen]
        x_padded.append(arr)
        seq_len.append(plen)
        W_list.append(w)
        Y_list.append(y)
        A_list.append(a)
        B_list.append(b)
    x_pad = torch.tensor(x_padded, dtype=torch.long)
    slen  = torch.tensor(seq_len,  dtype=torch.long)
    W_t   = torch.tensor(W_list, dtype=torch.float32).view(-1,1)
    Y_t   = torch.tensor(Y_list, dtype=torch.float32).view(-1,1)
    A_t   = torch.tensor(A_list, dtype=torch.float32)
    B_t   = torch.tensor(B_list, dtype=torch.float32)
    return (x_pad, slen, W_t, Y_t, A_t, B_t)

class TransformerAlphaBeta(nn.Module):
    def __init__(self, action_dim=4, d_model=8, nhead=1, num_layers=1, max_len=30):
        super().__init__()
        self.embed = nn.Embedding(action_dim, d_model)
        self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model))
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=16, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.head = nn.Linear(d_model, 2)
        self.max_len = max_len

    def forward(self, x_padded, seq_len):
        B, L = x_padded.shape
        emb = self.embed(x_padded) + self.pos_embed[:, :L, :]
        mask = torch.zeros((B, L), dtype=torch.bool, device=emb.device)
        for i in range(B):
            if seq_len[i]<L:
                mask[i, seq_len[i]:] = True
        out_seq = self.transformer(emb, src_key_padding_mask=mask)
        out_vec = []
        for i in range(B):
            plen = seq_len[i].item()
            if plen>0:
                out_vec.append(out_seq[i,:plen,:].mean(dim=0))
            else:
                out_vec.append(torch.zeros(d_model, device=out_seq.device))
        out_vec = torch.stack(out_vec, dim=0)
        return self.head(out_vec)

############################################################################
# 3) METRICS + EVALUATION
############################################################################

def r2_score(true_arr, pred_arr):
    ss_res = np.sum((pred_arr - true_arr)**2)
    ss_tot = np.sum((true_arr - np.mean(true_arr))**2)
    if ss_tot < 1e-9:
        return 1.0
    return 1.0 - ss_res/ss_tot

def dr_ate(alpha_hat, beta_hat, w_arr, y_arr):
    mu0 = alpha_hat
    mu1 = alpha_hat + beta_hat
    e=0.5
    IF = (mu1 + w_arr*(y_arr - mu1)/e) - (mu0 + (1-w_arr)*(y_arr - mu0)/(1-e))
    ate = IF.mean()
    se  = IF.std(ddof=1)/np.sqrt(len(IF))
    return ate, se

def evaluate_neural(model, dataset, model_type="dnn", max_len=30, action_dim=4, batch_size=256):
    """
    Returns (ate_est, ate_se, r2_y, r2_alpha, r2_beta) for given dataset
    """
    alphaT_list, betaT_list = [], []
    alphaP_list, betaP_list = [], []
    yT_list, yP_list        = [], []
    w_list                  = []

    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i+batch_size]
        if model_type=="dnn":
            X_t, W_t, Y_t, A_t, B_t = collate_fn_dnn(batch, max_len, action_dim)
            with torch.no_grad():
                ab = model(X_t)
        elif model_type=="lstm":
            x_pad, slen, W_t, Y_t, A_t, B_t = collate_fn_lstm(batch, max_len)
            with torch.no_grad():
                ab = model(x_pad, slen)
        else:  # "transformer"
            x_pad, slen, W_t, Y_t, A_t, B_t = collate_fn_transformer(batch, max_len)
            with torch.no_grad():
                ab = model(x_pad, slen)
        
        alpha_hat = ab[:,0].cpu().numpy()
        beta_hat  = ab[:,1].cpu().numpy()
        
        alphaT_list.append(A_t.numpy())
        betaT_list.append(B_t.numpy())
        
        w_arr = W_t.squeeze().numpy()
        y_true= Y_t.squeeze().numpy()
        y_pred= alpha_hat + beta_hat*w_arr
        
        alphaP_list.append(alpha_hat)
        betaP_list.append(beta_hat)
        w_list.append(w_arr)
        yT_list.append(y_true)
        yP_list.append(y_pred)
    
    alpha_true = np.concatenate(alphaT_list)
    beta_true  = np.concatenate(betaT_list)
    alpha_pred = np.concatenate(alphaP_list)
    beta_pred  = np.concatenate(betaP_list)
    w_arr      = np.concatenate(w_list)
    y_true     = np.concatenate(yT_list)
    y_pred     = np.concatenate(yP_list)
    
    ate_est, ate_se = dr_ate(alpha_pred, beta_pred, w_arr, y_true)
    r2_y     = r2_score(y_true, y_pred)
    r2_alpha = r2_score(alpha_true, alpha_pred)
    r2_beta  = r2_score(beta_true,  beta_pred)
    return ate_est, ate_se, r2_y, r2_alpha, r2_beta

def evaluate_cforest(cf_est, dataset, max_len=30, action_dim=4, batch_size=256):
    X_list, beta_true_list = [], []
    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i+batch_size]
        for (path,w,y,a,b) in batch:
            oh = path_to_onehot(path, max_len, action_dim)
            X_list.append(oh)
            beta_true_list.append(b)
    
    X_arr = np.array(X_list, dtype=np.float32)
    beta_true = np.array(beta_true_list, dtype=np.float32)
    
    ate_est = cf_est.ate(X=X_arr)
    ci_low, ci_high = cf_est.ate_interval(X=X_arr)
    ate_se = 0.5*(ci_high - ci_low)
    
    beta_hat = cf_est.effect(X_arr)
    r2_beta  = r2_score(beta_true, beta_hat)
    
    return ate_est, ate_se, np.nan, np.nan, r2_beta

############################################################################
# 4) MAIN
############################################################################

def main():
    # Unpack
    max_len      = GEN_H["max_len"]
    action_dim   = GEN_H["action_dim"]
    pop_size     = GEN_H["pop_size"]
    sample_train = GEN_H["sample_train"]
    sample_test  = GEN_H["sample_test"]
    batch_size   = GEN_H["batch_size"]
    lr           = GEN_H["lr"]
    
    dnn_hidden_dim = NN_H["dnn_hidden_dim"]
    dnn_num_layers = NN_H["dnn_num_layers"]
    dnn_epochs     = NN_H["dnn_epochs"]
    
    lstm_embed_dim  = NN_H["lstm_embed_dim"]
    lstm_hidden_dim = NN_H["lstm_hidden_dim"]
    lstm_epochs     = NN_H["lstm_epochs"]
    
    tf_d_model    = NN_H["tf_d_model"]
    tf_nhead      = NN_H["tf_nhead"]
    tf_num_layers = NN_H["tf_num_layers"]
    tf_epochs     = NN_H["tf_epochs"]
    
    cf_n_estimators= CF_H["n_estimators"]
    cf_max_depth   = CF_H["max_depth"]
    cf_num_leaves  = CF_H["num_leaves"]
    cf_criterion   = CF_H["criterion"]
    
    # 1) Build population => minmax => sample train/test
    pop_raw = build_population(pop_size, max_len, action_dim)
    pop_norm, _ = minmax_normalize_population(pop_raw)
    true_ate = compute_true_ate(pop_norm)
    print(f"Population => alpha,beta in [0,1], True ATE= {true_ate:.4f}")
    
    full_data = sample_dataset(pop_norm, (sample_train+sample_test), max_len, action_dim)
    np.random.shuffle(full_data)
    train_data = full_data[:sample_train]
    test_data  = full_data[sample_train:]
    
    # 2) Instantiate DNN, LSTM, Transformer
    dnn_model = DNNAlphaBeta(
        input_dim=max_len*action_dim,
        hidden_dim=dnn_hidden_dim,
        num_layers=dnn_num_layers
    )
    
    class LSTMAlphaBeta(nn.Module):
        def __init__(self, action_dim=4, embed_dim=4, hidden_dim=4):
            super().__init__()
            self.embed = nn.Embedding(action_dim, embed_dim)
            self.lstm  = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
            self.head  = nn.Linear(hidden_dim, 2)
        def forward(self, x_padded, seq_len):
            emb = self.embed(x_padded)
            packed = nn.utils.rnn.pack_padded_sequence(
                emb, seq_len, batch_first=True, enforce_sorted=False
            )
            _, (h_n,_) = self.lstm(packed)
            return self.head(h_n[-1])
    lstm_model = LSTMAlphaBeta(
        action_dim=action_dim,
        embed_dim=lstm_embed_dim,
        hidden_dim=lstm_hidden_dim
    )
    
    class TransformerAlphaBeta(nn.Module):
        def __init__(self, action_dim=4, d_model=8, nhead=1, num_layers=1, max_len=30):
            super().__init__()
            self.embed = nn.Embedding(action_dim, d_model)
            self.pos_embed = nn.Parameter(torch.zeros(1, max_len, d_model))
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model, nhead=nhead, dim_feedforward=16, batch_first=True
            )
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.head = nn.Linear(d_model, 2)
            self.max_len = max_len
        def forward(self, x_padded, seq_len):
            B, L = x_padded.shape
            emb = self.embed(x_padded) + self.pos_embed[:, :L, :]
            mask = torch.zeros((B, L), dtype=torch.bool, device=emb.device)
            for i in range(B):
                if seq_len[i]<L:
                    mask[i, seq_len[i]:] = True
            out_seq = self.transformer(emb, src_key_padding_mask=mask)
            out_vec = []
            for i in range(B):
                plen= seq_len[i].item()
                if plen>0:
                    out_vec.append(out_seq[i,:plen,:].mean(dim=0))
                else:
                    out_vec.append(torch.zeros(d_model, device=out_seq.device))
            out_vec = torch.stack(out_vec, dim=0)
            return self.head(out_vec)
    
    transformer_model = TransformerAlphaBeta(
        action_dim=action_dim,
        d_model=tf_d_model,
        nhead=tf_nhead,
        num_layers=tf_num_layers,
        max_len=max_len
    )
    
    # 3) Train loops
    mse_loss = nn.MSELoss()
    
    def train_dnn():
        opt = optim.Adam(dnn_model.parameters(), lr=lr)
        for _ in tqdm(range(dnn_epochs), desc="DNN"):
            np.random.shuffle(train_data)
            for i in range(0, len(train_data), batch_size):
                batch = train_data[i:i+batch_size]
                X_t, W_t, Y_t, _, _ = collate_fn_dnn(batch, max_len, action_dim)
                ab_pred = dnn_model(X_t)
                alpha_hat = ab_pred[:,0:1]
                beta_hat  = ab_pred[:,1:2]
                Y_hat = alpha_hat + beta_hat*W_t
                loss = mse_loss(Y_hat, Y_t)
                opt.zero_grad()
                loss.backward()
                opt.step()
    
    def train_lstm():
        opt = optim.Adam(lstm_model.parameters(), lr=lr)
        for _ in tqdm(range(lstm_epochs), desc="LSTM"):
            np.random.shuffle(train_data)
            for i in range(0, len(train_data), batch_size):
                batch = train_data[i:i+batch_size]
                x_pad, slen, W_t, Y_t, _, _ = collate_fn_lstm(batch, max_len)
                ab_pred = lstm_model(x_pad, slen)
                alpha_hat = ab_pred[:,0:1]
                beta_hat  = ab_pred[:,1:2]
                Y_hat = alpha_hat + beta_hat*W_t
                loss = mse_loss(Y_hat, Y_t)
                opt.zero_grad()
                loss.backward()
                opt.step()
    
    def train_transformer():
        opt = optim.Adam(transformer_model.parameters(), lr=lr, weight_decay=1e-5)
        for _ in tqdm(range(tf_epochs), desc="Transformer"):
            np.random.shuffle(train_data)
            for i in range(0, len(train_data), batch_size):
                batch = train_data[i:i+batch_size]
                x_pad, slen, W_t, Y_t, _, _ = collate_fn_transformer(batch, max_len)
                ab_pred = transformer_model(x_pad, slen)
                alpha_hat = ab_pred[:,0:1]
                beta_hat  = ab_pred[:,1:2]
                Y_hat = alpha_hat + beta_hat*W_t
                loss = mse_loss(Y_hat, Y_t)
                opt.zero_grad()
                loss.backward()
                opt.step()
    
    # Train them
    train_dnn()
    train_lstm()
    train_transformer()
    
    # 4) Fit causal forest
    def parse_data_cforest(data):
        X_list, T_list, Y_list = [], [], []
        for (path,w,y,a,b) in data:
            oh = path_to_onehot(path, max_len, action_dim)
            X_list.append(oh)
            T_list.append(w)
            Y_list.append(y)
        return np.array(X_list, dtype=np.float32), np.array(T_list, dtype=np.float32), np.array(Y_list, dtype=np.float32)
    
    print("\nFitting CausalForestDML with more leaves...")
    X_train_cf, T_train_cf, Y_train_cf = parse_data_cforest(train_data)
    cf_est = CausalForestDML(
        model_y=LGBMRegressor(num_leaves=cf_num_leaves, verbose=-1),
        model_t=LGBMRegressor(num_leaves=cf_num_leaves, verbose=-1),
        discrete_treatment=True,
        random_state=123,
        n_estimators=cf_n_estimators,
        criterion=cf_criterion,
        max_depth=cf_max_depth
    )
    cf_est.fit(Y_train_cf, T_train_cf, X=X_train_cf, W=None)
    
    # 5) Evaluate both train & test
    #    For neural models => (ATE, SE, R2(y), R2(a), R2(b))
    #    For CF => (ATE, SE, N/A, N/A, R2(b))
    
    # Neural nets
    dnn_train_res = evaluate_neural(dnn_model, train_data, "dnn", max_len, action_dim, batch_size)
    dnn_test_res  = evaluate_neural(dnn_model, test_data,  "dnn", max_len, action_dim, batch_size)
    
    lstm_train_res= evaluate_neural(lstm_model, train_data, "lstm", max_len, action_dim, batch_size)
    lstm_test_res = evaluate_neural(lstm_model, test_data,  "lstm", max_len, action_dim, batch_size)
    
    tf_train_res  = evaluate_neural(transformer_model, train_data,"transformer", max_len, action_dim, batch_size)
    tf_test_res   = evaluate_neural(transformer_model, test_data, "transformer", max_len, action_dim, batch_size)
    
    # CForest
    cf_train_res  = evaluate_cforest(cf_est, train_data, max_len, action_dim, batch_size)
    cf_test_res   = evaluate_cforest(cf_est, test_data,  max_len, action_dim, batch_size)
    
    # 6) Print final table
    # We'll do: 
    #    Model    |  ATE_est(train)  |  R^2(y,train)  |  R^2(a,train)  |  R^2(b,train)  ||  ATE_est(test)  |  R^2(y,test)  |  R^2(a,test)  |  R^2(b,test)
    #
    # For CF => alpha,y => N/A
    def format_r2(x):
        if np.isnan(x):
            return "  N/A"
        else:
            return f"{x:7.3f}"
    
    # We unify in one table
    names = ["DNN","LSTM","Transformer","CForest"]
    train_res = [dnn_train_res, lstm_train_res, tf_train_res, cf_train_res]
    test_res  = [dnn_test_res,  lstm_test_res,  tf_test_res,  cf_test_res]
    # each res => (ate, se, r2y, r2a, r2b)

    print("\n=== TRAIN vs. TEST RESULTS ===")
    print("Model        |    ATE(train)  R^2(y)   R^2(a)   R^2(b)   ||   ATE(test)   R^2(y)   R^2(a)   R^2(b)")
    
    for nm, (train_eval, test_eval) in zip(names, zip(train_res, test_res)):
        at_tr, se_tr, r2y_tr, r2a_tr, r2b_tr = train_eval
        at_te, se_te, r2y_te, r2a_te, r2b_te = test_eval
        
        # alpha,y => N/A if nan
        r2ytr_str = format_r2(r2y_tr)
        r2atr_str = format_r2(r2a_tr)
        r2btr_str = format_r2(r2b_tr)
        
        r2yte_str = format_r2(r2y_te)
        r2ate_str = format_r2(r2a_te)
        r2bte_str = format_r2(r2b_te)
        
        print(f"{nm:<12} | {at_tr:10.4f}  {r2ytr_str:>7}  {r2atr_str:>7}  {r2btr_str:>7}  || {at_te:10.4f}  {r2yte_str:>7}  {r2ate_str:>7}  {r2bte_str:>7}")

    print(f"\nTrue ATE= {true_ate:.4f}")

if __name__=="__main__":
    main()

Population => alpha,beta in [0,1], True ATE= 0.6413


DNN: 100%|██████████| 100/100 [00:13<00:00,  7.43it/s]
LSTM: 100%|██████████| 150/150 [00:30<00:00,  4.85it/s]
Transformer: 100%|██████████| 50/50 [00:28<00:00,  1.74it/s]



Fitting CausalForestDML with more leaves...

=== TRAIN vs. TEST RESULTS ===
Model        |    ATE(train)  R^2(y)   R^2(a)   R^2(b)   ||   ATE(test)   R^2(y)   R^2(a)   R^2(b)
DNN          |     0.6205    0.357   -8.097   -2.518  ||     0.7189   -0.128   -9.177   -2.374
LSTM         |     0.6271    0.103   -0.090    0.370  ||     0.6957    0.105   -0.143    0.372
Transformer  |     0.6297    0.096   -0.015    0.721  ||     0.7007    0.113   -0.020    0.725
CForest      |     0.6198      N/A      N/A    0.505  ||     0.6193      N/A      N/A    0.550

True ATE= 0.6413
