In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import numpy as np
import pandas as pd

def load_ihdp_data(ihdp_path: Path) -> tuple[pd.DataFrame, list[str], str]:
    ihdp_cols = [s[:-1] for s in np.loadtxt(ihdp_path / "columns.txt", dtype=str)][:-2]
    ihdp_cols.extend([f"x{i}" for i in range(2, 26)])

    csvs = []
    for csv_path in (ihdp_path / "csv").glob("*.csv"):
        csvs.append(pd.read_csv(csv_path, header=None))
        break # TODO choose a table, for now using the first table
    data = pd.concat(csvs)
    data.columns = ihdp_cols

    tau_col_name = "delta_y"
    data[tau_col_name] = (data["y_cfactual"] - data["y_factual"]) * (-1) ** data["treatment"]
    exclude_cols = ["treatment", "y_cfactual", "y_factual", "mu0", "mu1"]
    return data, exclude_cols, tau_col_name

ihdp_data, ihdp_exclude_cols, ihdp_tau_colname = load_ihdp_data(Path("/Users/vzuev/Documents/git/gh_zuevval/tabrel/CEVAE/datasets/IHDP"))
ihdp_data.head()

In [None]:
from typing import Final

x_all = ihdp_data.drop(columns=ihdp_exclude_cols + [ihdp_tau_colname])

ihdp_last_numeric_index: Final[int] = 6
x_numeric = x_all.iloc[:, :ihdp_last_numeric_index]
x_cat = x_all.iloc[:, ihdp_last_numeric_index:]
x_numeric

In [None]:
import seaborn as sns

x_num_y = x_numeric.copy()
x_num_y[ihdp_tau_colname] = ihdp_data[ihdp_tau_colname]
sns.pairplot(x_num_y, hue=ihdp_tau_colname)

In [None]:
from itertools import product
from tqdm import tqdm

group_col: Final[str] = "x4"
x = x_all.drop(columns=[group_col])
x_len = len(x)
categories = x_all[group_col]
print("n_categories", len(categories.unique()))

r = np.zeros((x_len, x_len))
for i, j in tqdm(list(product(range(x_len), range(x_len)))):
    if np.isclose(categories[i], categories[j]):
        r[i, j] = 1

r

# S-learner

In [None]:
from collections import defaultdict
from tabrel.benchmark.nw_regr import run_training, metrics_mean, train_nw_arbitrary, NwModelConfig
from tabrel.train import train_relnet
from sklearn.metrics import mean_squared_error
import lightgbm as lgb

labels = ["rel", "nrel", "lgb", "rel-fts", "lgb-rel"]

def generate_indices(seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)
    indices = np.random.permutation(x_len)
    n_query, n_back = 200, 300
    q_indices = indices[:n_query]
    b_indices = indices[n_query:n_back]
    v_indices = indices[n_back:]
    return q_indices, b_indices, v_indices

def split_treated_non_treated(x: pd.DataFrame, treatment: np.ndarray, y_fact: pd.DataFrame) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    treated = treatment == 1
    x_treated, y_treated = x.loc[treated], y_fact.loc[treated]
    x_non_treated, y_non_treated = x.loc[~treated], y_fact.loc[~treated]
    return x_treated.to_numpy(), y_treated.to_numpy(), x_non_treated.to_numpy(), y_non_treated.to_numpy()

y_fact_colname, y_cfact_colname = "y_factual", "y_cfactual"
data_y_fact, data_treatment = ihdp_data[y_fact_colname], ihdp_data["treatment"]

x_s = x # x for S-learner
x_s["treatment"] = data_treatment
x_s_np = x_s.to_numpy()
y_s = data_y_fact.to_numpy() # Y for S-learner

model_cfg, lr, n_epochs = NwModelConfig(), 1e-3, 50
lgb_params: Final[dict[str, str | int]] = {"objective": "regression", "metric": "rmse", "verbosity": -1}

