In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import numpy as np
import pandas as pd

def load_ihdp_data(ihdp_path: Path) -> tuple[pd.DataFrame, list[str], str]:
    ihdp_cols = [s[:-1] for s in np.loadtxt(ihdp_path / "columns.txt", dtype=str)][:-2]
    ihdp_cols.extend([f"x{i}" for i in range(2, 26)])

    csvs = []
    for csv_path in (ihdp_path / "csv").glob("*.csv"):
        csvs.append(pd.read_csv(csv_path, header=None))
        break # TODO choose a table, for now using the first table
    data = pd.concat(csvs)
    data.columns = ihdp_cols

    tau_col_name = "delta_y"
    data[tau_col_name] = (data["y_cfactual"] - data["y_factual"]) * (-1) ** data["treatment"]
    exclude_cols = ["treatment", "y_cfactual", "y_factual", "mu0", "mu1"]
    return data, exclude_cols, tau_col_name

ihdp_data, ihdp_exclude_cols, ihdp_tau_colname = load_ihdp_data(Path("/Users/vzuev/Documents/git/gh_zuevval/tabrel/CEVAE/datasets/IHDP"))
ihdp_data.head()

In [None]:
from typing import Final

x_all = ihdp_data.drop(columns=ihdp_exclude_cols + [ihdp_tau_colname])

ihdp_last_numeric_index: Final[int] = 6
x_numeric = x_all.iloc[:, :ihdp_last_numeric_index]
x_cat = x_all.iloc[:, ihdp_last_numeric_index:]
x_numeric

In [None]:
import seaborn as sns

x_num_y = x_numeric.copy()
x_num_y[ihdp_tau_colname] = ihdp_data[ihdp_tau_colname]
sns.pairplot(x_num_y, hue=ihdp_tau_colname)

In [None]:
from itertools import product
from tqdm import tqdm

group_col: Final[str] = "x4"
x = x_all.drop(columns=[group_col])
x_len = len(x)
categories = x_all[group_col]
print("n_categories", len(categories.unique()))

r = np.zeros((x_len, x_len))
for i, j in tqdm(list(product(range(x_len), range(x_len)))):
    if np.isclose(categories[i], categories[j]):
        r[i, j] = 1

r

# S-learner

In [None]:
from collections import defaultdict
from tabrel.benchmark.nw_regr import run_training, metrics_mean, train_nw_arbitrary, NwModelConfig
from tabrel.train import train_relnet
from sklearn.metrics import mean_squared_error
import lightgbm as lgb

labels = ["rel", "nrel", "lgb", "rel-fts", "lgb-rel"]

def generate_indices(seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)
    indices = np.random.permutation(x_len)
    n_query, n_back = 200, 300
    q_indices = indices[:n_query]
    b_indices = indices[n_query:n_back]
    v_indices = indices[n_back:]
    return q_indices, b_indices, v_indices

def split_treated_non_treated(x: pd.DataFrame, treatment: np.ndarray, y_fact: pd.DataFrame) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    treated = treatment == 1
    x_treated, y_treated = x.loc[treated], y_fact.loc[treated]
    x_non_treated, y_non_treated = x.loc[~treated], y_fact.loc[~treated]
    return x_treated.to_numpy(), y_treated.to_numpy(), x_non_treated.to_numpy(), y_non_treated.to_numpy()

y_fact_colname, y_cfact_colname = "y_factual", "y_cfactual"
data_y_fact, data_treatment = ihdp_data[y_fact_colname], ihdp_data["treatment"]

x_s = x # x for S-learner
x_s["treatment"] = data_treatment
x_s_np = x_s.to_numpy()
y_s = data_y_fact.to_numpy() # Y for S-learner

model_cfg, lr, n_epochs = NwModelConfig(), 1e-3, 50
lgb_params: Final[dict[str, str | int]] = {"objective": "regression", "metric": "rmse", "verbosity": -1}

