In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import numpy as np
import pandas as pd
from tabrel.utils.treatment import load_ihdp_data, generate_indices


ihdp_data, ihdp_exclude_cols, ihdp_tau_colname, _, _, _ = load_ihdp_data(Path("../CEVAE/datasets/IHDP"))
ihdp_data.head()

In [None]:
from typing import Final
x_all = ihdp_data.drop(columns=ihdp_exclude_cols + [ihdp_tau_colname])

ihdp_last_numeric_index: Final[int] = 6
x_numeric = x_all.iloc[:, :ihdp_last_numeric_index]
# x_numeric

In [None]:
import seaborn as sns

x_num_y = x_numeric.copy()
x_num_y[ihdp_tau_colname] = ihdp_data[ihdp_tau_colname]
# sns.pairplot(x_num_y, hue=ihdp_tau_colname)

In [None]:
x_num_y.corr()

In [None]:
from itertools import product
from tqdm import tqdm
import hdbscan
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
x_scaled = scaler.fit_transform(x_all[["x1", "x2", "x4", "x6", "x14", "x16", "x18"]])
clusterer = hdbscan.HDBSCAN(min_cluster_size=30)
clusters = clusterer.fit_predict(x_scaled)

# group_col: Final[str] = "x4"
# x = x_all.drop(columns=[group_col])
x_len = len(x_all)
# categories = x_all[group_col].to_numpy()
categories = clusters
print("n_categories", len(np.unique(categories)))

def make_r(cats_: np.ndarray, progress_bar: bool) -> np.ndarray:
    n_ = len(cats_)
    r_ = np.zeros((n_, n_))
    iter_pairs = list(product(range(n_), range(n_)))
    for i, j in tqdm(iter_pairs, desc="make R") if progress_bar else iter_pairs:
        if np.isclose(cats_[i], cats_[j]):
            r_[i, j] = 1
    return r_
r = make_r(categories, progress_bar=True)

# S-learner

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
import lightgbm as lgb
from sklearn.metrics import r2_score
import torch
import torch.nn as nn

from tabrel.benchmark.nw_regr import run_training, metrics_mean, train_nw_arbitrary, NwModelConfig, RelNwRegr
from tabrel.train import train_relnet
from tabrel.utils.misc import to_tensor

labels = ["rel", "nrel", "lgb", "rel-fts", "lgb-rel"]

def split_treated_non_treated(x: pd.DataFrame, treatment: np.ndarray, y_fact: pd.DataFrame) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    treated = treatment == 1
    x_treated, y_treated = x.loc[treated], y_fact.loc[treated]
    x_non_treated, y_non_treated = x.loc[~treated], y_fact.loc[~treated]
    return x_treated.to_numpy(), y_treated.to_numpy(), x_non_treated.to_numpy(), y_non_treated.to_numpy()

y_fact_colname, y_cfact_colname = "y_factual", "y_cfactual"
data_y_fact, data_y_cfact, data_treatment = ihdp_data[y_fact_colname], ihdp_data[y_cfact_colname], ihdp_data["treatment"]


y_s = data_y_fact.to_numpy() # Y for S-learner
y_s_cfact = data_y_cfact.to_numpy()
treatment_np = data_treatment.to_numpy()
tau_true = ihdp_data[ihdp_tau_colname].to_numpy()

lr, n_epochs = 1e-3, 50
lgb_params: Final[dict[str, str | int]] = {"objective": "regression", "metric": "rmse", "verbosity": -1}
    

### S-learner: MLP + Optuna

In [None]:
from tabrel.benchmark.nw_regr import MlpConfig, train_nw_arbitrary

def train_mlp_s_learner(
        mlp_out_dim: int,
        mlp_hidden_dim: int,
        dropout: float,
        lr: float,
        _seed: int,
) -> float:
    ids_q, ids_b, ids_v = generate_indices(_seed, n_total=x_len)

    x_all["treatment"] = data_treatment
    x_np = x_all.to_numpy()

    xv, yv, yv_cfact, treatment_v = (
        x_np[ids_v],
        y_s[ids_v],
        y_s_cfact[ids_v],
        treatment_np[ids_v],
    )

    yv_s = np.concatenate([yv, yv_cfact])
    xv_cfact = xv.copy()
    xv_cfact[:, -1] = 1 - treatment_v  # assuming treatment is the last col
    xv_s = np.concatenate([xv, xv_cfact])
    # x_s_np = np.concatenate([x_np[:x_len], xv_cfact])
    cats_s = np.concatenate([categories, categories[ids_v]])
    r_s = make_r(cats_s, progress_bar=True)
    
    r_q_b = r_s[ids_q][:, ids_b]
    
    ids_v_s = np.concatenate([ids_v, np.arange(start=x_len, stop=len(ids_v) + x_len)])
    r_val_train = r_s[ids_v_s][:, np.concatenate([ids_b, ids_q])]

    _, _, _, model = train_nw_arbitrary(
        x_backgnd=x_np[ids_b],
        y_backgnd=y_s[ids_b],
        x_query=x_np[ids_q],
        y_query=y_s[ids_q],
        x_val=xv_s,
        y_val=yv_s,
        r_query_backgnd=r_q_b,
        r_val_nonval=r_val_train,
        cfg=NwModelConfig(
            init_sigma=1.,
            init_r_scale=1.,
            input_dim=mlp_out_dim,
            trainable_weights_matrix=False,
            mlp_config=MlpConfig(
                in_dim=x_all.shape[1],
                hidden_dim=mlp_hidden_dim,
                out_dim=mlp_out_dim,
                dropout=dropout,
            )
        ),
        lr=lr,
        n_epochs=1000,
    )

    n_val = len(xv)
    y_val_pred_fact, y_val_pred_cfact = model.y_val_pred[:n_val], model.y_val_pred[n_val:]
    tau_val_pred = (y_val_pred_fact - y_val_pred_cfact) * (-1) ** (1 - treatment_v)
    tau_val_true = tau_true[ids_v]

    return mean_squared_error(tau_val_true, tau_val_pred)


