In [None]:
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from tabrel.benchmark.nw_regr import train_nw, compute_relation_matrix, NwModelConfig, NwTrainConfig
from typing import Final

model_cfg = NwModelConfig()
x_grid = torch.linspace(start=-1, end=1, steps=50).unsqueeze(1)

# Run predictions with different seeds
n_runs: Final[int] = 30
all_predictions = []
for seed in tqdm(range(n_runs), desc="Running predictions"):
    train_cfg = NwTrainConfig(n_clusters=1, use_rel=False, n_query=50, seed=seed)
    reg = train_nw(model_cfg=model_cfg, train_cfg=train_cfg)
    r = torch.zeros((len(x_grid), train_cfg.n_backgnd))
    y_grid = reg.model(reg.x_backgnd, reg.y_backgnd, x_grid, r)
    all_predictions.append(y_grid.detach())

# Convert to tensor and compute statistics
all_predictions = torch.stack(all_predictions)
valid_indices = ~torch.isnan(all_predictions).any(dim=1)  # Check for NaN in any column
all_predictions_clean = all_predictions[valid_indices]  # Keep only valid predictions
y_mean = all_predictions_clean.mean(dim=0).flatten()
y_std = all_predictions_clean.std(dim=0).flatten()
x_grid_flat = x_grid.flatten()

plt.figure(figsize=(10, 6))
plt.fill_between(
    x_grid_flat, 
    y_mean - y_std, 
    y_mean + y_std, 
    color="blue", 
    alpha=0.2,
    label=f"$\pm1$ std. dev. ({len(all_predictions_clean)} runs)"
)
plt.plot(x_grid_flat, y_mean, "--", color="blue", label="Mean prediction")
plt.scatter(reg.x_query.flatten(), reg.y_query_true, color="black", label="Example query data")

plt.xlabel("x")
plt.ylabel("y")
plt.title("Vanilla NW Regression")
plt.legend()
plt.grid(True)

In [None]:
reg.evaluate()

In [None]:
n_clusters: Final[int] = 3

for use_rel in (True, False):
    cluster_predictions = {i_c: [] for i_c in range(n_clusters)}
    final_metrics = None
    last_query_x = None
    last_query_y = None
    all_r2, all_mse = [], []
    for seed in tqdm(range(n_runs), desc=f"Use relation: {use_rel}"):
        train_cfg = NwTrainConfig(
            n_clusters=n_clusters, lr=0.05, n_backgnd=200, n_query=100, use_rel=use_rel, seed=seed,
            x_distr="uniform", y_func="sign",
        )
        reg = train_nw(model_cfg, train_cfg)

        for i_c in range(n_clusters):
            if use_rel:
                r = compute_relation_matrix(
                    backgnd_clusters=reg.clusters_backgnd,
                    query_clusters=torch.Tensor([i_c]).expand(len(x_grid))
                )
            else:
                r = torch.zeros((len(x_grid), train_cfg.n_backgnd))

            y_grid = reg.model(reg.x_backgnd, reg.y_backgnd, x_grid, r)
            cluster_predictions[i_c].append(y_grid.detach().flatten())

        if not y_grid.isnan().any():
            metrics = reg.evaluate()
            all_r2.append(metrics["r2"])
            all_mse.append(metrics["mse"])
        
        if seed == n_runs - 1:
            last_query_x = reg.x_query.flatten()
            last_query_y = reg.y_query_true

    plt.figure(figsize=(5, 4))
    plt.grid(True)
    for i_c in range(n_clusters):
        preds = torch.stack(cluster_predictions[i_c])
        valid_indices = ~torch.isnan(preds).any(dim=1)
        preds_clean = preds[valid_indices]

        y_mean = preds_clean.mean(dim=0)
        y_std = preds_clean.std(dim=0)

        plt.fill_between(
            x_grid_flat,
            y_mean - y_std,
            y_mean + y_std,
            alpha=0.2
        )
        plt.plot(x_grid_flat, y_mean, "--", label=f"Cluster {i_c+1} mean" if use_rel else "mean")
        if not use_rel:
            break  # one plot is enough - others are just the same
    
    plt.scatter(last_query_x, last_query_y, color="black", label="Example query data")

    all_r2 = torch.tensor(all_r2)
    all_mse = torch.tensor(all_mse)

    plt.text(
        0.05, 0.03,
        (f"$R^2 = {all_r2.mean():.2f} \pm {all_r2.std():.2f}$\n"
         f"MSE $ = {all_mse.mean():.2f} \pm {all_mse.std():.2f}$"),
        transform=plt.gca().transAxes,
        bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3', alpha=0.8)
    )

    plt.xlabel("x")
    plt.ylabel("y")
    plt.title(f"NW Regression ({'relationship-aware' if use_rel else 'vanilla'})")
    plt.legend()
    plt.savefig(f"nw_regr_rel_{use_rel}.pdf")