In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import lightgbm as lgb
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import mean_squared_error

from tabrel.train import train_relnet
from tabrel.utils.treatment import load_ihdp_data, generate_indices


ihdp_data, ihdp_exclude_cols, ihdp_tau_colname, ihdp_y_fact_colname, ihdp_y_cfact_colname, ihdp_treatment_colname = load_ihdp_data(Path("../CEVAE/datasets/IHDP"))
x_all = ihdp_data.drop(columns=ihdp_exclude_cols + [ihdp_tau_colname])
y_fact, y_cfact = ihdp_data[ihdp_y_fact_colname], ihdp_data[ihdp_y_cfact_colname]
treatment_all = ihdp_data[ihdp_treatment_colname]

lgb_params = {"objective": "regression", "metric": "rmse", "verbosity": -1}

pehes_rs, pehes_s, pehes_lgb = [], [], []
for seed in range(20):
    q_ids, b_ids, v_ids = generate_indices(seed=seed, n_total=len(x_all))

    x_rs = pd.concat([x_all, x_all.iloc[v_ids]], ignore_index=True)  # val twice: for y_fact, y_cfact prediction
    y_fact_rs = pd.concat([y_fact, y_cfact.iloc[v_ids]], ignore_index=True)
    treatment_rs = pd.concat([treatment_all, 1 - treatment_all.iloc[v_ids]], ignore_index=True)

    new_indices = np.arange(len(x_all), len(x_all) + len(v_ids))
    v_rs_ids = np.concatenate([new_indices, v_ids])

    treatment_array = treatment_rs.to_numpy()
    r_rs = (treatment_array[:, None] == treatment_array[None, :]).astype(int)

    tau_val_ground = ihdp_data[ihdp_tau_colname][v_ids].to_numpy()
    v_treatments = ((-1) ** (treatment_array + 1))[v_ids]
    n_val = len(v_ids)

    def train_relnet_shorthand(x: pd.DataFrame, r: np.ndarray, seed: int) -> float:
        torch.manual_seed(seed)
        _, _, _, y_val_pred, _ = train_relnet(
            x=x.to_numpy(),
            y=y_fact_rs.to_numpy(),
            r=r,
            backgnd_indices=b_ids,
            query_indices=q_ids,
            val_indices=v_rs_ids,
            lr=.01,
            n_epochs=500,
            n_layers=2,
            periodic_embed_dim=None,
            num_heads=2,
            progress_bar=True,
            print_loss=False,
        )
        y_val_fact_hat, y_val_cfact_hat = y_val_pred[:n_val], y_val_pred[n_val:]
        tau_val_pred = (y_val_fact_hat - y_val_cfact_hat) * v_treatments
        return mean_squared_error(tau_val_ground, tau_val_pred)

    pehe_val_rs = train_relnet_shorthand(x_rs, r=r_rs, seed=seed)
    print(f"seed {seed}\tPEHE {pehe_val_rs:.2f} (RS-learner)")
    pehes_rs.append(pehe_val_rs)


    x_s = x_rs.copy()
    x_s[ihdp_treatment_colname] = treatment_rs
    r_s = np.zeros_like(r_rs)
    pehe_val_s = train_relnet_shorthand(x_s, r=r_s, seed=seed)
    print(f"seed {seed}\tPEHE {pehe_val_s:.2f} (standard S-learner)")
    pehes_s.append(pehe_val_s)

    # LightGBM
    train_ids = np.concatenate([b_ids, q_ids])
    x_s_train = x_s.iloc[train_ids].to_numpy()
    y_s_train = y_fact_rs[train_ids]
    lgb_model = lgb.train(lgb_params, lgb.Dataset(x_s_train, label=y_s_train))

    x_s_val = x_s.iloc[v_rs_ids].to_numpy()
    y_s_val = y_fact_rs[v_rs_ids]
    val_pred_lgb = lgb_model.predict(x_s_val)
    y_val_fact_lgb, y_val_cfact_lgb = val_pred_lgb[:n_val], val_pred_lgb[n_val:]
    tau_val_lgb = (y_val_fact_lgb - y_val_cfact_lgb) * v_treatments
    pehe_val_lgb = mean_squared_error(tau_val_ground, tau_val_lgb)
    print(f"seed {seed}\tPEHE {pehe_val_lgb:.2f} (LightGBM)")
    pehes_lgb.append(pehe_val_lgb)