In [None]:
from datetime import datetime

import optuna


sqlite_path: Final[str] = "sqlite:///db.sqlite3"

seed = 0


def mlp_objective(trial: optuna.Trial) -> float:
    global seed
    seed += 1

    return train_mlp_s_learner(
        mlp_out_dim=trial.suggest_int("mlp_out_dim", 1, 40),
        mlp_hidden_dim=trial.suggest_int("hidden_dim", 4, 100),
        dropout=trial.suggest_float("dropout", .0, .6),
        lr=trial.suggest_float("lr", 1e-4, 1e-1),
        _seed=seed
    )

    
study = optuna.create_study(
    direction="minimize",
    study_name=f"TE_sLearner_mlpNw_{datetime.now()}",
    # storage=sqlite_path,
)

study.optimize(mlp_objective)

In [None]:
study.best_params

# {'mlp_out_dim': 40,
#  'hidden_dim': 78,
#  'dropout': 0.5335761824224884,
#  'lr': 0.014194376718333799}

In [None]:
mlp_pehes = []

for seed in range(15):
    mlp_pehes.append(train_mlp_s_learner(
        mlp_out_dim=40,
        mlp_hidden_dim=78,
        dropout=.534,
        lr=.014,
        _seed=seed,
    ))

In [None]:
round(np.mean(mlp_pehes), 2), round(np.std(mlp_pehes), 2)  # (3.86, 0.5)

In [None]:
metrics = defaultdict(list)

