In [None]:
%load_ext autoreload
%autoreload 2

from itertools import product

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import truncnorm
from tabrel.benchmark.nw_regr import run_training, make_r
from tabrel.train import train_relnet
from tqdm import tqdm

In [None]:
def synthetic_2d_data(n_samples: int, n_clusters: int, seed: int, func_type: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)
    x = np.random.uniform(-1, 1, (n_samples, 2))
    clusters = np.random.randint(0, n_clusters, n_samples)

    y = x.T[0] + clusters * .5
    
    if func_type == "linear":
        y += 2 * x.T[1]
    elif func_type == "square":
        y += .5 * x.T[1] ** 2
    elif func_type == "sin":
        y += np.sin(x.T[1])
    else:
        raise ValueError("unknown type: " + func_type)
    
    return x, y, clusters


x, y, c = synthetic_2d_data(10, n_clusters=3, seed=42, func_type="square")

def synthetic_3d_data(n_samples: int, n_clusters: int, seed: int, x3_type: str, depend_type: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)

    x = np.random.uniform(-1, 1, (n_samples, 3))
    if x3_type == "uniform":
        pass
    elif x3_type == "truncnorm":
        mean, std = 1, 1
        lower_bound, upper_bound = 0, 4
        a = (lower_bound - mean) / std
        b = (upper_bound - mean) / std
        x[:, -1] = truncnorm.rvs(a, b, loc=mean, scale=std, size=n_samples)
    else:
        raise ValueError("unknown x3_type " + x3_type)
    
    clusters = np.random.randint(0, n_clusters, n_samples)

    y = x.T[0] + .5 * clusters
    if depend_type == "quad":
        y += x.T[1] + x.T[2] ** 2
    elif depend_type == "cube":
        y += x.T[1] **2 + x.T[2] ** 3
    elif depend_type == "sin":
        y += np.sin(x.T[1]) + np.cos(x.T[2])
    
    return x, y, clusters
    
x, y, c = synthetic_3d_data(100, n_clusters=3, seed=42, x3_type="truncnorm", depend_type="quad")
# plt.hist(x[:, 2])


In [None]:
from datetime import datetime
from typing import Final

import optuna 

from tabrel.benchmark.nw_regr import MlpConfig
from tabrel.optuna import RelTrainData, build_objective_nw_mlp

def build_data_2d(n_samples: int, n_clusters: int, seed: int, func_type: str) -> RelTrainData:
    x, y, c = synthetic_2d_data(n_samples, 3, seed=seed, func_type=func_type)
    n_query = n_val = n_samples // 3

    return RelTrainData(
        r=make_r(c),
        x=x,
        y=y,
        query_ids=np.arange(n_query),
        val_ids=np.arange(n_query, n_query + n_val),
        back_ids=np.arange(n_query + n_val, n_samples),
    )

seed_current = 0
def objective_mlp(trial: optuna.Trial) -> float:
    global seed_current
    seed_current += 1

    return build_objective_nw_mlp(
        trial=trial,
        data=build_data_2d(n_samples=300, n_clusters=3, seed=seed_current, func_type="square"),
        n_epochs=100,
        seed=seed_current,
    )

sqlite_path: Final[str] = "sqlite:///db.sqlite3"
study = optuna.create_study(
    direction="maximize",
    study_name=f"synth2d_{datetime.now()}",
    storage=sqlite_path,
)
study.optimize(objective_mlp)

In [None]:
from collections import defaultdict


metrics = defaultdict(list)

for n_samples, func_type in tqdm(list(product([300, 1000], ["linear", "square", "sin"]))):

    for seed in range(15):
        data = build_data_2d(n_samples, 3, seed, func_type)

        res = run_training(
            x=data.x,
            y=data.y,
            r=data.r,
            backgnd_indices=data.back_ids,
            query_indices=data.query_ids,
            val_indices=data.val_ids,
            lr=1e-3,
            n_epochs=50,
            rel_as_feats=data.r,
            mlp_config=MlpConfig(
                in_dim=data.x.shape[1],
                hidden_dim=75,
                out_dim=18,
                dropout=.5,
            )
        )

        # res_relnet = train_relnet(
        #     x=x,
        #     y=y,
        #     r=r,
        #     backgnd_indices=back_ids,
        #     query_indices=query_ids,
        #     val_indices=val_ids,
        #     lr=1e-3,
        #     n_epochs=1000,
        #     n_layers=2,
        #     num_heads=4,
        # )

        # r2s_relnet.append(res_relnet[1])

        for k, v in res.items():
            metrics[f"{k}_nSamples={n_samples}_funcType={func_type}"].append(v)

In [None]:
from tabrel.utils.misc import to_df


df_metrics = to_df(metrics, decimal_places=4)
df_metrics.set_index(df_metrics["label"]).filter(regex=".*rel=T.*mlp=T.*", axis=0)