In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from tqdm import tqdm

from tabrel.train import train_relnet
from tabrel.benchmark.nw_regr import run_training, generate_toy_regr_data, make_r

def generate_multidim_noisy_data(n_samples: int, n_clusters: int, x_dim: int, seed: int) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    np.random.seed(seed)
    x = np.random.uniform(-1, 1, (n_samples, x_dim))
    clusters = np.random.randint(0, n_clusters, (n_samples,))
    y = np.sin(x[:, 0]) + np.cos(x[:, 1]) + x[:, 2] + clusters
    return x, y, clusters

def make_random_r(seed: int, clusters: np.ndarray) -> np.ndarray:
    np.random.seed(seed)
    n_samples = len(clusters)
    r = np.eye(n_samples)
    for i in range(n_samples):
        for j in range(i):
            if clusters[i] == clusters[j] and np.random.choice((True, False)):
                r[i, j] = 1
                r[j, i] = 1
    return r

n_samples = 300
n_ones = 70

In [None]:
for y_func in (
    # "square",
    # "sign", 
    "noisy",
    ):
    r2_list = []
    mse_list = []
    for seed in tqdm(range(1), desc=y_func):
        if y_func == "noisy":
            x, y, c = generate_multidim_noisy_data(n_samples=n_samples, n_clusters=3, x_dim=6, seed=seed)
        else:
            x, y, c = generate_toy_regr_data(n_samples=n_samples, n_clusters=3, y_func=y_func, seed=seed)
            x = x.numpy()
            y = y.numpy()
        
        
        
        n_query = n_validate = n_samples // 3
        samples_indices = np.array(list(range(n_samples)))
        validate_indices = samples_indices[:n_validate]
        query_indices = samples_indices[n_validate: n_query]
        back_indices = samples_indices[n_query:]

        r = make_random_r(seed, c)
        results = run_training( # this does not work for whatever reason
            x=x,
            y=y,
            r=r,
            backgnd_indices=back_indices,
            query_indices=query_indices,
            val_indices=validate_indices,
            lr=3e-3,
            n_epochs=1500,
            rel_as_feats=r,
        )

        
        
        mse_list.append({k:v[0] for k,v in results.items()})
        r2_list.append({k:v[1] for k, v in results.items()})
    
    print(f"=== {y_func} ===")
    for label, lst in (("mse", mse_list), ("r2", r2_list)):
        results_df = pd.DataFrame(lst)
        print("\n")
        print(label)
        print(pd.DataFrame({"mean": results_df.mean().round(3), "std": results_df.std().round(3)}))

Well, I give up on finding errors in my initial code, so I will just rewrite everything in a single cell. Of course, I'll copy some code from my vectorized NW implementation

In [None]:
from collections import defaultdict

import lightgbm as lgb
import torch
import torch.nn as nn
from sklearn.metrics import r2_score, mean_squared_error
from tabrel.benchmark.nw_regr import RelNwRegr, NwModelConfig




metrics = defaultdict(list)
for seed in tqdm(range(15)):
    # n_feats = 1
    # x_initial, y, c = generate_toy_regr_data(
    #      n_samples=n_samples, n_clusters=3, seed=seed, y_func="sign",
    # )
    # r = make_r(clusters=c)
    # x_initial, y = x_initial.numpy(), y.numpy()

    n_feats = 7
    x_initial, y, c = generate_multidim_noisy_data(
        n_samples=n_samples, n_clusters=3, x_dim=n_feats, seed=seed
    )
    r = make_random_r(seed, c)

    x_extended = np.concatenate([x_initial, r], axis=1)

    for x, x_label in ((x_initial, "xInit"), (x_extended, "xExtended")):

        x_mean = np.mean(x, axis=0, keepdims=True)
        x_std = np.std(x, axis=0, keepdims=True)
        x_norm = (x - x_mean) / x_std

        r_torch = torch.Tensor(r)
        x_torch = torch.Tensor(x_norm)
        y_torch = torch.Tensor(y)

        n_back = n_query = n_samples // 3
        n_test = n_samples - (n_back + n_query)
        x_back, y_back = x_torch[:n_back], y_torch[:n_back]
        x_q, y_q = x_torch[n_back : n_query + n_back], y_torch[n_back : n_query + n_back]
        x_val, y_val = x_torch[n_back + n_query :], y_torch[n_back + n_query :]
        r_q_b = r_torch[n_back : n_query + n_back, :n_back]
        
        if x_label == "xExtended":
            r_q_b = torch.zeros_like(r_q_b)
        
        x_train, y_train = x_torch[: n_back + n_query], y_torch[: n_back + n_query]
        r_val_train = r_torch[n_back + n_query :, : n_back + n_query]

        if x_label == "xExtended":
                r_val_train = torch.zeros_like(r_val_train)
        
        # lgb_params = {"objective": "regression", "metric": "rmse", "verbosity": -1}
        # train_data = lgb.Dataset(x_train.numpy(), label=y_train.numpy())
        # model_lgb = lgb.train(lgb_params, train_data)
        # y_pred = model_lgb.predict(x_val)
        # metrics[f"lgb_{x_label}_mse"].append(mean_squared_error(y_val, y_pred))
        # metrics[f"lgb_{x_label}_r2"].append(r2_score(y_val, y_pred))

        if x_label == "xInit":
            back_indices = np.arange(n_back)
            query_indices = np.arange(n_query) + n_back
            val_indices = np.arange(len(x_val)) + n_back + n_query
            results_relnet = train_relnet(
                x=x,
                y=y,
                r=r,
                backgnd_indices=back_indices,
                query_indices=query_indices,
                val_indices=validate_indices,
                lr=1e-3,
                num_heads=4,
                progress_bar=False,
                n_epochs=1000,
            )
            relnet_mse, relnet_r2 = results_relnet[:2]
            metrics["tabrel_mse"].append(relnet_mse)
            metrics["tabrel_r2"].append(relnet_r2) 

        for learnable_norm in (True, False):
            config = NwModelConfig(
                 init_sigma=1.0,
                init_r_scale=1.0,
                 input_dim=n_feats if x_label == "xInit" else n_feats + n_samples,
                 trainable_weights_matrix=learnable_norm,
            )
            model = RelNwRegr(config)
            optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
            loss_fn = nn.MSELoss()
            torch.manual_seed(seed)
            model.train()

            n_epochs = 1000
            for _ in range(n_epochs):
                optimizer.zero_grad()
                y_pred = model(x_back, y_back, x_q, r_q_b)
                loss = loss_fn(y_pred, y_q)
                loss.backward()
                optimizer.step()

            model.eval()
            with torch.no_grad():
                y_pred = model(
                    x_train,
                    y_train,
                    x_val,
                    r_val_train,
                )
                y_pred_np = y_pred.numpy()
                y_val_np = y_val.numpy()

                mse = mean_squared_error(y_val_np, y_pred_np)
                r2 = r2_score(y_val_np, y_pred_np)
                metrics[f"learnNorm{learnable_norm}_{x_label}_mse"].append(mse)
                metrics[f"learnNorm{learnable_norm}_{x_label}_r2"].append(r2)

for k, v in metrics.items():
    print(f"{k}:\tmean {np.mean(v):.4f}\tstd {np.std(v):.4f}")