for x_s, xs_label in ((
    #  x,
     # x_all.drop(columns=["x1", "x2", "x4", "x6", "x14", "x16", "x18"]),
     x_all[["x1", ]],                    
    "x"), 
    (x_all, "x_all")
    ):
    x_s["treatment"] = data_treatment
    x_s_np = x_s.to_numpy()
    n_samples, n_feats = x_s_np.shape
    for seed in range(2):
        np.random.seed(seed)
        ids_q, ids_b, ids_v = generate_indices(seed, n_total=x_len)
        ids_train = np.concatenate((ids_b, ids_q))
        xb, yb, xq, yq = x_s_np[ids_b], y_s[ids_b], x_s_np[ids_q], y_s[ids_q]
        xv, yv, yv_cfact, treatment_v = x_s_np[ids_v], y_s[ids_v], y_s_cfact[ids_v], treatment_np[ids_v]
        n_val = len(xv)

        yv = np.concatenate([yv, yv_cfact])
        xv_cfact = xv.copy()
        xv_cfact[:, -1] = 1 - treatment_v # assuming treatment is the last col
        xv = np.concatenate([xv, xv_cfact])
        x_s_np = np.concatenate([x_s_np[:len(x_s)], xv_cfact])
        y_s = np.concatenate([y_s[:len(x_s)], yv_cfact])
        cats_s = np.concatenate([categories, categories[ids_v]])
        tau_val_true = tau_true[ids_v]
        ids_v = np.concatenate([ids_v, np.arange(start=len(x_s), stop=len(ids_v) + len(x_s))])

        # NW without rel
        _, _, _, model_nrel = train_nw_arbitrary(
            x_backgnd=xb,
            y_backgnd=yb,
            x_query=xq,
            y_query=yq,
            x_val=xv,
            y_val=yv,
            r_query_backgnd=np.zeros((len(xq), len(xb))),
            r_val_nonval=np.zeros((len(xv), len(xb) + len(xq))),
            cfg=NwModelConfig(input_dim=n_feats),
            lr=lr,
            n_epochs=n_epochs,
        )

        y_val_pred_fact, y_val_pred_cfact = model_nrel.y_val_pred[:n_val], model_nrel.y_val_pred[n_val:]
        tau_val_pred = (y_val_pred_fact - y_val_pred_cfact) * (-1) ** (1 + treatment_v)
        metrics[f"nrel_{xs_label}"].append(mean_squared_error(tau_val_true, tau_val_pred))

        # LightGBM
        x_train = np.concatenate([xq, xb])
        y_train = np.concatenate([yq, yb])
        lgb_model = lgb.train(lgb_params, lgb.Dataset(x_train, label=y_train))
        y_pred_lgb = lgb_model.predict(xv)

        y_lgb_fact, y_lgb_cfact = y_pred_lgb[:n_val], y_pred_lgb[n_val:]
        tau_lgb_pred = (y_lgb_fact - y_lgb_cfact) * (-1) ** (1 - treatment_v)

        metrics[f"lgb_{xs_label}"].append(mean_squared_error(tau_val_true, tau_lgb_pred))

        # plt.figure()
        # plt.title(f"LGB {xs_label} seed {seed}")
        # plt.plot(range(len(tau_val_true)), tau_val_true, label="tau validate true")
        # plt.plot(range(len(tau_lgb_pred)), tau_lgb_pred, label = "tau val LGB")
        # plt.show()

        # continue # TODO
        if xs_label != "x_all":
             continue
        
        r_s = make_r(cats_s, progress_bar=True)
        r_q_b = r_s[ids_q][:, ids_b]
        r_val_train = r_s[ids_v][:, ids_train]

        # TabRel
        torch.manual_seed(seed)
        _, _, _, y_val_pred, _ = train_relnet(
            x=x_s_np,
            y=y_s,
            r=r_s,
            backgnd_indices=ids_b,
            query_indices=ids_q,
            val_indices=ids_v,
            lr=0.01,
            n_epochs=800,
            n_layers=2,
            periodic_embed_dim=None,
            num_heads=2,
            progress_bar=True,
        )
        y_val_pred_fact, y_val_pred_cfact = y_val_pred[:n_val], y_val_pred[n_val:]
        tau_val_pred = (y_val_pred_fact - y_val_pred_cfact) * (-1) ** (1 + treatment_v)
        relnet_pehe = mean_squared_error(tau_val_pred, tau_val_true)
        metrics["relnet"].append(relnet_pehe)

        # NW with rel
        _, _, _, model_rel = train_nw_arbitrary(
            x_backgnd=xb,
            y_backgnd=yb,
            x_query=xq,
            y_query=yq,
            x_val=xv,
            y_val=yv,
            r_query_backgnd=r_q_b,
            r_val_nonval=r_val_train,
            cfg=NwModelConfig(input_dim=n_feats),
            lr=lr,
            n_epochs=n_epochs,
        )

        y_val_pred_fact, y_val_pred_cfact = model_rel.y_val_pred[:n_val], model_rel.y_val_pred[n_val:]
        tau_val_pred = (y_val_pred_fact - y_val_pred_cfact) * (-1) ** (1 + treatment_v)
        metrics["rel"].append(mean_squared_error(tau_val_true, tau_val_pred))

        torch.manual_seed(seed)
        x_back, x_query, x_val  = to_tensor(xb), to_tensor(xq), to_tensor(xv)
        y_back, y_query, y_val = to_tensor(yb), to_tensor(yq), to_tensor(yv)
        x_train, y_train = torch.cat([x_back, x_query]), torch.cat([y_back, y_query])
        for learnable_norm in (True, False):
                config = NwModelConfig(
                    init_sigma=1.0,
                    init_r_scale=1.0,
                    input_dim=n_feats, # if x_label == "xInit" else n_feats + n_samples,
                    trainable_weights_matrix=learnable_norm,
                )
                model = RelNwRegr(config)
                optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
                loss_fn = nn.MSELoss()
                torch.manual_seed(seed)
                model.train()

                n_epochs = 500
                for _ in range(n_epochs):
                    optimizer.zero_grad()
                    y_pred = model(x_back, y_back, x_query, to_tensor(r_q_b))
                    loss = loss_fn(y_pred, y_query)
                    loss.backward()
                    optimizer.step()

                model.eval()
                with torch.no_grad():
                    y_pred = model(
                        x_train,
                        y_train,
                        x_val,
                        to_tensor(r_val_train),
                    )
                    y_pred_np = y_pred.numpy()

                    y_val_pred_fact, y_val_pred_cfact = y_pred_np[:n_val], y_pred_np[n_val:]
                    tau_val_pred = (y_val_pred_fact - y_val_pred_cfact) * (-1) ** (1 + treatment_v)

                    mse = mean_squared_error(tau_val_true, tau_val_pred)
                    metrics[f"learnNorm{learnable_norm}_pehe"].append(mse)

        # NW and LGB with rel as features
        x_broad = np.concatenate((x_s_np, r_s), axis=1)
        xb_broad, xq_broad, xv_broad = x_broad[ids_b], x_broad[ids_q], x_broad[ids_v]
 
        # NW with rel as features
        _, _, _, model_relfts = train_nw_arbitrary(
            x_backgnd=xb_broad,
            y_backgnd=yb,
            x_query=xq_broad,
            y_query=yq,
            x_val=xv_broad,
            y_val=yv,
            r_query_backgnd=np.zeros((len(xq_broad), len(xb_broad))),
            r_val_nonval=np.zeros((len(xv_broad), len(xb_broad) + len(xq_broad))),
            cfg=NwModelConfig(input_dim = len(xb_broad[0])),
            lr=lr,
            n_epochs=n_epochs,
        )
        y_val_pred_fact, y_val_pred_cfact = model_relfts.y_val_pred[:n_val], model_relfts.y_val_pred[n_val:]
        tau_val_pred = (y_val_pred_fact - y_val_pred_cfact) * (-1) ** (1 + treatment_v)
        metrics["rel-fts"].append(mean_squared_error(tau_val_true, tau_val_pred))

        # LightGBM with rel as features
        x_train_broad = np.concatenate([xq_broad, xb_broad])
        lgb_model_rel = lgb.train(lgb_params, lgb.Dataset(x_train_broad, label=y_train.numpy()))
        y_pred_lgb_rel = lgb_model_rel.predict(xv_broad)

        y_lgb_fact, y_lgb_cfact = y_pred_lgb_rel[:n_val], y_pred_lgb_rel[n_val:]
        tau_lgb_pred = (y_lgb_fact - y_lgb_cfact) * (-1) ** (1 - treatment_v)
        metrics["lgb-rel"].append(mean_squared_error(tau_val_true, tau_lgb_pred))


