In [None]:
import matplotlib.pyplot as plt
import torch

from tabrel.benchmark.nw_regr import train_nw, compute_relation_matrix, NwModelConfig, NwTrainConfig

train_cfg = NwTrainConfig(n_clusters=1, use_rel=False, n_test=50)
model_cfg = NwModelConfig()
reg = train_nw(model_cfg=model_cfg, train_cfg=train_cfg)

plt.scatter(reg.x_test.flatten(), reg.y_test_true, color="black")

x_grid = torch.linspace(start=-1, end=1, steps=50).unsqueeze(1)
r = torch.zeros((len(x_grid), train_cfg.n_train))
y_grid = reg.model(reg.x_train, reg.y_train, x_grid, r)
plt.plot(x_grid.flatten(), y_grid.detach(), "--")

In [None]:
reg.evaluate()

In [None]:
from typing import Final

for use_rel in (True, False):
    n_clusters: Final[int] = 3
    train_cfg = NwTrainConfig(n_clusters=n_clusters, lr=0.1, n_train=100, n_test=50, use_rel=use_rel)
    reg = train_nw(model_cfg, train_cfg)

    plt.figure()
    plt.scatter(reg.x_test.flatten(), reg.y_test_true, color="black")

    for i_c in range(n_clusters):
        if use_rel:
            r = compute_relation_matrix(train_clusters=reg.clusters_train, test_clusters=torch.Tensor([i_c]).expand(len(x_grid)))
        else:
            r = torch.zeros((train_cfg.n_test, train_cfg.n_train))
        y_grid = reg.model(reg.x_train, reg.y_train, x_grid, r)
        plt.plot(x_grid, y_grid.detach(), "--")

    metrics = reg.evaluate()
    plt.title(f"Rel: {use_rel}, $R^2\\approx{metrics['r2']:.3f}$, MAE$\\approx{metrics['mae']:.3f}$")