In [None]:
metrics = defaultdict(list)
for seed in tqdm(range(20)):
    np.random.seed(seed)
    ids_q, ids_b, ids_v = generate_indices(seed)
    ids_train = np.concatenate((ids_q, ids_b))
    xb, yb, xq, yq = x_s_np[ids_b], y_s[ids_b], x_s_np[ids_q], y_s[ids_q]
    xv, yv = x_s_np[ids_v], y_s[ids_v]
    r_q_b = r[ids_q][:, ids_b]
    r_val_train = r[ids_v][:, ids_train]

    # TabRel
    relnet_pehe, _, _, _, _ = train_relnet(
        x=x_s_np,
        y=y_s,
        r=r,
        backgnd_indices=ids_b,
        query_indices=ids_q,
        val_indices=ids_v,
        lr=0.01,
        n_epochs=800,
        n_layers=2,
        periodic_embed_dim=None,
        num_heads=2,
        progress_bar=False,
    )
    metrics["relnet"].append(relnet_pehe)

    # NW with rel
    _, _, _, model_rel = train_nw_arbitrary(
        x_backgnd=xb,
        y_backgnd=yb,
        x_query=xq,
        y_query=yq,
        x_val=xv,
        y_val=yv,
        r_query_backgnd=r_q_b,
        r_val_nonval=r_val_train,
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics["rel"].append(mean_squared_error(yv, model_rel.y_val_pred))

    # NW without rel
    _, _, _, model_nrel = train_nw_arbitrary(
        x_backgnd=xb,
        y_backgnd=yb,
        x_query=xq,
        y_query=yq,
        x_val=xv,
        y_val=yv,
        r_query_backgnd=np.zeros_like(r_q_b),
        r_val_nonval=np.zeros_like(r_val_train),
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics["nrel"].append(mean_squared_error(yv, model_nrel.y_val_pred))

    # LightGBM
    x_train = np.concatenate([xq, xb])
    y_train = np.concatenate([yq, yb])
    lgb_model = lgb.train(lgb_params, lgb.Dataset(x_train, label=y_train))
    y_pred_lgb = lgb_model.predict(xv)
    metrics["lgb"].append(mean_squared_error(yv, y_pred_lgb))

    # NW and LGB with rel as features
    x_broad = np.concatenate((x_s_np, r), axis=1)
    xb_broad, xq_broad, xv_broad = x_broad[ids_b], x_broad[ids_q], x_broad[ids_v]

    # NW with rel as features
    _, _, _, model_relfts = train_nw_arbitrary(
        x_backgnd=xb_broad,
        y_backgnd=yb,
        x_query=xq_broad,
        y_query=yq,
        x_val=xv_broad,
        y_val=yv,
        r_query_backgnd=np.zeros((len(xq_broad), len(xb_broad))),
        r_val_nonval=np.zeros((len(xv_broad), len(xb_broad) + len(xq_broad))),
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics["rel-fts"].append(mean_squared_error(yv, model_relfts.y_val_pred))

    # LightGBM with rel as features
    x_train_broad = np.concatenate([xq_broad, xb_broad])
    lgb_model_rel = lgb.train(lgb_params, lgb.Dataset(x_train_broad, label=y_train))
    y_pred_lgb_rel = lgb_model_rel.predict(xv_broad)
    metrics["lgb-rel"].append(mean_squared_error(yv, y_pred_lgb_rel))

# Show mean metrics
{ k: round(np.mean(v), 2) for k, v in metrics.items() }

# T-learner

In [None]:
import torch


metrics = defaultdict(list)
for seed in tqdm(range(30)):
    np.random.seed(seed)
    query_indices, back_indices, val_indices = generate_indices(seed)
    y_np = data_y_fact.to_numpy()
    xq, xb, xv = x.iloc[query_indices], x.iloc[back_indices], x.iloc[val_indices]
    yq, yb = data_y_fact.iloc[query_indices], data_y_fact.iloc[back_indices]
    tq, tb = data_treatment[query_indices], data_treatment[back_indices]

    xt_q, yt_q, xnt_q, ynt_q = split_treated_non_treated(xq, tq, yq)
    xt_b, yt_b, xnt_b, ynt_b = split_treated_non_treated(xb, tb, yb)

    iq_t = np.array([i for i in query_indices if data_treatment[i] == 1])
    iq_nt = np.array([i for i in query_indices if data_treatment[i] == 0])

    ib_t = np.array([i for i in back_indices if data_treatment[i] == 1])
    ib_nt = np.array([i for i in back_indices if data_treatment[i] == 0])

    data_y_cfact = ihdp_data[y_cfact_colname]
    yv_t = np.array([data_y_fact[i] if data_treatment[i] == 1 else data_y_cfact[i] for i in val_indices])
    yv_nt = np.array([data_y_fact[i] if data_treatment[i] == 0 else data_y_cfact[i] for i in val_indices])

    i_train_t = np.concatenate([iq_t, ib_t])
    i_train_nt = np.concatenate([iq_nt, ib_nt])

    r_q_b_treated = r[iq_t][:, ib_t]
    r_q_b_nt = r[iq_nt][:, ib_nt]
    r_val_nvt = r[val_indices][:, i_train_t]  # rel between val and treated train
    r_val_nvnt = r[val_indices][:, i_train_nt]  # rel between val and non-treated train

    label_t: Final[str] = "treated"
    label_nt: Final[str] = "non-treated"

    nw_broad_key: Final[str] = "nw_rel-as-features"
    trained_models = {
        "rel=True": {},
        "rel=False": {},
        nw_broad_key: {},
    }
    for rel, (xqi, xbi, yqi, ybi, yvi, r_q_b, r_v_nvi, label) in product(
        (True, False),
        (
            (xt_q, xt_b, yt_q, yt_b, yv_t, r_q_b_treated, r_val_nvt, label_t),
            (xnt_q, xnt_b, ynt_q, ynt_b, yv_nt, r_q_b_nt, r_val_nvnt, label_nt),
        ),
    ):
        _, _, _, model = train_nw_arbitrary(
            x_backgnd=xbi,
            y_backgnd=ybi,
            x_query=xqi,
            y_query=yqi,
            x_val=xv.to_numpy(),
            y_val=yvi,
            r_query_backgnd=r_q_b if rel else np.zeros_like(r_q_b),
            r_val_nonval=r_v_nvi if rel else np.zeros_like(r_v_nvi),
            cfg=model_cfg,
            lr=lr,
            n_epochs=n_epochs,
        )
        trained_models[f"rel={rel}"][label] = model
    
    # rel as features
    x_broad = np.concatenate((x.to_numpy(), r), axis=1)
    xb_broad, xq_broad, xv_broad = x_broad[back_indices], x_broad[query_indices], x_broad[val_indices]
    xt_q_broad, xnt_q_broad = x_broad[iq_t], x_broad[iq_nt]
    xt_b_broad, xnt_b_broad = x_broad[ib_t], x_broad[ib_nt]
    xv_broad = x_broad[val_indices]
    
    for (xqi, xbi, yqi, ybi, yvi, label) in (
        (xt_q_broad, xt_b_broad, yt_q, yt_b, yv_t, label_t),
        (xnt_q_broad, xnt_b_broad, ynt_q, ynt_b, yv_nt, label_nt),
    ):
        trained_models[nw_broad_key][label] = train_nw_arbitrary(
            x_backgnd=xbi,
            y_backgnd=ybi,
            x_query=xqi,
            y_query=yqi,
            x_val=xv_broad,
            y_val=yvi,
            r_query_backgnd=np.zeros((len(xqi), len(xbi))),
            r_val_nonval=np.zeros((len(xv_broad), len(xbi) + len(xqi))),
            cfg=model_cfg,
            lr=lr,
            n_epochs=n_epochs,
        )[-1]
    
    # LightGBM
    tau_true = yv_t - yv_nt
    yt_train, ynt_train = np.concatenate([yt_q, yt_b]), np.concatenate([ynt_q, ynt_b])
    for xq_ti, xb_ti, xq_nti, xb_nti, xv_i, lgb_key in (
        (xt_q, xt_b, xnt_q, xnt_b, xv, "lgb"),
        (xt_q_broad, xt_b_broad, xnt_q_broad, xnt_b_broad, xv_broad, "lgb-rel"),
    ):
        xt_train, xnt_train = np.concatenate([xq_ti, xb_ti]), np.concatenate([xq_nti, xb_nti])
        lgb_model_t = lgb.train(lgb_params, lgb.Dataset(xt_train, label=yt_train))
        lgb_model_nt = lgb.train(lgb_params, lgb.Dataset(xnt_train, ynt_train))
        tau_lgb = lgb_model_t.predict(xv_i) - lgb_model_nt.predict(xv_i)
        metrics[lgb_key].append(mean_squared_error(tau_true, tau_lgb))

    # TabRel
    def train_relnet_shorthand(x_: np.ndarray, y_: np.ndarray, r_: np.ndarray, 
                               bi_: np.ndarray, qi_: np.ndarray, vi_: np.ndarray) -> torch.Tensor:
        _, _, _, y_pred, _ = train_relnet(
            x=x_,
            y=y_,
            r=r_,
            backgnd_indices=bi_,
            query_indices=qi_,
            val_indices=vi_,
            lr=.007,
            n_epochs=1500,
            n_layers=2,
            periodic_embed_dim=None,
            embed_dim=32,
            num_heads=2,
            progress_bar=False,
        )
        return y_pred
    
    torch.manual_seed(seed)
    xt = np.concatenate([xt_b, xt_q, xv])
    bi_t = np.array(range(len(xt_b)))
    qi_t = np.array(range(len(xt_q))) + len(xt_b)
    vi_t = np.array(range(len(xv))) + len(xt_b) + len(xt_q)

    xnt = np.concatenate([xnt_b, xnt_q, xv])
    yt = np.concatenate([yt_b, yt_q, yv_t])
    ynt = np.concatenate([ynt_b, ynt_q, yv_nt])
    bi_nt = np.array(range(len(xnt_b)))
    qi_nt = np.array(range(len(xnt_q))) + len(xnt_b)
    vi_nt = np.array(range(len(xv))) + len(xnt_b) + len(xnt_q)

    i_t = np.concatenate([ib_t, iq_t, val_indices])
    i_nt = np.concatenate([ib_nt, iq_nt, val_indices])
    r_bqv_t = r[i_t][:, i_t]
    r_bqv_nt = r[i_nt][:, i_nt]
    y_pred_relnet_t = train_relnet_shorthand(x_=xt, y_=yt, r_=r_bqv_t, bi_=bi_t, qi_=qi_t, vi_=vi_t)
    y_pred_relnet_nt = train_relnet_shorthand(x_=xnt, y_=ynt, r_=r_bqv_nt, bi_=bi_nt, qi_=qi_nt, vi_=vi_nt)
    tau_pred_relnet = y_pred_relnet_t - y_pred_relnet_nt

    metrics["relnet"].append(mean_squared_error(tau_true, tau_pred_relnet))

    for key, models in trained_models.items():
        y_pred_treated = models[label_t].y_val_pred
        y_pred_nt = models[label_nt].y_val_pred

        tau_pred = y_pred_treated - y_pred_nt
        metrics[key].append(mean_squared_error(tau_true, tau_pred)) # PEHE

In [None]:
{k: f"{np.mean(v):.2f} & {np.std(v):.2f}" for k, v in metrics.items()}

# X-learner
## Step 1: two base learners as in T-learner

Copilot prompt:
<blockquote>
Write the first stage of X-learner using LightGBM only. Split the data into train and validation set, then split each set into treated and non-treated. Train two separate models for treated and non-treated patients, use cross-validation. Save their predictions for training and validation sets. For each point in the training set, calculate estimated tau values (difference between outcomes when treated and when not treated). Reduce code duplication by introducing functions when necessary
</blockquote>

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
import lightgbm as lgb

def split_by_treatment(x: np.ndarray, y: np.ndarray, treatment: np.ndarray):
    """Split features and labels into treated and non-treated groups."""
    treated_mask = treatment == 1
    return (
        x[treated_mask], y[treated_mask],
        x[~treated_mask], y[~treated_mask]
    )

def lgb_cv_train_predict(x_train, y_train, x_pred, params, n_splits=5, random_state=0):
    """Train LightGBM with cross-validation and predict on x_pred."""
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    preds = np.zeros(x_pred.shape[0])
    for train_idx, _ in kf.split(x_train):
        dtrain = lgb.Dataset(x_train[train_idx], label=y_train[train_idx])
        model = lgb.train(params, dtrain)
        preds += model.predict(x_pred) / n_splits
    return preds

# Prepare data
X = x.to_numpy()
y = data_y_fact.to_numpy()
treatment = data_treatment.to_numpy()

# Split into train and validation sets
X_train, X_val, y_train, y_val, tr_train, tr_val = train_test_split(
    X, y, treatment, test_size=0.2, random_state=42, stratify=treatment
)

# Split train and val sets into treated and non-treated
X_train_t, y_train_t, X_train_nt, y_train_nt = split_by_treatment(X_train, y_train, tr_train)
X_val_t, y_val_t, X_val_nt, y_val_nt = split_by_treatment(X_val, y_val, tr_val)

# LightGBM parameters
lgb_params = {"objective": "regression", "metric": "rmse", "verbosity": -1}

# Train models with CV and predict on train and val sets
y_hat_treated_train = lgb_cv_train_predict(X_train_t, y_train_t, X_train, lgb_params)
y_hat_nontreated_train = lgb_cv_train_predict(X_train_nt, y_train_nt, X_train, lgb_params)
y_hat_treated_val = lgb_cv_train_predict(X_train_t, y_train_t, X_val, lgb_params)
y_hat_nontreated_val = lgb_cv_train_predict(X_train_nt, y_train_nt, X_val, lgb_params)

# Save predictions for the training and validation sets
train_preds = {
    "y_hat_treated": y_hat_treated_train,
    "y_hat_nontreated": y_hat_nontreated_train,
    "treatment": tr_train,
    "y_train": y_train
}
val_preds = {
    "y_hat_treated": y_hat_treated_val,
    "y_hat_nontreated": y_hat_nontreated_val,
    "treatment": tr_val,
    "y_val": y_val
}

# Compute estimated tau for each training point
# For treated: tau_hat = y_train - y_hat_nontreated
# For non-treated: tau_hat = y_hat_treated - y_train
tau_hat_train = np.where(
    tr_train == 1,
    y_train - y_hat_nontreated_train,
    y_hat_treated_train - y_train
)
train_preds["tau_hat"] = tau_hat_train

# Optionally, save as DataFrame for further analysis
xlearner_stage1_train_df = pd.DataFrame({
    "y_train": y_train,
    "treatment": tr_train,
    "y_hat_treated": y_hat_treated_train,
    "y_hat_nontreated": y_hat_nontreated_train,
    "tau_hat": tau_hat_train
})
xlearner_stage1_val_df = pd.DataFrame({
    "y_val": y_val,
    "treatment": tr_val,
    "y_hat_treated": y_hat_treated_val,
    "y_hat_nontreated": y_hat_nontreated_val
})

xlearner_stage1_train_df.head(), xlearner_stage1_val_df.head()

# Step 2: $\tau$ imputation for training set

Copilot prompt:
<blockquote>
Now write the second stage of X-learner. Just like for S-learner, split data into query, background, validation sets; train models (NW, NW without R, NW with broad features (rels as features), LightGBM, LightGBM with rels as features, relnet), but do not include the treatment indicator in the feature set and predict tau, not the factual outcome. Calculate metrics (PEHE calculated as MSE between estimated and real tau) for the validation set

instead of using generate_indices, use the fixed validation set from the previous X-learner stage. Split train into query and background for each seed
</blockquote>


In [None]:
from collections import defaultdict
from sklearn.metrics import mean_squared_error
from tqdm import tqdm
import numpy as np

# Use tau_hat from X-learner stage 1 as y for training, and true tau for validation
tau_true = ihdp_data[ihdp_tau_colname].to_numpy()
tau_hat = xlearner_stage1_train_df["tau_hat"].to_numpy()

# Use the same train/val split as in X-learner stage 1
X = x.to_numpy()
train_idx = xlearner_stage1_train_df.index.values
val_idx = xlearner_stage1_val_df.index.values

X_train, X_val = X[train_idx], X[val_idx]
y_train, y_val = tau_hat, tau_true[val_idx]

metrics_x = defaultdict(list)
n_train = len(X_train)

for seed in tqdm(range(20)):
    np.random.seed(seed)
    # Split train into query and background for this seed
    perm = np.random.permutation(n_train)
    n_query, n_back = 200, 300
    ids_q = perm[:n_query]
    ids_b = perm[n_query:n_query + n_back]
    ids_train = np.concatenate([ids_q, ids_b])

    # Prepare features (no treatment indicator)
    xq, xb = X_train[ids_q], X_train[ids_b]
    yq, yb = y_train[ids_q], y_train[ids_b]
    xv, yv = X_val, y_val

    # NW with rel
    r_q_b = r[train_idx[ids_q]][:, train_idx[ids_b]]
    r_val_train = r[val_idx][:, train_idx[ids_train]]
    _, _, _, model_rel = train_nw_arbitrary(
        x_backgnd=xb,
        y_backgnd=yb,
        x_query=xq,
        y_query=yq,
        x_val=xv,
        y_val=yv,
        r_query_backgnd=r_q_b,
        r_val_nonval=r_val_train,
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics_x["rel"].append(mean_squared_error(yv, model_rel.y_val_pred))

    # NW without rel
    _, _, _, model_nrel = train_nw_arbitrary(
        x_backgnd=xb,
        y_backgnd=yb,
        x_query=xq,
        y_query=yq,
        x_val=xv,
        y_val=yv,
        r_query_backgnd=np.zeros_like(r_q_b),
        r_val_nonval=np.zeros_like(r_val_train),
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics_x["nrel"].append(mean_squared_error(yv, model_nrel.y_val_pred))

    # LightGBM
    x_train = np.concatenate([xq, xb])
    y_train_lgb = np.concatenate([yq, yb])
    lgb_model = lgb.train(lgb_params, lgb.Dataset(x_train, label=y_train_lgb))
    y_pred_lgb = lgb_model.predict(xv)
    metrics_x["lgb"].append(mean_squared_error(yv, y_pred_lgb))

    # NW and LGB with rel as features
    x_broad = np.concatenate((X, r), axis=1)
    x_train_broad = x_broad[train_idx[ids_train]]
    xq_broad, xb_broad = x_broad[train_idx[ids_q]], x_broad[train_idx[ids_b]]
    xv_broad = x_broad[val_idx]

    # NW with rel as features
    _, _, _, model_relfts = train_nw_arbitrary(
        x_backgnd=xb_broad,
        y_backgnd=yb,
        x_query=xq_broad,
        y_query=yq,
        x_val=xv_broad,
        y_val=yv,
        r_query_backgnd=np.zeros((len(xq_broad), len(xb_broad))),
        r_val_nonval=np.zeros((len(xv_broad), len(xb_broad) + len(xq_broad))),
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics_x["rel-fts"].append(mean_squared_error(yv, model_relfts.y_val_pred))

    # LightGBM with rel as features
    lgb_model_rel = lgb.train(lgb_params, lgb.Dataset(x_train_broad, label=y_train_lgb))
    y_pred_lgb_rel = lgb_model_rel.predict(xv_broad)
    metrics_x["lgb-rel"].append(mean_squared_error(yv, y_pred_lgb_rel))

    # RelNet (TabRel)
    relnet_pehe, _, _, _, _ = train_relnet(
        x=X_train,
        y=y_train,
        r=r[train_idx][:, train_idx],
        backgnd_indices=ids_b,
        query_indices=ids_q,
        val_indices=np.arange(len(X_val)),
        lr=0.01,
        n_epochs=800,
        n_layers=2,
        periodic_embed_dim=None,
        num_heads=2,
        progress_bar=False,
    )
    metrics_x["relnet"].append(relnet_pehe)

# Show mean metrics
{ k: round(np.mean(v), 1) for k, v in metrics_x.items() }