In [None]:
x_all.shape

In [None]:
# Show mean metrics
{ k: f"{round(np.mean(v), 2)} {round(np.std(v), 2)}" for k, v in metrics.items() }

# T-learner

In [None]:
label_t: Final[str] = "treated"
label_nt: Final[str] = "non-treated"

In [None]:
import torch

x = x_all
metrics = defaultdict(list)
for seed in tqdm(range(15)):
    # seed = 1
    np.random.seed(seed)
    query_indices, back_indices, val_indices = generate_indices(seed, n_total=x_len)
    y_np = data_y_fact.to_numpy()
    xq, xb, xv = x.iloc[query_indices], x.iloc[back_indices], x.iloc[val_indices]
    yq, yb = data_y_fact.iloc[query_indices], data_y_fact.iloc[back_indices]
    tq, tb = data_treatment[query_indices], data_treatment[back_indices]

    xt_q, yt_q, xnt_q, ynt_q = split_treated_non_treated(xq, tq, yq)
    xt_b, yt_b, xnt_b, ynt_b = split_treated_non_treated(xb, tb, yb)

    iq_t = np.array([i for i in query_indices if data_treatment[i] == 1])
    iq_nt = np.array([i for i in query_indices if data_treatment[i] == 0])

    ib_t = np.array([i for i in back_indices if data_treatment[i] == 1])
    ib_nt = np.array([i for i in back_indices if data_treatment[i] == 0])

    data_y_cfact = ihdp_data[y_cfact_colname]
    yv_t = np.array([data_y_fact[i] if data_treatment[i] == 1 else data_y_cfact[i] for i in val_indices])
    yv_nt = np.array([data_y_fact[i] if data_treatment[i] == 0 else data_y_cfact[i] for i in val_indices])

    i_train_t = np.concatenate([ib_t, iq_t])
    i_train_nt = np.concatenate([ib_nt, iq_nt])

    r_q_b_treated = r[iq_t][:, ib_t]
    r_q_b_nt = r[iq_nt][:, ib_nt]
    r_val_nvt = r[val_indices][:, i_train_t]  # rel between val and treated train
    r_val_nvnt = r[val_indices][:, i_train_nt]  # rel between val and non-treated train

    nw_broad_key: Final[str] = "nw_rel-as-features"
    # trained_models = {
    #     "rel=True": {},
    #     "rel=False": {},
    #     nw_broad_key: {},
    # }
    trained_models = defaultdict(dict)

    y_pred_learn_norm = {}
    for metrics_key in ("mlp", "nw_learn_norm"):
        for xqi, xbi, yqi, ybi, yvi, r_q_b, r_v_nvi, label in (
            (xt_q, xt_b, yt_q, yt_b, yv_t, r_q_b_treated, r_val_nvt, label_t),
            (xnt_q, xnt_b, ynt_q, ynt_b, yv_nt, r_q_b_nt, r_val_nvnt, label_nt),
        ):
            n_feats = len(xqi[0])
            torch.manual_seed(seed)
            x_back, x_query, x_val  = to_tensor(xbi), to_tensor(xqi), to_tensor(xv.to_numpy())
            y_back, y_query, y_val = to_tensor(ybi), to_tensor(yqi), to_tensor(yvi)
            x_train, y_train = torch.cat([x_back, x_query]), torch.cat([y_back, y_query])
            
            mlp_out_dim = 40
            config = NwModelConfig(
                    init_sigma=1.0,
                    init_r_scale=1.0,
                    input_dim=mlp_out_dim if metrics_key == "mlp" else n_feats,
                    trainable_weights_matrix=True if metrics_key == "nw_learn_norm" else False,
                    mlp_config= MlpConfig(
                        in_dim=n_feats,
                        out_dim=mlp_out_dim,
                        hidden_dim=78,
                        dropout=.534,
                    ) if metrics_key == "mlp" else None,
            )
            model = RelNwRegr(config)
            optimizer = torch.optim.AdamW(model.parameters(), 
                                        #   lr=.014, 
                                          lr=1e-3,
                                          )
            loss_fn = nn.MSELoss()
            torch.manual_seed(seed)
            model.train()

            n_epochs = 50
            for _ in range(n_epochs):
                optimizer.zero_grad()
                y_pred = model(x_back, y_back, x_query, to_tensor(r_q_b))
                loss = loss_fn(y_pred, y_query)
                loss.backward()
                optimizer.step()

            model.eval()
            with torch.no_grad():
                y_pred = model(x_train, y_train, x_val, to_tensor(r_v_nvi))
                y_pred_learn_norm[label] = y_pred
        tau_pred_learn_norm = y_pred_learn_norm[label_t] - y_pred_learn_norm[label_nt]
        tau_true = yv_t - yv_nt
        metrics[metrics_key].append(mean_squared_error(tau_true, tau_pred_learn_norm.numpy())) 
         

    # for rel, (xqi, xbi, yqi, ybi, yvi, r_q_b, r_v_nvi, label) in product(
    #     (True, False),
    #     (
    #         (xt_q, xt_b, yt_q, yt_b, yv_t, r_q_b_treated, r_val_nvt, label_t),
    #         (xnt_q, xnt_b, ynt_q, ynt_b, yv_nt, r_q_b_nt, r_val_nvnt, label_nt),
    #     ),
    # ):
    #     _, _, _, model = train_nw_arbitrary(
    #         x_backgnd=xbi,
    #         y_backgnd=ybi,
    #         x_query=xqi,
    #         y_query=yqi,
    #         x_val=xv.to_numpy(),
    #         y_val=yvi,
    #         r_query_backgnd=r_q_b if rel else np.zeros_like(r_q_b),
    #         r_val_nonval=r_v_nvi if rel else np.zeros_like(r_v_nvi),
    #         cfg=NwModelConfig(input_dim=len(xbi[0])),
    #         lr=lr,
    #         n_epochs=n_epochs,
    #     )
    #     trained_models[f"rel={rel}"][label] = model

    # # rel as features
    # x_broad = np.concatenate((x.to_numpy(), r), axis=1)
    # xb_broad, xq_broad, xv_broad = x_broad[back_indices], x_broad[query_indices], x_broad[val_indices]
    # xt_q_broad, xnt_q_broad = x_broad[iq_t], x_broad[iq_nt]
    # xt_b_broad, xnt_b_broad = x_broad[ib_t], x_broad[ib_nt]
    # xv_broad = x_broad[val_indices]

    # for (xqi, xbi, yqi, ybi, yvi, label) in (
    #     (xt_q_broad, xt_b_broad, yt_q, yt_b, yv_t, label_t),
    #     (xnt_q_broad, xnt_b_broad, ynt_q, ynt_b, yv_nt, label_nt),
    # ):
    #     trained_models[nw_broad_key][label] = train_nw_arbitrary(
    #         x_backgnd=xbi,
    #         y_backgnd=ybi,
    #         x_query=xqi,
    #         y_query=yqi,
    #         x_val=xv_broad,
    #         y_val=yvi,
    #         r_query_backgnd=np.zeros((len(xqi), len(xbi))),
    #         r_val_nonval=np.zeros((len(xv_broad), len(xbi) + len(xqi))),
    #         cfg=NwModelConfig(input_dim=len(xbi[0])),
    #         lr=lr,
    #         n_epochs=n_epochs,
    #     )[-1]

    # # LightGBM
    # yt_train, ynt_train = np.concatenate([yt_q, yt_b]), np.concatenate([ynt_q, ynt_b])
    # for xq_ti, xb_ti, xq_nti, xb_nti, xv_i, lgb_key in (
    #     (xt_q, xt_b, xnt_q, xnt_b, xv, "lgb"),
    #     # (xt_q_broad, xt_b_broad, xnt_q_broad, xnt_b_broad, xv_broad, "lgb-rel"),
    # ):
    #     xt_train, xnt_train = np.concatenate([xq_ti, xb_ti]), np.concatenate([xq_nti, xb_nti])
    #     lgb_model_t = lgb.train(lgb_params, lgb.Dataset(xt_train, label=yt_train))
    #     lgb_model_nt = lgb.train(lgb_params, lgb.Dataset(xnt_train, ynt_train))
    #     tau_lgb = lgb_model_t.predict(xv_i) - lgb_model_nt.predict(xv_i)
    #     metrics[lgb_key].append(mean_squared_error(tau_true, tau_lgb))

    # # TabRel
    # def train_relnet_shorthand(x_: np.ndarray, y_: np.ndarray, r_: np.ndarray, 
    #                             bi_: np.ndarray, qi_: np.ndarray, vi_: np.ndarray) -> torch.Tensor:
    #     _, _, _, y_pred, _ = train_relnet(
    #         x=x_,
    #         y=y_,
    #         r=r_,
    #         backgnd_indices=bi_,
    #         query_indices=qi_,
    #         val_indices=vi_,
    #         lr=.007,
    #         n_epochs=1500,
    #         n_layers=2,
    #         periodic_embed_dim=None,
    #         embed_dim=32,
    #         num_heads=2,
    #         progress_bar=True,
    #     )
    #     return y_pred

    # torch.manual_seed(seed)
    # xt = np.concatenate([xt_b, xt_q, xv])
    # bi_t = np.array(range(len(xt_b)))
    # qi_t = np.array(range(len(xt_q))) + len(xt_b)
    # vi_t = np.array(range(len(xv))) + len(xt_b) + len(xt_q)

    # xnt = np.concatenate([xnt_b, xnt_q, xv])
    # yt = np.concatenate([yt_b, yt_q, yv_t])
    # ynt = np.concatenate([ynt_b, ynt_q, yv_nt])
    # bi_nt = np.array(range(len(xnt_b)))
    # qi_nt = np.array(range(len(xnt_q))) + len(xnt_b)
    # vi_nt = np.array(range(len(xv))) + len(xnt_b) + len(xnt_q)

    # i_t = np.concatenate([ib_t, iq_t, val_indices])
    # i_nt = np.concatenate([ib_nt, iq_nt, val_indices])
    # r_bqv_t = r[i_t][:, i_t]
    # r_bqv_nt = r[i_nt][:, i_nt]
    # y_pred_relnet_t = train_relnet_shorthand(x_=xt, y_=yt, r_=r_bqv_t, bi_=bi_t, qi_=qi_t, vi_=vi_t)
    # y_pred_relnet_nt = train_relnet_shorthand(x_=xnt, y_=ynt, r_=r_bqv_nt, bi_=bi_nt, qi_=qi_nt, vi_=vi_nt)
    # tau_pred_relnet = y_pred_relnet_t - y_pred_relnet_nt

    # metrics["relnet"].append(mean_squared_error(tau_true, tau_pred_relnet))

    for key, models in trained_models.items():
        y_pred_treated = models[label_t].y_val_pred
        y_pred_nt = models[label_nt].y_val_pred

        tau_pred = y_pred_treated - y_pred_nt
        metrics[key].append(mean_squared_error(tau_true, tau_pred)) # PEHE
    
    

