In [None]:
%load_ext autoreload
%autoreload 2
from pathlib import Path

import numpy as np
import pandas as pd

def load_ihdp_data(ihdp_path: Path) -> tuple[pd.DataFrame, list[str], str]:
    ihdp_cols = [s[:-1] for s in np.loadtxt(ihdp_path / "columns.txt", dtype=str)][:-2]
    ihdp_cols.extend([f"x{i}" for i in range(2, 26)])

    csvs = []
    for csv_path in (ihdp_path / "csv").glob("*.csv"):
        csvs.append(pd.read_csv(csv_path, header=None))
        break # TODO choose a table, for now using the first table
    data = pd.concat(csvs)
    data.columns = ihdp_cols

    y_col_name = "delta_y"
    data[y_col_name] = (data["y_cfactual"] - data["y_factual"]) * (-1) ** data["treatment"]
    exclude_cols = ["treatment", "y_cfactual", "y_factual", "mu0", "mu1"]
    return data, exclude_cols, y_col_name

ihdp_data, ihdp_exclude_cols, ihdp_y_colname = load_ihdp_data(Path("/Users/vzuev/Documents/git/gh_zuevval/tabrel/CEVAE/datasets/IHDP"))
ihdp_data.head()

In [None]:
from typing import Final

x_all = ihdp_data.drop(columns=ihdp_exclude_cols + [ihdp_y_colname])

ihdp_last_numeric_index: Final[int] = 6
x_numeric = x_all.iloc[:, :ihdp_last_numeric_index]
x_cat = x_all.iloc[:, ihdp_last_numeric_index:]
x_numeric

In [None]:
import seaborn as sns

x_num_y = x_numeric.copy()
x_num_y[ihdp_y_colname] = ihdp_data[ihdp_y_colname]
sns.pairplot(x_num_y, hue=ihdp_y_colname)

In [None]:
from itertools import product
from tqdm import tqdm

group_col: Final[str] = "x4"
x = x_all.drop(columns=[group_col])
x_len = len(x)
categories = x_all[group_col]
print("n_categories", len(categories.unique()))

r = np.zeros((x_len, x_len))
for i, j in tqdm(list(product(range(x_len), range(x_len)))):
    if np.isclose(categories[i], categories[j]):
        r[i, j] = 1

r

# S-learner

In [None]:
from tabrel.benchmark.nw_regr import run_training, metrics_mean

labels = ["rel", "nrel", "lgb", "rel-fts", "lgb-rel"]
metrics = {
    "pehe": [[] for _ in range(len(labels))]
}

def generate_indices(seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)
    indices = np.random.permutation(x_len)
    n_query, n_back = 200, 300
    q_indices = indices[:n_query]
    b_indices = indices[n_query:n_back]
    v_indices = indices[n_back:]
    return q_indices, b_indices, v_indices


for seed in tqdm(range(10)):
    query_indices, back_indices, val_indices = generate_indices(seed)



    res = run_training(
        x=x.to_numpy(),
        y=ihdp_data[ihdp_y_colname].to_numpy(),
        r=r,
        backgnd_indices=back_indices,
        query_indices=query_indices,
        val_indices=val_indices,
        lr=1e-4,
        n_epochs=10,
        rel_as_feats=r,
    )
    for i, v in enumerate(res.values()):
        mse = v[0]
        metrics["pehe"][i].append(mse)


for i in range(len(labels)):
    metrics["pehe"][i] = np.array(metrics["pehe"][i])
metrics_mean(metrics, labels)

# T-learner

In [None]:
from collections import defaultdict
from tabrel.benchmark.nw_regr import train_nw_arbitrary, NwModelConfig
from sklearn.metrics import mean_squared_error

y_fact_colname, y_cfact_colname = "y_factual", "y_cfactual"
data_y_fact, data_treatment = ihdp_data[y_fact_colname], ihdp_data["treatment"]

