In [None]:
%load_ext autoreload
%autoreload 2
from typing import Final

import numpy as np
import pandas as pd
from tqdm import tqdm

from tabrel.train import train_relnet
from tabrel.benchmark.nw_regr import run_training, generate_toy_regr_data, make_r, make_random_r, generate_multidim_noisy_data

n_samples: Final[int] = 300
n_epochs: Final[int] = 5000
n_features: Final[int] = 7

In [None]:
for y_func in (
    # "square",
    # "sign", 
    "noisy",
    ):
    r2_list = []
    mse_list = []
    for seed in tqdm(range(15), desc=y_func):
        if y_func == "noisy":
            x, y, c = generate_multidim_noisy_data(n_samples=n_samples, n_clusters=3, x_dim=n_features, seed=seed)
        else:
            x, y, c = generate_toy_regr_data(n_samples=n_samples, n_clusters=3, y_func=y_func, seed=seed)
            x = x.numpy()
            y = y.numpy()
            c = c.numpy()
        
        
        
        n_query = n_validate = n_samples // 3
        samples_indices = np.array(list(range(n_samples)))
        validate_indices = np.arange(n_validate)
        query_indices = np.arange(n_validate, n_validate + n_query)
        back_indices = np.arange(n_query + n_validate, n_samples)

        r = make_random_r(seed, c)
        results = run_training( # this does not work for whatever reason
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_indices,
            query_indices=query_indices,
            val_indices=validate_indices,
            lr=3e-3,
            n_epochs=n_epochs,
            rel_as_feats=r,
        )
        
        mse_list.append({k:v[0] for k,v in results.items()})
        r2_list.append({k:v[1] for k, v in results.items()})

        results_relnet = train_relnet(
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_indices,
            query_indices=query_indices,
            val_indices=validate_indices,
            lr=1e-3,
            num_heads=4,
            progress_bar=False,
            n_epochs=n_epochs,
        )
        relnet_mse, relnet_r2 = results_relnet[:2]
        mse_list[-1]["tabrel"] = relnet_mse
        r2_list[-1]["tabrel"] = relnet_r2 
    
    print(f"=== {y_func} ===")
    for label, lst in (("mse", mse_list), ("r2", r2_list)):
        results_df = pd.DataFrame(lst)
        print("\n")
        print(label)
        print(pd.DataFrame({"mean": results_df.mean().round(3), "std": results_df.std().round(3)}))