In [None]:
# lr=.014: {'mlp': '4.54 & 0.49', 'nw_learn_norm': '5.97 & 1.00', 'lgb': '3.62 & 0.26'}
# lr=1e-3: {'mlp': '4.07 & 0.33', 'nw_learn_norm': '4.74 & 0.38', 'lgb': '3.62 & 0.26'}
# FIXED B, Q order, lr=1e-3: {'mlp': '3.72 & 0.31', 'nw_learn_norm': '4.57 & 0.44', 'lgb': '3.62 & 0.26'}
# FIXED BQ, lr=1e-3, n_epochs=200 (not 500): {'mlp': '3.49 & 0.42', 'nw_learn_norm': '6.24 & 0.46'}
# FIXED BQ, lr=1e-3, n_epochs=150: {'mlp': '3.38 & 0.42', 'nw_learn_norm': '6.43 & 0.46'}
# FIXED BQ, lr=1e-3, n_epochs=100: {'mlp': '3.31 & 0.34', 'nw_learn_norm': '6.60 & 0.45'}
# FIXED BQ, lr=1e-3, n_epochs=50: {'mlp': '5.19 & 0.62', 'nw_learn_norm': '6.75 & 0.43'}

{k: f"{np.mean(v):.2f} & {np.std(v):.2f}" for k, v in metrics.items()}