def split_treated_non_treated(x: pd.DataFrame, treatment: np.ndarray, y_fact: pd.DataFrame) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    treated = treatment == 1
    x_treated, y_treated = x.loc[treated], y_fact.loc[treated]
    x_non_treated, y_non_treated = x.loc[~treated], y_fact.loc[~treated]
    return x_treated.to_numpy(), y_treated.to_numpy(), x_non_treated.to_numpy(), y_non_treated.to_numpy()

metrics = defaultdict(list)
for seed in tqdm(range(20)):
    np.random.seed(seed)
    query_indices, back_indices, val_indices = generate_indices(seed)
    y_np = data_y_fact.to_numpy()
    xq, xb, xv = x.iloc[query_indices], x.iloc[back_indices], x.iloc[val_indices]
    yq, yb = data_y_fact.iloc[query_indices], data_y_fact.iloc[back_indices]
    tq, tb = data_treatment[query_indices], data_treatment[back_indices]

    xt_q, yt_q, xnt_q, ynt_q = split_treated_non_treated(xq, tq, yq)
    xt_b, yt_b, xnt_b, ynt_b = split_treated_non_treated(xb, tb, yb)

    iq_t = np.array([i for i in query_indices if data_treatment[i] == 1])
    iq_nt = np.array([i for i in query_indices if data_treatment[i] == 0])

    ib_t = np.array([i for i in back_indices if data_treatment[i] == 1])
    ib_nt = np.array([i for i in back_indices if data_treatment[i] == 0])

    data_y_cfact = ihdp_data[y_cfact_colname]
    yv_t = np.array([data_y_fact[i] if data_treatment[i] == 1 else data_y_cfact[i] for i in val_indices])
    yv_nt = np.array([data_y_fact[i] if data_treatment[i] == 0 else data_y_cfact[i] for i in val_indices])

    i_train_t = np.concatenate([iq_t, ib_t])
    i_train_nt = np.concatenate([iq_nt, ib_nt])

    r_q_b_treated = r[iq_t][:, ib_t]
    r_q_b_nt = r[iq_nt][:, ib_nt]
    r_val_nvt = r[val_indices][:, i_train_t]  # rel between val and treated train
    r_val_nvnt = r[val_indices][:, i_train_nt]  # rel between val and non-treated train

    label_t: Final[str] = "treated"
    label_nt: Final[str] = "non-treated"

    trained_models = {
        "rel=True": {},
        "rel=False": {},
    }
    for rel, (xqi, xbi, yqi, ybi, yvi, r_q_b, r_v_nvi, label) in product(
        (True, False),
        (
            (xt_q, xt_b, yt_q, yt_b, yv_t, r_q_b_treated, r_val_nvt, label_t),
            (xnt_q, xnt_b, ynt_q, ynt_b, yv_nt, r_q_b_nt, r_val_nvnt, label_nt),
        ),
    ):
        _, _, _, model = train_nw_arbitrary(
            x_backgnd=xbi,
            y_backgnd=ybi,
            x_query=xqi,
            y_query=yqi,
            x_val=xv.to_numpy(),
            y_val=yvi,
            r_query_backgnd=r_q_b if rel else np.zeros_like(r_q_b),
            r_val_nonval=r_v_nvi if rel else np.zeros_like(r_v_nvi),
            cfg=NwModelConfig(),
            lr=1e-3,
            n_epochs=50,
        )
        trained_models[f"rel={rel}"][label] = model

    for rel in (True, False):
        key = f"rel={rel}"
        models = trained_models[key]
        y_true_treated = models[label_t].y_val_true
        y_pred_treated = models[label_t].y_val_pred
        y_true_nt = models[label_nt].y_val_true
        y_pred_nt = models[label_nt].y_val_pred

        tau_true = y_true_treated - y_true_nt
        tau_pred = y_pred_treated - y_pred_nt
        metrics[key].append(mean_squared_error(tau_true, tau_pred)) # PEHE

In [None]:
{k: round(np.mean(v),2) for k, v in metrics.items()}