In [None]:
from pathlib import Path

import numpy as np
import pandas as pd

def load_ihdp_data(ihdp_path: Path) -> tuple[pd.DataFrame, list[str], str]:
    ihdp_cols = [s[:-1] for s in np.loadtxt(ihdp_path / "columns.txt", dtype=str)][:-2]
    ihdp_cols.extend([f"x{i}" for i in range(2, 26)])

    csvs = []
    for csv_path in (ihdp_path / "csv").glob("*.csv"):
        csvs.append(pd.read_csv(csv_path, header=None))
        break # TODO choose a table, for now using the first table
    data = pd.concat(csvs)
    data.columns = ihdp_cols

    y_col_name = "delta_y"
    data[y_col_name] = (data["y_cfactual"] - data["y_factual"]) * (-1) ** data["treatment"]
    exclude_cols = ["treatment", "y_cfactual", "y_factual", "mu0", "mu1"]
    return data, exclude_cols, y_col_name

ihdp_data, ihdp_exclude_cols, ihdp_y_colname = load_ihdp_data(Path("/Users/vzuev/Documents/git/gh_zuevval/tabrel/CEVAE/datasets/IHDP"))
ihdp_data.head()

In [None]:
from typing import Final

x_all = ihdp_data.drop(columns=ihdp_exclude_cols + [ihdp_y_colname])

ihdp_last_numeric_index: Final[int] = 6
x_numeric = x_all.iloc[:, :ihdp_last_numeric_index]
x_cat = x_all.iloc[:, ihdp_last_numeric_index:]
x_numeric

In [None]:
import seaborn as sns

x_num_y = x_numeric.copy()
x_num_y[ihdp_y_colname] = ihdp_data[ihdp_y_colname]
sns.pairplot(x_num_y, hue=ihdp_y_colname)

In [None]:
from itertools import product
from tqdm import tqdm

group_col: Final[str] = "x4"
x = x_all.drop(columns=[group_col])
x_len = len(x)
categories = x_all[group_col]
print("n_categories", len(categories.unique()))

r = np.zeros((x_len, x_len))
for i, j in tqdm(list(product(range(x_len), range(x_len)))):
    if np.isclose(categories[i], categories[j]):
        r[i, j] = 1

r

In [None]:
from tabrel.benchmark.nw_regr import run_training

np.random.seed(42)
indices = np.random.permutation(x_len)
n_query, n_back = 200, 300
query_indices = indices[:n_query]
back_indices = indices[n_query:n_back]
val_indices = indices[n_back:]

res = run_training(
    x=x.to_numpy(),
    y=ihdp_data[ihdp_y_colname].to_numpy(),
    r=r,
    backgnd_indices=back_indices,
    query_indices=query_indices,
    val_indices=val_indices,
    lr=1e-4,
    n_epochs=10,
)
for rel in (True, False):
    print(rel, res[f"rel={rel}"][:-1])