In [None]:
import matplotlib.pyplot as plt

plt.scatter(np.ones(len(metrics["relnet"])), metrics["relnet"])


In [None]:
np.std([m for m in metrics["relnet"] if m < 15])

# X-learner

In [None]:
import torch

def train_relnet_shorthand(x_b_: np.ndarray, x_q_: np.ndarray, x_v_: np.ndarray,
                           y_b_: np.ndarray, y_q_: np.ndarray, 
                           r_: np.ndarray, # r: (q, b, v) x (q, b, v)
                           seed: int) -> torch.Tensor:
        
        x_ = np.concatenate([x_q_, x_b_,  x_v_])
        y_ = np.concatenate([y_q_, y_b_, np.zeros(len(x_v_))])
        qi_ = np.array(range(len(x_q_)))
        bi_ = np.array(range(len(x_q_))) + len(x_q_)
        vi_ = np.array(range(len(x_v_))) + len(x_q_) + len(x_b_)

        torch.manual_seed(seed)
        _, _, _, y_pred, _ = train_relnet(
            x=x_,
            y=y_,
            r=r_,
            backgnd_indices=bi_,
            query_indices=qi_,
            val_indices=vi_,
            lr=.007,
            n_epochs=1000,
            n_layers=2,
            periodic_embed_dim=None,
            embed_dim=32,
            num_heads=2,
            progress_bar=False,
        )
        return y_pred

