In [None]:
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import truncnorm
from tabrel.benchmark.nw_regr import run_training
from tabrel.train import train_relnet


def synthetic_2d_data(n_samples: int, n_clusters: int, seed: int, type: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)
    x = np.random.uniform(-1, 1, (n_samples, 2))
    clusters = np.random.randint(0, n_clusters, n_samples)

    y = x.T[0] + clusters * .5
    
    if type == "linear":
        y += 2 * x.T[1]
    elif type == "square":
        y += .5 * x.T[1] ** 2
    elif type == "sin":
        y += np.sin(x.T[1])
    else:
        raise ValueError("unknown type: " + type)
    
    return x, y, clusters


x, y, c = synthetic_2d_data(10, n_clusters=3, seed=42, type="square")

def synthetic_3d_data(n_samples: int, n_clusters: int, seed: int, x3_type: str, depend_type: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)

    x = np.random.uniform(-1, 1, (n_samples, 3))
    if x3_type == "uniform":
        pass
    elif x3_type == "truncnorm":
        mean, std = 1, 1
        lower_bound, upper_bound = 0, 4
        a = (lower_bound - mean) / std
        b = (upper_bound - mean) / std
        x[:, -1] = truncnorm.rvs(a, b, loc=mean, scale=std, size=n_samples)
    else:
        raise ValueError("unknown x3_type " + x3_type)
    
    clusters = np.random.randint(0, n_clusters, n_samples)

    y = x.T[0] + .5 * clusters
    if depend_type == "quad":
        y += x.T[1] + x.T[2] ** 2
    elif depend_type == "cube":
        y += x.T[1] **2 + x.T[2] ** 3
    elif depend_type == "sin":
        y += np.sin(x.T[1]) + np.cos(x.T[2])
    
    return x, y, clusters
    
x, y, c = synthetic_3d_data(100, n_clusters=3, seed=42, x3_type="truncnorm", depend_type="quad")
plt.hist(x[:, 2])

def make_r(clusters: np.ndarray) -> np.ndarray:
    n = len(clusters)
    r = np.zeros((n, n))

    for i, j in product(range(n), range(n)):
        if clusters[i] == clusters[j]:
            r[i, j] = 1
    return r

In [None]:
results = {}
for n_samples, type in product([300, 1000], ["linear", "square", "sin"]):
    r2s_relnet = []
    r2s_nw = []
    r2s_rel_nw = []
    r2s_feat_rel_nw = []
    r2s_lgb = []
    r2s_feat_rel_lgb = []

    for seed in range(10):
        x, y, c = synthetic_2d_data(n_samples, 3, seed=seed, type=type)
        r = make_r(c)

        n_query = n_val = n_samples // 3
        query_ids = np.arange(n_query)
        val_ids = np.arange(n_query, n_query + n_val)
        back_ids = np.arange(n_query + n_val, n_samples)

        res = run_training(
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_ids,
            query_indices=query_ids,
            val_indices=val_ids,
            lr=1e-3,
            n_epochs=50,
            rel_as_feats=r,
        )

        r2s_nw.append(res["rel=False"][1])
        r2s_rel_nw.append(res["rel=True"][1])
        r2s_feat_rel_nw.append(res["rel-as-feats"][1])
        r2s_lgb.append(res["lgb"][1])
        r2s_feat_rel_lgb.append(res["lgb-rel"][1])

        res_relnet = train_relnet(
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_ids,
            query_indices=query_ids,
            val_indices=val_ids,
            lr=1e-3,
            n_epochs=1000,
            n_layers=2,
            num_heads=4,
        )

        r2s_relnet.append(res_relnet[1])

    results[f"n_samples={n_samples};type={type}"] = {
        k: (np.mean(v), np.std(v))
        for k, v in (
            ("rel=True", r2s_rel_nw),
            ("rel=False", r2s_nw),
            ("rel-feats", r2s_feat_rel_nw),
            ("relnet", r2s_relnet),
            ("lgb", r2s_lgb),
            ("lgb-rel", r2s_feat_rel_lgb),
        )
    }

In [None]:
from itertools import chain
[{f"{k};{kv}":(round(vv[0], 4), round(vv[1], 4)) for kv, vv in v.items()} for k,v in results.items()]