In [None]:
metrics = defaultdict(list)
for seed in tqdm(range(20)):
    np.random.seed(seed)
    ids_q, ids_b, ids_v = generate_indices(seed)
    ids_train = np.concatenate((ids_q, ids_b))
    xb, yb, xq, yq = x_s_np[ids_b], y_s[ids_b], x_s_np[ids_q], y_s[ids_q]
    xv, yv = x_s_np[ids_v], y_s[ids_v]
    r_q_b = r[ids_q][:, ids_b]
    r_val_train = r[ids_v][:, ids_train]

    # TabRel
    relnet_pehe, _, _ = train_relnet(
        x=x_s_np,
        y=y_s,
        r=r,
        backgnd_indices=ids_b,
        query_indices=ids_q,
        val_indices=ids_v,
        lr=0.01,
        n_epochs=800,
        n_layers=2,
        periodic_embed_dim=None,
        num_heads=2,
        progress_bar=False,
    )
    metrics["relnet"].append(relnet_pehe)

    # NW with rel
    _, _, _, model_rel = train_nw_arbitrary(
        x_backgnd=xb,
        y_backgnd=yb,
        x_query=xq,
        y_query=yq,
        x_val=xv,
        y_val=yv,
        r_query_backgnd=r_q_b,
        r_val_nonval=r_val_train,
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics["rel"].append(mean_squared_error(yv, model_rel.y_val_pred))

    # NW without rel
    _, _, _, model_nrel = train_nw_arbitrary(
        x_backgnd=xb,
        y_backgnd=yb,
        x_query=xq,
        y_query=yq,
        x_val=xv,
        y_val=yv,
        r_query_backgnd=np.zeros_like(r_q_b),
        r_val_nonval=np.zeros_like(r_val_train),
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics["nrel"].append(mean_squared_error(yv, model_nrel.y_val_pred))

    # LightGBM
    x_train = np.concatenate([xq, xb])
    y_train = np.concatenate([yq, yb])
    lgb_model = lgb.train(lgb_params, lgb.Dataset(x_train, label=y_train))
    y_pred_lgb = lgb_model.predict(xv)
    metrics["lgb"].append(mean_squared_error(yv, y_pred_lgb))

    # NW and LGB with rel as features
    x_broad = np.concatenate((x_s_np, r), axis=1)
    xb_broad, xq_broad, xv_broad = x_broad[ids_b], x_broad[ids_q], x_broad[ids_v]

    # NW with rel as features
    _, _, _, model_relfts = train_nw_arbitrary(
        x_backgnd=xb_broad,
        y_backgnd=yb,
        x_query=xq_broad,
        y_query=yq,
        x_val=xv_broad,
        y_val=yv,
        r_query_backgnd=np.zeros((len(xq_broad), len(xb_broad))),
        r_val_nonval=np.zeros((len(xv_broad), len(xb_broad) + len(xq_broad))),
        cfg=model_cfg,
        lr=lr,
        n_epochs=n_epochs,
    )
    metrics["rel-fts"].append(mean_squared_error(yv, model_relfts.y_val_pred))

    # LightGBM with rel as features
    x_train_broad = np.concatenate([xq_broad, xb_broad])
    lgb_model_rel = lgb.train(lgb_params, lgb.Dataset(x_train_broad, label=y_train))
    y_pred_lgb_rel = lgb_model_rel.predict(xv_broad)
    metrics["lgb-rel"].append(mean_squared_error(yv, y_pred_lgb_rel))

# Show mean metrics
{ k: round(np.mean(v), 2) for k, v in metrics.items() }

# T-learner

In [None]:
metrics = defaultdict(list)
for seed in tqdm(range(20)):
    np.random.seed(seed)
    query_indices, back_indices, val_indices = generate_indices(seed)
    y_np = data_y_fact.to_numpy()
    xq, xb, xv = x.iloc[query_indices], x.iloc[back_indices], x.iloc[val_indices]
    yq, yb = data_y_fact.iloc[query_indices], data_y_fact.iloc[back_indices]
    tq, tb = data_treatment[query_indices], data_treatment[back_indices]

    xt_q, yt_q, xnt_q, ynt_q = split_treated_non_treated(xq, tq, yq)
    xt_b, yt_b, xnt_b, ynt_b = split_treated_non_treated(xb, tb, yb)

    iq_t = np.array([i for i in query_indices if data_treatment[i] == 1])
    iq_nt = np.array([i for i in query_indices if data_treatment[i] == 0])

    ib_t = np.array([i for i in back_indices if data_treatment[i] == 1])
    ib_nt = np.array([i for i in back_indices if data_treatment[i] == 0])

    data_y_cfact = ihdp_data[y_cfact_colname]
    yv_t = np.array([data_y_fact[i] if data_treatment[i] == 1 else data_y_cfact[i] for i in val_indices])
    yv_nt = np.array([data_y_fact[i] if data_treatment[i] == 0 else data_y_cfact[i] for i in val_indices])

    i_train_t = np.concatenate([iq_t, ib_t])
    i_train_nt = np.concatenate([iq_nt, ib_nt])

    r_q_b_treated = r[iq_t][:, ib_t]
    r_q_b_nt = r[iq_nt][:, ib_nt]
    r_val_nvt = r[val_indices][:, i_train_t]  # rel between val and treated train
    r_val_nvnt = r[val_indices][:, i_train_nt]  # rel between val and non-treated train

    label_t: Final[str] = "treated"
    label_nt: Final[str] = "non-treated"

    nw_broad_key: Final[str] = "nw_rel-as-features"
    trained_models = {
        "rel=True": {},
        "rel=False": {},
        nw_broad_key: {},
    }
    for rel, (xqi, xbi, yqi, ybi, yvi, r_q_b, r_v_nvi, label) in product(
        (True, False),
        (
            (xt_q, xt_b, yt_q, yt_b, yv_t, r_q_b_treated, r_val_nvt, label_t),
            (xnt_q, xnt_b, ynt_q, ynt_b, yv_nt, r_q_b_nt, r_val_nvnt, label_nt),
        ),
    ):
        _, _, _, model = train_nw_arbitrary(
            x_backgnd=xbi,
            y_backgnd=ybi,
            x_query=xqi,
            y_query=yqi,
            x_val=xv.to_numpy(),
            y_val=yvi,
            r_query_backgnd=r_q_b if rel else np.zeros_like(r_q_b),
            r_val_nonval=r_v_nvi if rel else np.zeros_like(r_v_nvi),
            cfg=model_cfg,
            lr=lr,
            n_epochs=n_epochs,
        )
        trained_models[f"rel={rel}"][label] = model
    
    # rel as features
    x_broad = np.concatenate((x.to_numpy(), r), axis=1)
    xb_broad, xq_broad, xv_broad = x_broad[back_indices], x_broad[query_indices], x_broad[val_indices]
    xt_q_broad, xnt_q_broad = x_broad[iq_t], x_broad[iq_nt]
    xt_b_broad, xnt_b_broad = x_broad[ib_t], x_broad[ib_nt]
    xv_broad = x_broad[val_indices]
    
    for (xqi, xbi, yqi, ybi, yvi, label) in (
        (xt_q_broad, xt_b_broad, yt_q, yt_b, yv_t, label_t),
        (xnt_q_broad, xnt_b_broad, ynt_q, ynt_b, yv_nt, label_nt),
    ):
        trained_models[nw_broad_key][label] = train_nw_arbitrary(
            x_backgnd=xbi,
            y_backgnd=ybi,
            x_query=xqi,
            y_query=yqi,
            x_val=xv_broad,
            y_val=yvi,
            r_query_backgnd=np.zeros((len(xqi), len(xbi))),
            r_val_nonval=np.zeros((len(xv_broad), len(xbi) + len(xqi))),
            cfg=model_cfg,
            lr=lr,
            n_epochs=n_epochs,
        )[-1]
    
    # LightGBM
    yt_train, ynt_train = np.concatenate([yt_q, yt_b]), np.concatenate([ynt_q, ynt_b])
    for xq_ti, xb_ti, xq_nti, xb_nti, xv_i, lgb_key in (
        (xt_q, xt_b, xnt_q, xnt_b, xv, "lgb"),
        (xt_q_broad, xt_b_broad, xnt_q_broad, xnt_b_broad, xv_broad, "lgb-rel"),
    ):
        xt_train, xnt_train = np.concatenate([xq_ti, xb_ti]), np.concatenate([xq_nti, xb_nti])
        lgb_model_t = lgb.train(lgb_params, lgb.Dataset(xt_train, label=yt_train))
        lgb_model_nt = lgb.train(lgb_params, lgb.Dataset(xnt_train, ynt_train))
        tau_lgb = lgb_model_t.predict(xv_i) - lgb_model_nt.predict(xv_i)
        tau_true = yv_t - yv_nt
        metrics[lgb_key].append(mean_squared_error(tau_true, tau_lgb))



    for key, models in trained_models.items():
        y_pred_treated = models[label_t].y_val_pred
        y_pred_nt = models[label_nt].y_val_pred

        tau_pred = y_pred_treated - y_pred_nt
        metrics[key].append(mean_squared_error(tau_true, tau_pred)) # PEHE

In [None]:
{k: round(np.mean(v),2) for k, v in metrics.items()}

# X-learner
## Step 1: two base learners as in T-learner

Copilot prompt:
<blockquote>
Write the code for the first stage of X-learners.

1. split data to xt_q, yt_q, xnt_q, ynt_q, ..., yv_t, yv_nt, just like for T-learner
2. create x_broad by adding r and split to xt_q_broad, ...
3. train estimators: nw rel, nw nonrel, nw nonrel for broad, lightgbm for broad
4. Using each estimators, compute y treated for xnt_q, xnt_b, and y non-treated for xt_q, xt_b
5. save estimated values and val indices for future analysis
</blockquote>

Code modified after generation


In [None]:
import torch

xlearner_stage1_results = []

for seed in tqdm(range(20)):
    np.random.seed(seed)
    query_indices, back_indices, val_indices = generate_indices(seed)
    y_np = data_y_fact.to_numpy()
    xq, xb, xv = x.iloc[query_indices], x.iloc[back_indices], x.iloc[val_indices]
    yq, yb = data_y_fact.iloc[query_indices], data_y_fact.iloc[back_indices]
    tq, tb = data_treatment[query_indices], data_treatment[back_indices]

    # Split by treatment
    xt_q, yt_q, xnt_q, ynt_q = split_treated_non_treated(xq, tq, yq)
    xt_b, yt_b, xnt_b, ynt_b = split_treated_non_treated(xb, tb, yb)

    iq_t = np.array([i for i in query_indices if data_treatment[i] == 1])
    iq_nt = np.array([i for i in query_indices if data_treatment[i] == 0])
    ib_t = np.array([i for i in back_indices if data_treatment[i] == 1])
    ib_nt = np.array([i for i in back_indices if data_treatment[i] == 0])

    data_y_cfact = ihdp_data[y_cfact_colname]
    yv_t = np.array([data_y_fact[i] if data_treatment[i] == 1 else data_y_cfact[i] for i in val_indices])
    yv_nt = np.array([data_y_fact[i] if data_treatment[i] == 0 else data_y_cfact[i] for i in val_indices])

    # Add r as features
    x_broad = np.concatenate((x.to_numpy(), r), axis=1)
    xb_broad, xq_broad, xv_broad = x_broad[back_indices], x_broad[query_indices], x_broad[val_indices]
    xt_q_broad, xnt_q_broad = x_broad[iq_t], x_broad[iq_nt]
    xt_b_broad, xnt_b_broad = x_broad[ib_t], x_broad[ib_nt]

    # Train NW (rel) on treated
    _, _, _, nw_treated_rel = train_nw_arbitrary(
        x_backgnd=xt_b, y_backgnd=yt_b,
        x_query=xt_q, y_query=yt_q,
        x_val=xv.to_numpy(), y_val=yv_t,
        r_query_backgnd=r[iq_t][:, ib_t],
        r_val_nonval=r[val_indices][:, np.concatenate([iq_t, ib_t])],
        cfg=model_cfg, lr=lr, n_epochs=n_epochs,
    )
    # Train NW (rel) on non-treated
    _, _, _, nw_nontreated_rel = train_nw_arbitrary(
        x_backgnd=xnt_b, y_backgnd=ynt_b,
        x_query=xnt_q, y_query=ynt_q,
        x_val=xv.to_numpy(), y_val=yv_nt,
        r_query_backgnd=r[iq_nt][:, ib_nt],
        r_val_nonval=r[val_indices][:, np.concatenate([iq_nt, ib_nt])],
        cfg=model_cfg, lr=lr, n_epochs=n_epochs,
    )
    # Train NW (nonrel) on treated
    _, _, _, nw_treated_norel = train_nw_arbitrary(
        x_backgnd=xt_b, y_backgnd=yt_b,
        x_query=xt_q, y_query=yt_q,
        x_val=xv.to_numpy(), y_val=yv_t,
        r_query_backgnd=np.zeros_like(r[iq_t][:, ib_t]),
        r_val_nonval=np.zeros_like(r[val_indices][:, np.concatenate([iq_t, ib_t])]),
        cfg=model_cfg, lr=lr, n_epochs=n_epochs,
    )
    # Train NW (nonrel) on non-treated
    _, _, _, nw_nontreated_norel = train_nw_arbitrary(
        x_backgnd=xnt_b, y_backgnd=ynt_b,
        x_query=xnt_q, y_query=ynt_q,
        x_val=xv.to_numpy(), y_val=yv_nt,
        r_query_backgnd=np.zeros_like(r[iq_nt][:, ib_nt]),
        r_val_nonval=np.zeros_like(r[val_indices][:, np.concatenate([iq_nt, ib_nt])]),
        cfg=model_cfg, lr=lr, n_epochs=n_epochs,
    )
    # NW (nonrel) on broad features, treated
    _, _, _, nw_treated_broad = train_nw_arbitrary(
        x_backgnd=xt_b_broad, y_backgnd=yt_b,
        x_query=xt_q_broad, y_query=yt_q,
        x_val=xv_broad, y_val=yv_t,
        r_query_backgnd=np.zeros((len(xt_q_broad), len(xt_b_broad))),
        r_val_nonval=np.zeros((len(xv_broad), len(xt_b_broad) + len(xt_q_broad))),
        cfg=model_cfg, lr=lr, n_epochs=n_epochs,
    )
    # NW (nonrel) on broad features, non-treated
    _, _, _, nw_nontreated_broad = train_nw_arbitrary(
        x_backgnd=xnt_b_broad, y_backgnd=ynt_b,
        x_query=xnt_q_broad, y_query=ynt_q,
        x_val=xv_broad, y_val=yv_nt,
        r_query_backgnd=np.zeros((len(xnt_q_broad), len(xnt_b_broad))),
        r_val_nonval=np.zeros((len(xv_broad), len(xnt_b_broad) + len(xnt_q_broad))),
        cfg=model_cfg, lr=lr, n_epochs=n_epochs,
    )

    # LightGBM
    xt_train = np.concatenate([xt_q, xt_b])
    yt_train = np.concatenate([yt_q, yt_b])
    lgb_model_t = lgb.train(lgb_params, lgb.Dataset(xt_train, label=yt_train))

    xnt_train = np.concatenate([xnt_q, xnt_b])
    ynt_train = np.concatenate([ynt_q, ynt_b])
    lgb_model_nt = lgb.train(lgb_params, lgb.Dataset(xnt_train, label=ynt_train))

    # LightGBM on broad features, treated
    xt_train_broad = np.concatenate([xt_q_broad, xt_b_broad])
    lgb_model_t_broad = lgb.train(lgb_params, lgb.Dataset(xt_train_broad, label=yt_train))
    # LightGBM on broad features, non-treated
    xnt_train_broad = np.concatenate([xnt_q_broad, xnt_b_broad])
    lgb_model_nt_broad = lgb.train(lgb_params, lgb.Dataset(xnt_train_broad, label=ynt_train))

    # 4. Predict counterfactuals for training sets
    # For non-treated: predict y_treated for non-treated patients
    xnt_b, xnt_q = torch.tensor(xnt_b, dtype=torch.float32), torch.tensor(xnt_q, dtype=torch.float32)
    xnt_b_broad, xnt_q_broad = torch.tensor(xnt_b_broad, dtype=torch.float32), torch.tensor(xnt_q_broad, dtype=torch.float32)
    ynt_b = torch.tensor(ynt_b, dtype=torch.float32)
    rqb_nt = torch.tensor(r[iq_nt][:, ib_nt], dtype=torch.float32)
    y_treated_on_xnt_q_nwrel = nw_treated_rel.model(xnt_b, ynt_b, xnt_q, rqb_nt)
    # y_treated_on_xnt_b_nwrel = nw_treated_rel.model(xnt_b)
    y_treated_on_xnt_q_nwnorel = nw_treated_norel.model(xnt_b, ynt_b, xnt_q, torch.zeros_like(rqb_nt))
    # y_treated_on_xnt_b_nwnorel = nw_treated_norel.model(xnt_b)
    y_treated_on_xnt_q_broad = nw_treated_broad.model(xnt_b_broad, ynt_b, xnt_q_broad, torch.zeros_like(rqb_nt))
    # y_treated_on_xnt_b_broad = nw_treated_broad.model(xnt_b_broad)
    y_treated_on_xnt_train_lgb = lgb_model_t.predict(xnt_train)
    y_treated_on_xnt_train_lgb_broad = lgb_model_t_broad.predict(xnt_train_broad)

    # For treated: predict y_non-treated for xt_q, xt_b
    xt_q, xt_b = torch.tensor(xt_q, dtype=torch.float32), torch.tensor(xt_b, dtype=torch.float32)
    xt_b_broad, xt_q_broad = torch.tensor(xt_b_broad, dtype=torch.float32), torch.tensor(xt_q_broad, dtype=torch.float32)
    yt_b = torch.tensor(yt_b, dtype=torch.float32)
    rqb_t = torch.tensor(r[iq_t][:, ib_t], dtype=torch.float32)
    y_nontreated_on_xt_q_nwrel = nw_nontreated_rel.model(xt_b, yt_b, xt_q, rqb_t)
    # y_nontreated_on_xt_b_nwrel = nw_nontreated_rel.model(xt_b)
    y_nontreated_on_xt_q_nwnorel = nw_nontreated_norel.model(xt_b, yt_b, xt_q, torch.zeros_like(rqb_t))
    # y_nontreated_on_xt_b_nwnorel = nw_nontreated_norel.model(xt_b)
    y_nontreated_on_xt_q_broad = nw_nontreated_broad.model(xt_b_broad, yt_b, xt_q_broad, torch.zeros_like(rqb_t))
    # y_nontreated_on_xt_b_broad = nw_nontreated_broad.model(xt_b_broad)
    y_nontreated_on_xt_train_lgb = lgb_model_t.predict(xt_train)
    y_nontreated_on_xt_train_lgb_broad = lgb_model_nt_broad.predict(xt_train_broad)

    # Save all results for future analysis
    xlearner_stage1_results.append({
        "seed": seed,
        "indices": {
            "iq_t": iq_t, "iq_nt": iq_nt, "ib_t": ib_t, "ib_nt": ib_nt,
            "query_indices": query_indices, "back_indices": back_indices, "val_indices": val_indices
        },
        "factual": {
            "yt_q": yt_q, "yt_b": yt_b, "ynt_q": ynt_q, "ynt_b": ynt_b
        },
        "cfactual_preds": {
            "y_treated_on_xnt_q_nwrel": y_treated_on_xnt_q_nwrel,
            # "y_treated_on_xnt_b_nwrel": y_treated_on_xnt_b_nwrel,
            "y_treated_on_xnt_q_nwnorel": y_treated_on_xnt_q_nwnorel,
            # "y_treated_on_xnt_b_nwnorel": y_treated_on_xnt_b_nwnorel,
            "y_treated_on_xnt_q_broad": y_treated_on_xnt_q_broad,
            # "y_treated_on_xnt_b_broad": y_treated_on_xnt_b_broad,
            "y_treated_on_xnt_train_lgb": y_treated_on_xnt_train_lgb,
            "y_treated_on_xnt_train_lgb_broad": y_treated_on_xnt_train_lgb_broad,
            "y_nontreated_on_xt_q_nwrel": y_nontreated_on_xt_q_nwrel,
            # "y_nontreated_on_xt_b_nwrel": y_nontreated_on_xt_b_nwrel,
            "y_nontreated_on_xt_q_nwnorel": y_nontreated_on_xt_q_nwnorel,
            # "y_nontreated_on_xt_b_nwnorel": y_nontreated_on_xt_b_nwnorel,
            "y_nontreated_on_xt_q_broad": y_nontreated_on_xt_q_broad,
            # "y_nontreated_on_xt_b_broad": y_nontreated_on_xt_b_broad,
            "y_nontreated_on_xt_train_lgb": y_nontreated_on_xt_train_lgb,
            "y_nontreated_on_xt_train_lgb_broad": y_nontreated_on_xt_train_lgb_broad,
        }
    })

# Step 2: $\tau$ imputation for training set

In [None]:
for seed in tqdm(range(10)):
    query_indices, back_indices, val_indices = generate_indices(seed)

    res = run_training(
        x=x.to_numpy(),
        y=ihdp_data[ihdp_tau_colname].to_numpy(),  # TODO generate this in the previous stage, not use pre-computed
        r=r,
        backgnd_indices=back_indices,
        query_indices=query_indices,
        val_indices=val_indices,
        lr=1e-4,
        n_epochs=10,
        rel_as_feats=r,
    )
    for i, v in enumerate(res.values()):
        mse = v[0]
        metrics["pehe"][i].append(mse)


for i in range(len(labels)):
    metrics["pehe"][i] = np.array(metrics["pehe"][i])
metrics_mean(metrics, labels)