In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from tabrel.utils.treatment import load_ihdp_data, generate_indices

ihdp_data, ihdp_exclude_cols, ihdp_tau_colname, ihdp_y_fact_colname, ihdp_y_cfact_colname, ihdp_treatment_colname = load_ihdp_data(Path("../CEVAE/datasets/IHDP"))
x_all = ihdp_data.drop(columns=ihdp_exclude_cols + [ihdp_tau_colname])
y_fact, y_cfact = ihdp_data[ihdp_y_fact_colname], ihdp_data[ihdp_y_cfact_colname]
treatment_all = ihdp_data[ihdp_treatment_colname]
q_ids, b_ids, v_ids = generate_indices(seed=42, n_total=len(x_all))

In [None]:
import numpy as np
import pandas as pd

# 1. x_rs: concatenate x_all with validation set features
x_rs = pd.concat([x_all, x_all.iloc[v_ids]], ignore_index=True)

# 2. y_fact_rs: concatenate y_fact with counterfactual outcomes for validation set
y_fact_rs = pd.concat([y_fact, y_cfact.iloc[v_ids]], ignore_index=True)

# 3. treatment_rs: concatenate treatment_all with inverted treatment values for validation set
treatment_rs = pd.concat(
    [treatment_all, 1 - treatment_all.iloc[v_ids]], ignore_index=True
)

# 4. v_rs_ids: concatenate original indices (for new rows) and v_ids
new_indices = np.arange(len(x_all), len(x_all) + len(v_ids))
v_rs_ids = np.concatenate([new_indices, v_ids])

# 5. r_rs: square matrix, r_rs[i, j] = 1 if treatment[i] == treatment[j], else 0
treatment_array = treatment_rs.to_numpy()
r_rs = (treatment_array[:, None] == treatment_array[None, :]).astype(int)

In [None]:
import torch
from sklearn.metrics import mean_squared_error

from tabrel.train import train_relnet

tau_val_ground = ihdp_data[ihdp_tau_colname][v_ids].to_numpy()
n_val = len(v_ids)

def train_relnet_shorthand(x: pd.DataFrame, r: np.ndarray, seed: int = 42) -> float:
    torch.manual_seed(seed)
    _, _, _, y_val_pred, _ = train_relnet(
        x=x.to_numpy(),
        y=y_fact_rs.to_numpy(),
        r=r,
        backgnd_indices=b_ids,
        query_indices=q_ids,
        val_indices=v_rs_ids,
        lr=.01,
        n_epochs=1500,
        n_layers=2,
        periodic_embed_dim=None,
        num_heads=2,
        progress_bar=False,
        print_loss=True,
    )
    y_val_fact_hat, y_val_cfact_hat = y_val_pred[:n_val], y_val_pred[n_val:]
    tau_val_pred = np.abs(y_val_fact_hat - y_val_cfact_hat)
    return mean_squared_error(tau_val_ground, tau_val_pred)

pehe_val_rs = train_relnet_shorthand(x_rs, r_rs)
print(f"PEHE {pehe_val_rs:.2f}") # 16.51

# Standard S-learner

In [None]:
x_s = x_rs.copy()
x_s[ihdp_treatment_colname] = treatment_rs
r_s = np.zeros_like(r_rs)
pehe_val_s = train_relnet_shorthand(x_s, r_s)
print(f"PEHE (standard S-learner): {pehe_val_s:.2f}") # 26.89