In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from tqdm import tqdm

from tabrel.train import train_relnet
from tabrel.benchmark.nw_regr import run_training
from tabrel.benchmark.nw_regr import generate_toy_regr_data


def generate_multidim_noisy_data(n_samples: int, n_clusters: int, x_dim: int, seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)
    x = np.random.uniform(-1, 1, (n_samples, x_dim))
    clusters = np.random.randint(0, n_clusters, (n_samples,))
    y = np.sin(x[:, 0]) + np.cos(x[:, 1]) + x[:, 2] + clusters
    return x, y, clusters


n_samples = 300
n_ones = 70

for y_func in (
    "square",
    "sign", 
    "noisy",
    ):
    r2_list = []
    mse_list = []
    for seed in tqdm(range(10), desc=y_func):
        if y_func == "noisy":
            x, y, c = generate_multidim_noisy_data(n_samples=n_samples, n_clusters=3, x_dim=4, seed=seed)
        else:
            x, y, c = generate_toy_regr_data(n_samples=n_samples, n_clusters=3, y_func=y_func, seed=seed)
            x = x.numpy()
            y = y.numpy()
        
        r = np.eye(n_samples)
        for i in range(n_samples):
            for j in range(i):
                if c[i] == c[j] and np.random.choice((True, False)):
                    r[i, j] = 1
                    r[j, i] = 1
        
        n_query = n_validate = n_samples // 3
        samples_indices = np.array(list(range(n_samples)))
        validate_indices = samples_indices[:n_validate]
        query_indices = samples_indices[n_validate: n_query]
        back_indices = samples_indices[n_query:]

        results = run_training(
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_indices,
            query_indices=query_indices,
            val_indices=validate_indices,
            lr=1e-3,
            n_epochs=500,
            rel_as_feats=r,
        )

        results["relnet"] = train_relnet(
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_indices,
            query_indices=query_indices,
            val_indices=validate_indices,
            lr=1e-3,
            num_heads=4,
            progress_bar=False,
            n_epochs=1000,
        )
        
        mse_list.append({k:v[0] for k,v in results.items()})
        r2_list.append({k:v[1] for k, v in results.items()})
    
    print(f"=== {y_func} ===")
    for label, lst in (("mse", mse_list), ("r2", r2_list)):
        results_df = pd.DataFrame(lst)
        print("\n")
        print(label)
        print(pd.DataFrame({"mean": results_df.mean().round(3), "std": results_df.std().round(3)}))