In [None]:
%load_ext autoreload
%autoreload 2
from typing import Final

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from tabrel.train import train_relnet
from tabrel.benchmark.nw_regr import make_r, make_random_r, generate_toy_regr_data

n_samples: Final[int] = 300
seed: Final[int] = 42
x, y, c = generate_toy_regr_data(n_samples=n_samples, n_clusters=3, seed=seed, distr="uniform", y_func="square")
plt.scatter(x, y, c=c)

c_np = c.numpy()
r_deterministic = make_r(c_np)
r_random = make_random_r(seed=seed, clusters=c_np)

In [None]:
n_back: Final[int] = n_samples // 3
n_query: Final[int] = n_samples // 3
n_train: Final[int] = n_back + n_query
n_val: Final[int] = n_samples - n_train

back_inds: Final[np.ndarray] = np.arange(n_back)
query_inds: Final[np.ndarray] = np.arange(n_query) + n_back
val_inds: Final[np.ndarray] = np.arange(n_val) + n_train

x_np, y_np = x.numpy(), y.numpy()

def train_relnet_shorthand(r: np.ndarray) -> None:
    mse, r2, _, y_val_pred, y_val_true = train_relnet(
        x_np,
        y_np,
        r=r,
        backgnd_indices=back_inds,
        query_indices=query_inds,
        val_indices=val_inds,
        lr=1e-2,
        n_epochs=2000,
        progress_bar=False,
        print_loss=True,
        lr_decay=0.9,
        lr_decay_step=100,
    )
    print(f"mse: {mse:.3f}\tr2: {r2:.3f}")
    plt.scatter(x_np[val_inds], y_val_true, c=c_np[val_inds])
    plt.scatter(x_np[val_inds], y_val_pred, c=c_np[val_inds], marker="*")

train_relnet_shorthand(r_deterministic)

In [None]:
train_relnet_shorthand(r_random)