metrics = defaultdict(list)
for seed in tqdm(range(15)):
    # seed = 1
    np.random.seed(seed)
    query_indices, back_indices, val_indices = generate_indices(seed, n_total=x_len)
    y_np = data_y_fact.to_numpy()
    xq, xb, xv = x.iloc[query_indices], x.iloc[back_indices], x.iloc[val_indices]
    yq, yb = data_y_fact.iloc[query_indices], data_y_fact.iloc[back_indices]
    tq, tb = data_treatment[query_indices], data_treatment[back_indices]

    xt_q, yt_q, xnt_q, ynt_q = split_treated_non_treated(xq, tq, yq)
    xt_b, yt_b, xnt_b, ynt_b = split_treated_non_treated(xb, tb, yb)

    iq_t = np.array([i for i in query_indices if data_treatment[i] == 1])
    iq_nt = np.array([i for i in query_indices if data_treatment[i] == 0])

    ib_t = np.array([i for i in back_indices if data_treatment[i] == 1])
    ib_nt = np.array([i for i in back_indices if data_treatment[i] == 0])

    data_y_cfact = ihdp_data[y_cfact_colname]
    yv_t = np.array([data_y_fact[i] if data_treatment[i] == 1 else data_y_cfact[i] for i in val_indices])
    yv_nt = np.array([data_y_fact[i] if data_treatment[i] == 0 else data_y_cfact[i] for i in val_indices])

    i_train_t = np.concatenate([iq_t, ib_t]) # TODO check order
    i_train_nt = np.concatenate([iq_nt, ib_nt])

    r_q_b_treated = r[iq_t][:, ib_t]
    r_q_b_nt = r[iq_nt][:, ib_nt]
    r_val_nvt = r[val_indices][:, i_train_t]  # rel between val and treated train
    r_val_nvnt = r[val_indices][:, i_train_nt]  # rel between val and non-treated train

    label_t: Final[str] = "treated"
    label_nt: Final[str] = "non-treated"

    # rel as features
    x_broad = np.concatenate((x.to_numpy(), r), axis=1)
    xb_broad, xq_broad, xv_broad = x_broad[back_indices], x_broad[query_indices], x_broad[val_indices]
    xt_q_broad, xnt_q_broad = x_broad[iq_t], x_broad[iq_nt]
    xt_b_broad, xnt_b_broad = x_broad[ib_t], x_broad[ib_nt]
    xv_broad = x_broad[val_indices]

    # for LightGBM
    tau_true_val = yv_t - yv_nt
    yt_train, ynt_train = np.concatenate([yt_q, yt_b]), np.concatenate([ynt_q, ynt_b])
    
    lgb_usual_key: Final[str] = "lgb"
    lgb_broad_key: Final[str] = "lgb-rel"

    tau_train = {
        lgb_broad_key:
        {
            "t": {"q": None, "b": None},
            "nt": {"q": None, "b": None},
        },
        lgb_usual_key:
        {
            "t": {"q": None, "b": None},
            "nt": {"q": None, "b": None},
        },
    }

    for xq_ti, xb_ti, xq_nti, xb_nti, xv_i, lgb_key in (
        (xt_q, xt_b, xnt_q, xnt_b, xv, lgb_usual_key),
        (xt_q_broad, xt_b_broad, xnt_q_broad, xnt_b_broad, xv_broad, lgb_broad_key),
    ):
        xt_train, xnt_train = np.concatenate([xq_ti, xb_ti]), np.concatenate([xq_nti, xb_nti])
        lgb_model_t = lgb.train(lgb_params, lgb.Dataset(xt_train, label=yt_train))
        lgb_model_nt = lgb.train(lgb_params, lgb.Dataset(xnt_train, label=ynt_train))

        # impute treatment effects using first-stage learners

        yt_train_cfact_lgb = lgb_model_nt.predict(xt_train)
        tau_train_t = yt_train - yt_train_cfact_lgb
        n_query_t = len(xq_ti)
        tau_train[lgb_key]["t"] = {"q": tau_train_t[:n_query_t], "b": tau_train_t[n_query_t:]}

        ynt_train_cfact_lgb = lgb_model_t.predict(xnt_train)
        tau_train_nt =  ynt_train_cfact_lgb - ynt_train
        n_query_nt = len(xq_nti)
        tau_train[lgb_key]["nt"] = {"q": tau_train_nt[:n_query_nt], "b": tau_train_nt[n_query_nt:]}

    
    # stage 2
    for train_weights, is_rel, is_mlp, (xq_ti, xb_ti, xq_nti, xb_nti, xv_i, lgb_key), in product(
        (True, False), 
        (True, False), 
        (True, False),
    (
        (xt_q, xt_b, xnt_q, xnt_b, xv.to_numpy(), lgb_usual_key),
        (xt_q_broad, xt_b_broad, xnt_q_broad, xnt_b_broad, xv_broad, lgb_broad_key),
    )):
        if lgb_key == lgb_broad_key and (train_weights or is_rel or is_mlp):
            continue
        if is_mlp and (train_weights or not is_rel):
             continue

        xt_train, xnt_train = np.concatenate([xq_ti, xb_ti]), np.concatenate([xq_nti, xb_nti])

        # retrieve treatment effects imputed by first-stage learners

        taus_t = tau_train[lgb_key]["t"]
        tau_train_t = np.concatenate([taus_t["q"], taus_t["b"]])

        taus_nt = tau_train[lgb_key]["nt"]
        tau_train_nt =  np.concatenate([taus_nt["q"], taus_nt["b"]])

        # stage 2 learners: LightGBM
        # if not train_weights and not is_rel: # just do it once per lgb_key
        #     lgb_model_tau_t = lgb.train(lgb_params, lgb.Dataset(xt_train, label=tau_train_t))
        #     lgb_model_tau_nt = lgb.train(lgb_params, lgb.Dataset(xnt_train, label=tau_train_nt))

        #     tau_lgb = .5 * (lgb_model_tau_t.predict(xv_i) + lgb_model_tau_nt.predict(xv_i))
        #     metrics[lgb_key].append(mean_squared_error(tau_true, tau_lgb))

        # stage 2 learners: NW
        taus_t_q, taus_t_b = taus_t["q"], taus_t["b"]
        taus_nt_q, taus_nt_b = taus_nt["q"], taus_nt["b"]

        cfg = NwModelConfig(
             input_dim=len(xb_ti[0]),
             trainable_weights_matrix=train_weights,
             mlp_config=MlpConfig(
                        in_dim=n_feats,
                        out_dim=mlp_out_dim,
                        hidden_dim=78,
                        dropout=.534,
                    ) if is_mlp else None
             )
        n_epochs = 800
        _, _, _, model_tau_t = train_nw_arbitrary(
            x_backgnd=xb_ti,
            y_backgnd=taus_t_b,
            x_query=xq_ti,
            y_query=taus_t_q,
            x_val=xv_i,
            y_val=np.zeros(len(xv_i)),
            r_query_backgnd= r_q_b_treated if is_rel else np.zeros_like(r_q_b_treated),
            r_val_nonval=np.zeros_like(r_val_nvt),
            cfg=cfg, # TODO trainable weight matrix
            n_epochs=n_epochs,
            lr=1e-3,
        )

        _, _, _, model_tau_nt = train_nw_arbitrary(
            x_backgnd=xb_nti,
            y_backgnd=taus_nt_b,
            x_query=xq_nti,
            y_query=taus_nt_q,
            x_val=xv_i,
            y_val=np.zeros(len(xv_i)),
            r_query_backgnd = r_q_b_nt if is_rel else np.zeros_like(r_q_b_nt),
            r_val_nonval=np.zeros_like(r_val_nvnt),
            cfg=cfg,
            n_epochs=n_epochs,
            lr=1e-3,
        )
        with torch.no_grad():
            tau_pred1 = model_tau_t.model(
                to_tensor(xt_train),
                to_tensor(tau_train_t), 
                to_tensor(xv_i),
                to_tensor(r_val_nvt if is_rel else np.zeros_like(r_val_nvt)),
                )
            tau_pred0 = model_tau_nt.model(
                to_tensor(xnt_train),
                to_tensor(tau_train_nt),
                to_tensor(xv_i),
                to_tensor(r_val_nvnt if is_rel else np.zeros_like(r_val_nvnt)),
            )
        tau_pred = .5 * (tau_pred0 + tau_pred1)
        metrics[f"NW-rel{is_rel}-broad{lgb_key == lgb_broad_key}-trainWeights{train_weights}_mlp{is_mlp}"].append(mean_squared_error(tau_true_val, tau_pred))

        # TabRel
        if train_weights or is_rel or lgb_key == lgb_broad_key:
             continue
        
        i_treated = np.concatenate([ib_t, iq_t, val_indices])  # absolute indexation
        i_non_treated = np.concatenate([ib_nt, iq_nt, val_indices])
        tau_pred1 = train_relnet_shorthand(
             x_b_ = xb_ti,
             x_q_ = xq_ti,
             x_v_ = xv_i,
             y_b_ = taus_t_b,
             y_q_ = taus_t_q,
             r_ = r[i_treated][:, i_treated],
             seed=seed,
        )
        tau_pred0 = train_relnet_shorthand(
             x_b_ = xb_nti,
             x_q_ = xq_nti,
             x_v_ = xv_i,
             y_b_ = taus_nt_b,
             y_q_ = taus_nt_q,
             r_ = r[i_non_treated][:, i_non_treated],
             seed=seed,
        )
        tau_pred_relnet = .5 * (tau_pred1 + tau_pred0)

        metrics["relnet"].append(mean_squared_error(tau_true_val, tau_pred_relnet))

In [None]:
# --- 800 iterations, B, Q order as before: ---
# {'NW-relTrue-broadFalse-trainWeightsTrue_mlpFalse': '5.11 & 0.30',
#  'NW-relFalse-broadFalse-trainWeightsTrue_mlpFalse': '5.35 & 0.40',
#  'NW-relTrue-broadFalse-trainWeightsFalse_mlpTrue': '4.14 & 0.53',
#  'NW-relTrue-broadFalse-trainWeightsFalse_mlpFalse': '6.33 & 1.44',
#  'NW-relFalse-broadFalse-trainWeightsFalse_mlpFalse': '6.69 & 1.50',
#  'NW-relFalse-broadTrue-trainWeightsFalse_mlpFalse': '6.89 & 1.93'}

# --- 200 iterations, B, Q order as before: ---
# {'NW-relTrue-broadFalse-trainWeightsTrue_mlpFalse': '5.31 & 0.28',
#  'NW-relFalse-broadFalse-trainWeightsTrue_mlpFalse': '5.91 & 1.51',
#  'NW-relTrue-broadFalse-trainWeightsFalse_mlpTrue': '4.27 & 0.45',
#  'NW-relTrue-broadFalse-trainWeightsFalse_mlpFalse': '6.09 & 1.39',
#  'NW-relFalse-broadFalse-trainWeightsFalse_mlpFalse': '6.27 & 1.51',
#  'NW-relFalse-broadTrue-trainWeightsFalse_mlpFalse': '6.64 & 1.89'}

# --- 800 iterations, changed B, Q order: --- TODO why so? not everywhere changed in Stage 1!
# {'NW-relTrue-broadFalse-trainWeightsFalse_mlpTrue': '8.30 & 1.02'}

{k: f"{np.mean(v):.2f} & {np.std(v):.2f}" for k, v in metrics.items()}